From 17bcc3943f4c8302a51c3a1cc302a3af75c5eba7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 30 Nov 2022 18:10:07 +0100 Subject: [PATCH] Update livecell.py --- torch_em/data/datasets/livecell.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/torch_em/data/datasets/livecell.py b/torch_em/data/datasets/livecell.py index e605cf4b..12920b43 100644 --- a/torch_em/data/datasets/livecell.py +++ b/torch_em/data/datasets/livecell.py @@ -49,7 +49,7 @@ def _download_annotation_file(path, split, download): return annotation_file -def _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder): +def _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types): # TODO try except to explain how to install this from pycocotools.coco import COCO @@ -62,6 +62,11 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo # get the path for the image data and make sure the corresponding image exists image_metadata = coco.loadImgs(image_id)[0] file_name = image_metadata["file_name"] + + # if cell_type names are given we only select file names that match a cell_type + if cell_types is not None and (not any([cell_type in file_name for cell_type in cell_types])): + continue + sub_folder = file_name.split("_")[0] image_path = os.path.join(image_folder, sub_folder, file_name) # something changed in the image layout? we keep the old version around in case this chagnes back... @@ -91,10 +96,12 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo imageio.imwrite(seg_path, seg) + assert len(image_paths) == len(seg_paths) + assert len(image_paths) > 0, f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?" return image_paths, seg_paths -def _download_livecell_annotations(path, split, download): +def _download_livecell_annotations(path, split, download, cell_types): annotation_file = _download_annotation_file(path, split, download) if split == "test": split_name = "livecell_test_images" @@ -106,7 +113,7 @@ def _download_livecell_annotations(path, split, download): seg_folder = os.path.join(path, "annotations", split_name) assert os.path.exists(image_folder), image_folder - return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder) + return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types) def _livecell_segmentation_loader( @@ -145,13 +152,16 @@ def _livecell_segmentation_loader( return loader -# TODO add support for the different benchmark tasks def get_livecell_loader(path, patch_shape, split, download=False, offsets=None, boundaries=False, binary=False, - **kwargs): + cell_types=None, **kwargs): assert split in ("train", "val", "test") + if cell_types is not None: + assert isinstance(cell_types, (list, tuple)),\ + f"cell_types must be passed as a list or tuple instead of {cell_types}" + _download_livecell_images(path, download) - image_paths, seg_paths = _download_livecell_annotations(path, split, download) + image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types) assert sum((offsets is not None, boundaries, binary)) <= 1 label_dtype = torch.int64