Skip to content

Commit

Permalink
Merge pull request #100 from constantinpape/livecell-cell-types
Browse files Browse the repository at this point in the history
Update livecell.py
  • Loading branch information
constantinpape committed Nov 30, 2022
2 parents 0cec1e7 + 17bcc39 commit 692275d
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions torch_em/data/datasets/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _download_annotation_file(path, split, download):
return annotation_file


def _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder):
def _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types):
# TODO try except to explain how to install this
from pycocotools.coco import COCO

Expand All @@ -62,6 +62,11 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo
# get the path for the image data and make sure the corresponding image exists
image_metadata = coco.loadImgs(image_id)[0]
file_name = image_metadata["file_name"]

# if cell_type names are given we only select file names that match a cell_type
if cell_types is not None and (not any([cell_type in file_name for cell_type in cell_types])):
continue

sub_folder = file_name.split("_")[0]
image_path = os.path.join(image_folder, sub_folder, file_name)
# something changed in the image layout? we keep the old version around in case this chagnes back...
Expand Down Expand Up @@ -91,10 +96,12 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo

imageio.imwrite(seg_path, seg)

assert len(image_paths) == len(seg_paths)
assert len(image_paths) > 0, f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?"
return image_paths, seg_paths


def _download_livecell_annotations(path, split, download):
def _download_livecell_annotations(path, split, download, cell_types):
annotation_file = _download_annotation_file(path, split, download)
if split == "test":
split_name = "livecell_test_images"
Expand All @@ -106,7 +113,7 @@ def _download_livecell_annotations(path, split, download):
seg_folder = os.path.join(path, "annotations", split_name)
assert os.path.exists(image_folder), image_folder

return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder)
return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types)


def _livecell_segmentation_loader(
Expand Down Expand Up @@ -145,13 +152,16 @@ def _livecell_segmentation_loader(
return loader


# TODO add support for the different benchmark tasks
def get_livecell_loader(path, patch_shape, split, download=False,
offsets=None, boundaries=False, binary=False,
**kwargs):
cell_types=None, **kwargs):
assert split in ("train", "val", "test")
if cell_types is not None:
assert isinstance(cell_types, (list, tuple)),\
f"cell_types must be passed as a list or tuple instead of {cell_types}"

_download_livecell_images(path, download)
image_paths, seg_paths = _download_livecell_annotations(path, split, download)
image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types)

assert sum((offsets is not None, boundaries, binary)) <= 1
label_dtype = torch.int64
Expand Down

0 comments on commit 692275d

Please sign in to comment.