From e4ba975d676f272e5b26b81171c8e60086a030e6 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 25 Oct 2023 23:40:56 +0200 Subject: [PATCH 1/3] Add test split in bcss dataset --- scripts/datasets/check_bcss.py | 3 +- torch_em/data/datasets/bcss.py | 85 +++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/scripts/datasets/check_bcss.py b/scripts/datasets/check_bcss.py index 805ad7b8..384c6854 100644 --- a/scripts/datasets/check_bcss.py +++ b/scripts/datasets/check_bcss.py @@ -42,7 +42,8 @@ def check_bcss(): patch_shape=(512, 512), batch_size=2, download=False, - label_transform=BCSSLabelTrafo(label_choices=[0, 1, 2]) + label_transform=BCSSLabelTrafo(label_choices=[0, 1, 2]), + split="train" ) check_loader(chosen_label_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./bcss.png") diff --git a/torch_em/data/datasets/bcss.py b/torch_em/data/datasets/bcss.py index 725a5051..9d62b9bb 100644 --- a/torch_em/data/datasets/bcss.py +++ b/torch_em/data/datasets/bcss.py @@ -1,5 +1,7 @@ import os +import shutil from glob import glob +from pathlib import Path import torch @@ -15,6 +17,18 @@ CHECKSUM = None +TEST_LIST = [ + 'TCGA-A2-A0SX-DX1_xmin53791_ymin56683_MPP-0.2500', 'TCGA-BH-A0BG-DX1_xmin64019_ymin24975_MPP-0.2500', + 'TCGA-AR-A1AI-DX1_xmin38671_ymin10616_MPP-0.2500', 'TCGA-E2-A574-DX1_xmin54962_ymin47475_MPP-0.2500', + 'TCGA-GM-A3XL-DX1_xmin29910_ymin15820_MPP-0.2500', 'TCGA-E2-A14X-DX1_xmin88836_ymin66393_MPP-0.2500', + 'TCGA-A2-A04P-DX1_xmin104246_ymin48517_MPP-0.2500', 'TCGA-E2-A14N-DX1_xmin21383_ymin66838_MPP-0.2500', + 'TCGA-EW-A1OV-DX1_xmin126026_ymin65132_MPP-0.2500', 'TCGA-S3-AA15-DX1_xmin55486_ymin28926_MPP-0.2500', + 'TCGA-LL-A5YO-DX1_xmin36631_ymin44396_MPP-0.2500', 'TCGA-GI-A2C9-DX1_xmin20882_ymin11843_MPP-0.2500', + 'TCGA-BH-A0BW-DX1_xmin42346_ymin30843_MPP-0.2500', 'TCGA-E2-A1B6-DX1_xmin16266_ymin50634_MPP-0.2500', + 'TCGA-AO-A0J2-DX1_xmin33561_ymin14515_MPP-0.2500' +] + + def _download_bcss_dataset(path, download): """Current recommendation: - download the folder from URL manually @@ -28,7 +42,52 @@ def _download_bcss_dataset(path, download): util.download_source_gdrive(path=path, url=URL, download=download, checksum=CHECKSUM, download_type="folder") -def get_bcss_dataset(path, patch_shape, download=False, label_dtype=torch.int64, **kwargs): +def _get_image_and_label_paths(path): + # when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized` + # when getting the files from the git repo's command line feature, the input images are stored under `images` + if os.path.exists(os.path.join(path, "images")): + image_paths = sorted(glob(os.path.join(path, "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "masks", "*"))) + elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")): + image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*"))) + label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*"))) + else: + raise ValueError("Please check the image directory. If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\"") + + return image_paths, label_paths + + +def _assort_bcss_data(path, download): + if download: + _download_bcss_dataset(path, download) + + if os.path.exists(os.path.join(path, "train")) and os.path.exists(os.path.join(path, "test")): + return + + all_image_paths, all_label_paths = _get_image_and_label_paths(path) + + train_img_dir, train_lab_dir = os.path.join(path, "train", "images"), os.path.join(path, "train", "masks") + test_img_dir, test_lab_dir = os.path.join(path, "test", "images"), os.path.join(path, "test", "masks") + os.makedirs(train_img_dir, exist_ok=True) + os.makedirs(train_lab_dir, exist_ok=True) + os.makedirs(test_img_dir, exist_ok=True) + os.makedirs(test_lab_dir, exist_ok=True) + + for image_path, label_path in zip(all_image_paths, all_label_paths): + img_idx, label_idx = os.path.split(image_path)[-1], os.path.split(label_path)[-1] + if Path(image_path).stem in TEST_LIST: + # move image and label to test + dst_img_path, dst_lab_path = os.path.join(test_img_dir, img_idx), os.path.join(test_lab_dir, label_idx) + shutil.copy(src=image_path, dst=dst_img_path) + shutil.copy(src=label_path, dst=dst_lab_path) + else: + # move image and label to train + dst_img_path, dst_lab_path = os.path.join(train_img_dir, img_idx), os.path.join(train_lab_dir, label_idx) + shutil.copy(src=image_path, dst=dst_img_path) + shutil.copy(src=label_path, dst=dst_lab_path) + + +def get_bcss_dataset(path, patch_shape, split, download=False, label_dtype=torch.int64, **kwargs): """Dataset for breast cancer tissue segmentation in histopathology. This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/. @@ -58,21 +117,13 @@ def get_bcss_dataset(path, patch_shape, download=False, label_dtype=torch.int64, - 20: dcis - 21: other """ - if download: - _download_bcss_dataset(path, download) + assert split in ["train", "test"], "Please choose from the available `train` / `test` splits" - # when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized` - # when getting the files from the git repo's command line feature, the input images are stored under `images` - if os.path.exists(os.path.join(path, "images")): - image_paths = sorted(glob(os.path.join(path, "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "masks", "*"))) - elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")): - image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*"))) - label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*"))) - else: - raise ValueError( - "Please check the image directory. If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\"" - ) + _assort_bcss_data(path, download) + + # update the paths now based on the splits + image_paths = sorted(glob(os.path.join(path, split, "images", "*"))) + label_paths = sorted(glob(os.path.join(path, split, "masks", "*"))) assert len(image_paths) == len(label_paths) dataset = ImageCollectionDataset( @@ -82,12 +133,12 @@ def get_bcss_dataset(path, patch_shape, download=False, label_dtype=torch.int64, def get_bcss_loader( - path, patch_shape, batch_size, download=False, label_dtype=torch.int64, **kwargs + path, patch_shape, batch_size, split, download=False, label_dtype=torch.int64, **kwargs ): """Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details.""" ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_bcss_dataset( - path, patch_shape, download=download, label_dtype=label_dtype, **ds_kwargs + path, patch_shape, split, download=download, label_dtype=label_dtype, **ds_kwargs ) loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) return loader From 3793f76f99684a94e150f867845b69fcf9ef5bb3 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 26 Oct 2023 00:00:18 +0200 Subject: [PATCH 2/3] Add train val split from bcss train data --- torch_em/data/datasets/bcss.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/torch_em/data/datasets/bcss.py b/torch_em/data/datasets/bcss.py index 9d62b9bb..00ec320e 100644 --- a/torch_em/data/datasets/bcss.py +++ b/torch_em/data/datasets/bcss.py @@ -3,8 +3,9 @@ from glob import glob from pathlib import Path -import torch +from sklearn.model_selection import train_test_split +import torch import torch_em from torch_em.data.datasets import util from torch_em.data import ImageCollectionDataset @@ -87,7 +88,7 @@ def _assort_bcss_data(path, download): shutil.copy(src=label_path, dst=dst_lab_path) -def get_bcss_dataset(path, patch_shape, split, download=False, label_dtype=torch.int64, **kwargs): +def get_bcss_dataset(path, patch_shape, split, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs): """Dataset for breast cancer tissue segmentation in histopathology. This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/. @@ -117,13 +118,25 @@ def get_bcss_dataset(path, patch_shape, split, download=False, label_dtype=torch - 20: dcis - 21: other """ - assert split in ["train", "test"], "Please choose from the available `train` / `test` splits" + assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits" _assort_bcss_data(path, download) - # update the paths now based on the splits - image_paths = sorted(glob(os.path.join(path, split, "images", "*"))) - label_paths = sorted(glob(os.path.join(path, split, "masks", "*"))) + if split == "test": + image_paths = sorted(glob(os.path.join(path, "test", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "test", "masks", "*"))) + else: + image_paths = sorted(glob(os.path.join(path, "train", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "train", "masks", "*"))) + + (train_image_paths, val_image_paths, + train_label_paths, val_label_paths) = train_test_split( + image_paths, label_paths, test_size=val_fraction, random_state=42 + ) + + image_paths = train_image_paths if split == "train" else val_image_paths + label_paths = train_label_paths if split == "train" else val_label_paths + assert len(image_paths) == len(label_paths) dataset = ImageCollectionDataset( @@ -133,12 +146,12 @@ def get_bcss_dataset(path, patch_shape, split, download=False, label_dtype=torch def get_bcss_loader( - path, patch_shape, batch_size, split, download=False, label_dtype=torch.int64, **kwargs + path, patch_shape, batch_size, split, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs ): """Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details.""" ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_bcss_dataset( - path, patch_shape, split, download=download, label_dtype=label_dtype, **ds_kwargs + path, patch_shape, split, val_fraction, download=download, label_dtype=label_dtype, **ds_kwargs ) loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) return loader From 0e5468b9f2f47c6b560b1e7b13566a64155f3cb6 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 26 Oct 2023 16:17:59 +0200 Subject: [PATCH 3/3] Add optional split for bcss --- scripts/datasets/check_bcss.py | 16 +++++++--------- torch_em/data/datasets/bcss.py | 33 ++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/scripts/datasets/check_bcss.py b/scripts/datasets/check_bcss.py index 384c6854..fa495973 100644 --- a/scripts/datasets/check_bcss.py +++ b/scripts/datasets/check_bcss.py @@ -1,8 +1,7 @@ import numpy as np from typing import Optional, List -import vigra - +import torch_em from torch_em.util.debug import check_loader from torch_em.data.datasets import get_bcss_loader @@ -22,13 +21,12 @@ def __call__( self, labels: np.ndarray ) -> np.ndarray: - """Returns the transformed labels (use-case for SAM) - """ + """Returns the transformed labels (use-case for SAM)""" if self.label_choices is not None: labels[~np.isin(labels, self.label_choices)] = 0 - segmentation, _, _ = vigra.analysis.relabelConsecutive(labels.astype("uint64")) + segmentation = torch_em.transform.label.label_consecutive(labels) else: - segmentation, _, _ = vigra.analysis.relabelConsecutive(labels) + segmentation = torch_em.transform.label.label_consecutive(labels) return segmentation @@ -40,11 +38,11 @@ def check_bcss(): chosen_label_loader = get_bcss_loader( path=BCSS_ROOT, patch_shape=(512, 512), - batch_size=2, + batch_size=1, download=False, - label_transform=BCSSLabelTrafo(label_choices=[0, 1, 2]), - split="train" + label_transform=BCSSLabelTrafo(label_choices=[0, 1, 2]) ) + print("Length of loader:", len(chosen_label_loader)) check_loader(chosen_label_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./bcss.png") diff --git a/torch_em/data/datasets/bcss.py b/torch_em/data/datasets/bcss.py index 00ec320e..adc079a4 100644 --- a/torch_em/data/datasets/bcss.py +++ b/torch_em/data/datasets/bcss.py @@ -88,7 +88,7 @@ def _assort_bcss_data(path, download): shutil.copy(src=label_path, dst=dst_lab_path) -def get_bcss_dataset(path, patch_shape, split, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs): +def get_bcss_dataset(path, patch_shape, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs): """Dataset for breast cancer tissue segmentation in histopathology. This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/. @@ -118,24 +118,27 @@ def get_bcss_dataset(path, patch_shape, split, val_fraction=0.2, download=False, - 20: dcis - 21: other """ - assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits" - _assort_bcss_data(path, download) - if split == "test": - image_paths = sorted(glob(os.path.join(path, "test", "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "test", "masks", "*"))) + if split is None: + image_paths = sorted(glob(os.path.join(path, "*", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "*", "masks", "*"))) else: - image_paths = sorted(glob(os.path.join(path, "train", "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "train", "masks", "*"))) + assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits" + if split == "test": + image_paths = sorted(glob(os.path.join(path, "test", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "test", "masks", "*"))) + else: + image_paths = sorted(glob(os.path.join(path, "train", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "train", "masks", "*"))) - (train_image_paths, val_image_paths, - train_label_paths, val_label_paths) = train_test_split( - image_paths, label_paths, test_size=val_fraction, random_state=42 - ) + (train_image_paths, val_image_paths, + train_label_paths, val_label_paths) = train_test_split( + image_paths, label_paths, test_size=val_fraction, random_state=42 + ) - image_paths = train_image_paths if split == "train" else val_image_paths - label_paths = train_label_paths if split == "train" else val_label_paths + image_paths = train_image_paths if split == "train" else val_image_paths + label_paths = train_label_paths if split == "train" else val_label_paths assert len(image_paths) == len(label_paths) @@ -146,7 +149,7 @@ def get_bcss_dataset(path, patch_shape, split, val_fraction=0.2, download=False, def get_bcss_loader( - path, patch_shape, batch_size, split, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs + path, patch_shape, batch_size, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs ): """Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details.""" ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)