From d6f35c9ff8c64ece98143a15749d3854a759fa3b Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 15 May 2024 00:53:49 +0200 Subject: [PATCH 1/6] Add cxr pneumothorax dataset --- scripts/datasets/check_siim_acr.py | 22 +++++++ torch_em/data/datasets/medical/__init__.py | 1 + torch_em/data/datasets/medical/siim_acr.py | 71 ++++++++++++++++++++++ torch_em/data/datasets/util.py | 17 ++++++ 4 files changed, 111 insertions(+) create mode 100644 scripts/datasets/check_siim_acr.py create mode 100644 torch_em/data/datasets/medical/siim_acr.py diff --git a/scripts/datasets/check_siim_acr.py b/scripts/datasets/check_siim_acr.py new file mode 100644 index 00000000..98e13d6a --- /dev/null +++ b/scripts/datasets/check_siim_acr.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data import MinForegroundSampler +from torch_em.data.datasets.medical import get_siim_acr_loader + + +ROOT = "/media/anwai/ANWAI/data/siim_acr" + + +def check_siim_acr(): + loader = get_siim_acr_loader( + path=ROOT, + split="train", + patch_shape=(1024, 1024), + batch_size=2, + download=True, + sampler=MinForegroundSampler(min_fraction=0.001) + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_siim_acr() diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 5bc385b8..4695644c 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -1,2 +1,3 @@ from .autopet import get_autopet_loader from .btcv import get_btcv_dataset, get_btcv_loader +from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py new file mode 100644 index 00000000..cade75fc --- /dev/null +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -0,0 +1,71 @@ +import os +from glob import glob +from typing import Union, Tuple + +import torch_em + +from .. import util +from ... import ImageCollectionDataset + + +KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks" +CHECKSUM = "1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538" + + +def get_siim_acr_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "siim-acr-pneumothorax") + if os.path.exists(data_dir): + return data_dir + + util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download) + + zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip") + util._check_checksum(path=zip_path, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_siim_acr_paths(path, split, download): + data_dir = get_siim_acr_data(path=path, download=download) + + assert split in ["train", "test"], f"'{split}' is not a valid split." + + image_paths = sorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png"))) + gt_paths = sorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png"))) + + return image_paths, gt_paths + + +def get_siim_acr_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + download: bool = False, + **kwargs +): + image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download) + + dataset = ImageCollectionDataset( + raw_image_paths=image_paths, label_image_paths=gt_paths, patch_shape=patch_shape, **kwargs + ) + + return dataset + + +def get_siim_acr_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + **kwargs +): + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_siim_acr_dataset( + path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index e6a38236..3fd416e4 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -146,6 +146,23 @@ def download_source_empiar(path, access_id, download): return download_path +def download_source_kaggle(path, dataset_name, download): + if not download: + raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") + + try: + from kaggle.api.kaggle_api_extended import KaggleApi + except ModuleNotFoundError: + msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " + msg += "After you have installed kaggle, you would need an API token. " + msg += "Follow the instructions at https://www.kaggle.com/docs/api." + raise ModuleNotFoundError(msg) + + api = KaggleApi() + api.authenticate() + api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) + + def update_kwargs(kwargs, key, value, msg=None): if key in kwargs: msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg From db741ea195dd50d16615812ddc9d5d7b17eb5b05 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 15 May 2024 08:28:06 +0200 Subject: [PATCH 2/6] Add docstrings --- torch_em/data/datasets/medical/siim_acr.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py index cade75fc..a7f1859c 100644 --- a/torch_em/data/datasets/medical/siim_acr.py +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -46,6 +46,15 @@ def get_siim_acr_dataset( download: bool = False, **kwargs ): + """Dataset for pneumothorax segmentation in CXR. + + The database is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data + + This dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition: + https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation + + Please cite it if you use this dataset for a publication. + """ image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download) dataset = ImageCollectionDataset( @@ -63,6 +72,8 @@ def get_siim_acr_loader( download: bool = False, **kwargs ): + """Dataloader for pneumothorax segmentation in CXR. See `get_siim_acr_dataset` for details. + """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_siim_acr_dataset( path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs From 023b9fe77641272345da51b1f2c38ce1c755c386 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 15 May 2024 08:47:44 +0200 Subject: [PATCH 3/6] Allow patch_shape as None to get the true image shape --- scripts/datasets/check_siim_acr.py | 5 +++-- torch_em/data/datasets/medical/siim_acr.py | 20 ++++++++++++++++-- torch_em/data/image_collection_dataset.py | 15 ++++++++++---- torch_em/transform/generic.py | 24 +++++++++++++++++++++- 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/scripts/datasets/check_siim_acr.py b/scripts/datasets/check_siim_acr.py index 98e13d6a..f3cd850d 100644 --- a/scripts/datasets/check_siim_acr.py +++ b/scripts/datasets/check_siim_acr.py @@ -10,10 +10,11 @@ def check_siim_acr(): loader = get_siim_acr_loader( path=ROOT, split="train", - patch_shape=(1024, 1024), + patch_shape=(512, 512), batch_size=2, download=True, - sampler=MinForegroundSampler(min_fraction=0.001) + resize_inputs=False, + sampler=MinForegroundSampler(min_fraction=0.001), ) check_loader(loader, 8) diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py index a7f1859c..5096f42e 100644 --- a/torch_em/data/datasets/medical/siim_acr.py +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -3,6 +3,7 @@ from typing import Union, Tuple import torch_em +from torch_em.transform.generic import ResizeInputs from .. import util from ... import ImageCollectionDataset @@ -44,6 +45,7 @@ def get_siim_acr_dataset( split: str, patch_shape: Tuple[int, int], download: bool = False, + resize_inputs: bool = False, **kwargs ): """Dataset for pneumothorax segmentation in CXR. @@ -57,8 +59,21 @@ def get_siim_acr_dataset( """ image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download) + if resize_inputs: + raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False) + label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) + patch_shape = None + else: + patch_shape = patch_shape + raw_trafo, label_trafo = None, None + dataset = ImageCollectionDataset( - raw_image_paths=image_paths, label_image_paths=gt_paths, patch_shape=patch_shape, **kwargs + raw_image_paths=image_paths, + label_image_paths=gt_paths, + patch_shape=patch_shape, + raw_transform=raw_trafo, + label_transform=label_trafo, + **kwargs ) return dataset @@ -70,13 +85,14 @@ def get_siim_acr_loader( patch_shape: Tuple[int, int], batch_size: int, download: bool = False, + resize_inputs: bool = False, **kwargs ): """Dataloader for pneumothorax segmentation in CXR. See `get_siim_acr_dataset` for details. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_siim_acr_dataset( - path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs + path=path, split=split, patch_shape=patch_shape, download=download, resize_inputs=resize_inputs, **ds_kwargs ) loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) return loader diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index e005d15a..4b6e3e4a 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -68,7 +68,8 @@ def __init__( self.label_images = label_image_paths self._ndim = 2 - assert len(patch_shape) == self._ndim + if patch_shape is not None: + assert len(patch_shape) == self._ndim self.patch_shape = patch_shape self.raw_transform = raw_transform @@ -95,11 +96,16 @@ def ndim(self): return self._ndim def _sample_bounding_box(self, shape): + if self.patch_shape is None: + patch_shape_for_bb = shape + else: + patch_shape_for_bb = self.patch_shape + bb_start = [ np.random.randint(0, sh - psh) if sh - psh > 0 else 0 - for sh, psh in zip(shape, self.patch_shape) + for sh, psh in zip(shape, patch_shape_for_bb) ] - return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) + return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first): shape = raw.shape @@ -137,7 +143,8 @@ def _load_data(self, raw_path, label_path): if have_raw_channels: channel_first = raw.shape[-1] > 16 - raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first) + if self.patch_shape is not None: + raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first) shape = raw.shape prefix_box = tuple() diff --git a/torch_em/transform/generic.py b/torch_em/transform/generic.py index b74c5be2..26b5faa6 100644 --- a/torch_em/transform/generic.py +++ b/torch_em/transform/generic.py @@ -3,7 +3,7 @@ import numpy as np import torch -from skimage.transform import rescale +from skimage.transform import rescale, resize class Tile(torch.nn.Module): @@ -72,6 +72,28 @@ def __call__(self, *inputs): return outputs +class ResizeInputs: + def __init__(self, target_shape, is_label=False): + self.target_shape = target_shape + self.is_label = is_label + + def __call__(self, inputs): + if self.is_label: + anti_aliasing = True + else: + anti_aliasing = False + + inputs = resize( + image=inputs, + output_shape=self.target_shape, + order=3, + anti_aliasing=anti_aliasing, + preserve_range=True, + ) + + return inputs + + class PadIfNecessary: def __init__(self, shape): self.shape = tuple(shape) From f963e51c7c3025151372e9263344bd788d3a750d Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 15 May 2024 08:58:40 +0200 Subject: [PATCH 4/6] Increase sampling attempts for dataset --- scripts/datasets/check_siim_acr.py | 6 +++--- torch_em/data/datasets/medical/siim_acr.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/datasets/check_siim_acr.py b/scripts/datasets/check_siim_acr.py index f3cd850d..288991e1 100644 --- a/scripts/datasets/check_siim_acr.py +++ b/scripts/datasets/check_siim_acr.py @@ -1,5 +1,5 @@ from torch_em.util.debug import check_loader -from torch_em.data import MinForegroundSampler +from torch_em.data import MinInstanceSampler from torch_em.data.datasets.medical import get_siim_acr_loader @@ -13,8 +13,8 @@ def check_siim_acr(): patch_shape=(512, 512), batch_size=2, download=True, - resize_inputs=False, - sampler=MinForegroundSampler(min_fraction=0.001), + resize_inputs=True, + sampler=MinInstanceSampler() ) check_loader(loader, 8) diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py index 5096f42e..f0033dcb 100644 --- a/torch_em/data/datasets/medical/siim_acr.py +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -75,6 +75,7 @@ def get_siim_acr_dataset( label_transform=label_trafo, **kwargs ) + dataset.max_sampling_attempts = 5000 return dataset From 79125425e5cdcd1e27274b36bec863e29072d2eb Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 15 May 2024 09:42:11 +0200 Subject: [PATCH 5/6] Refactor bb_start logic --- torch_em/data/image_collection_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index 4b6e3e4a..2c47083a 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -98,13 +98,14 @@ def ndim(self): def _sample_bounding_box(self, shape): if self.patch_shape is None: patch_shape_for_bb = shape + bb_start = [0] * len(shape) else: patch_shape_for_bb = self.patch_shape + bb_start = [ + np.random.randint(0, sh - psh) if sh - psh > 0 else 0 + for sh, psh in zip(shape, patch_shape_for_bb) + ] - bb_start = [ - np.random.randint(0, sh - psh) if sh - psh > 0 else 0 - for sh, psh in zip(shape, patch_shape_for_bb) - ] return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first): From 11d79fb3b5b7ef28a5bf5dbccad0d8f55238f4d5 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 15 May 2024 09:46:12 +0200 Subject: [PATCH 6/6] Sort is_label criterion for resize params --- scripts/datasets/check_siim_acr.py | 2 +- torch_em/transform/generic.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/scripts/datasets/check_siim_acr.py b/scripts/datasets/check_siim_acr.py index 288991e1..ebceb51a 100644 --- a/scripts/datasets/check_siim_acr.py +++ b/scripts/datasets/check_siim_acr.py @@ -13,7 +13,7 @@ def check_siim_acr(): patch_shape=(512, 512), batch_size=2, download=True, - resize_inputs=True, + resize_inputs=False, sampler=MinInstanceSampler() ) check_loader(loader, 8) diff --git a/torch_em/transform/generic.py b/torch_em/transform/generic.py index 26b5faa6..a9068e0b 100644 --- a/torch_em/transform/generic.py +++ b/torch_em/transform/generic.py @@ -1,10 +1,10 @@ from typing import Any, Dict, Optional, Sequence, Union import numpy as np -import torch - from skimage.transform import rescale, resize +import torch + class Tile(torch.nn.Module): _params = None @@ -78,18 +78,17 @@ def __init__(self, target_shape, is_label=False): self.is_label = is_label def __call__(self, inputs): - if self.is_label: - anti_aliasing = True - else: - anti_aliasing = False + if self.is_label: # kwargs needed for int data + kwargs = {"order": 0, "anti_aliasing": False} + else: # we use the default settings for float data + kwargs = {} inputs = resize( image=inputs, output_shape=self.target_shape, - order=3, - anti_aliasing=anti_aliasing, preserve_range=True, - ) + **kwargs + ).astype(inputs.dtype) return inputs