From 223c79b0ab2634ed166d29f4fd9662a3f47f6c58 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Sun, 26 May 2024 13:30:41 +0200 Subject: [PATCH 1/3] Add MicroUS prostate segmentation dataset --- scripts/datasets/medical/check_micro_usp.py | 22 +++++ torch_em/data/datasets/medical/__init__.py | 1 + torch_em/data/datasets/medical/micro_usp.py | 99 +++++++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 scripts/datasets/medical/check_micro_usp.py create mode 100644 torch_em/data/datasets/medical/micro_usp.py diff --git a/scripts/datasets/medical/check_micro_usp.py b/scripts/datasets/medical/check_micro_usp.py new file mode 100644 index 00000000..b570edd4 --- /dev/null +++ b/scripts/datasets/medical/check_micro_usp.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_micro_usp_loader + + +ROOT = "/media/anwai/ANWAI/data/micro-usp" + + +def check_micro_usp(): + loader = get_micro_usp_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + split="train", + resize_inputs=False, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_micro_usp() diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 88c0ffa9..997037d7 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -3,6 +3,7 @@ from .busi import get_busi_dataset, get_busi_loader from .camus import get_camus_dataset, get_camus_loader from .drive import get_drive_dataset, get_drive_loader +from .micro_usp import get_micro_usp_dataset, get_micro_usp_loader from .papila import get_papila_dataset, get_papila_loader from .plethora import get_plethora_dataset, get_plethora_loader from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader diff --git a/torch_em/data/datasets/medical/micro_usp.py b/torch_em/data/datasets/medical/micro_usp.py new file mode 100644 index 00000000..e094e3b7 --- /dev/null +++ b/torch_em/data/datasets/medical/micro_usp.py @@ -0,0 +1,99 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/10475293/files/Micro_Ultrasound_Prostate_Segmentation_Dataset.zip?" +CHECKSUM = "031645dc30948314e379d0a0a7d54bad1cd4e1f3f918b77455d69810aa05dce3" +FNAME = "Micro_Ultrasound_Prostate_Segmentation_Dataset" + + +def get_micro_usp_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, FNAME) + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, f"{FNAME}.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_micro_usp_paths(path, split, download): + data_dir = get_micro_usp_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, split, "micro_ultrasound_scans", "*.nii.gz"))) + gt_paths = natsorted(glob(os.path.join(data_dir, split, "expert_annotations", "*.nii.gz"))) + + for image_path, gt_path in zip(image_paths, gt_paths): + import nibabel as nib + + image = nib.load(image_path) + image = image.get_fdata() + + gt = nib.load(gt_path) + gt = gt.get_fdata() + + import napari + v = napari.Viewer() + v.add_image(image.transpose(2, 0, 1)) + v.add_labels(gt.transpose(2, 0, 1).astype("uint8")) + napari.run() + + breakpoint() + + return image_paths, gt_paths + + +def get_micro_usp_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of prostate in micro-ultrasound scans. + + This dataset is from Jiang et al. - https://doi.org/10.1016/j.compmedimag.2024.102326. + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_micro_usp_paths(path=path, split=split, download=download) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_micro_usp_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of prostate in micro-ultrasound scans. See `get_micro_usp_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_micro_usp_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader From c01b8c0b83a7d1aa75b98cb9daf4f4b26fdfee63 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 5 Jun 2024 15:47:17 +0200 Subject: [PATCH 2/3] Add resize_inputs functionality --- torch_em/data/datasets/medical/micro_usp.py | 23 +++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/torch_em/data/datasets/medical/micro_usp.py b/torch_em/data/datasets/medical/micro_usp.py index e094e3b7..72a8fd4c 100644 --- a/torch_em/data/datasets/medical/micro_usp.py +++ b/torch_em/data/datasets/medical/micro_usp.py @@ -1,5 +1,6 @@ import os from glob import glob +from pathlib import Path from natsort import natsorted from typing import Union, Tuple @@ -16,11 +17,12 @@ def get_micro_usp_data(path, download): os.makedirs(path, exist_ok=True) - data_dir = os.path.join(path, FNAME) + fname = Path(URL).stem + data_dir = os.path.join(path, fname) if os.path.exists(data_dir): return data_dir - zip_path = os.path.join(path, f"{FNAME}.zip") + zip_path = os.path.join(path, f"{fname}.zip") util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) util.unzip(zip_path=zip_path, dst=path) @@ -33,14 +35,11 @@ def _get_micro_usp_paths(path, split, download): image_paths = natsorted(glob(os.path.join(data_dir, split, "micro_ultrasound_scans", "*.nii.gz"))) gt_paths = natsorted(glob(os.path.join(data_dir, split, "expert_annotations", "*.nii.gz"))) - for image_path, gt_path in zip(image_paths, gt_paths): - import nibabel as nib - - image = nib.load(image_path) - image = image.get_fdata() + from tukra.utils import read_image - gt = nib.load(gt_path) - gt = gt.get_fdata() + for image_path, gt_path in zip(image_paths, gt_paths): + image = read_image(image_path) + gt = read_image(gt_path) import napari v = napari.Viewer() @@ -68,6 +67,12 @@ def get_micro_usp_dataset( """ image_paths, gt_paths = _get_micro_usp_paths(path=path, split=split, download=download) + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + dataset = torch_em.default_segmentation_dataset( raw_paths=image_paths, raw_key="data", From 67f1d4cddbd09aaf9f6239b81a55dde7265e6647 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 5 Jun 2024 15:50:51 +0200 Subject: [PATCH 3/3] Remove visualization scripts --- scripts/datasets/medical/check_micro_usp.py | 4 +++- torch_em/data/datasets/medical/micro_usp.py | 15 --------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/scripts/datasets/medical/check_micro_usp.py b/scripts/datasets/medical/check_micro_usp.py index b570edd4..9ea3ed56 100644 --- a/scripts/datasets/medical/check_micro_usp.py +++ b/scripts/datasets/medical/check_micro_usp.py @@ -1,3 +1,4 @@ +from torch_em.data import MinInstanceSampler from torch_em.util.debug import check_loader from torch_em.data.datasets.medical import get_micro_usp_loader @@ -11,8 +12,9 @@ def check_micro_usp(): patch_shape=(1, 512, 512), batch_size=2, split="train", - resize_inputs=False, + resize_inputs=True, download=True, + sampler=MinInstanceSampler(), ) check_loader(loader, 8) diff --git a/torch_em/data/datasets/medical/micro_usp.py b/torch_em/data/datasets/medical/micro_usp.py index 72a8fd4c..eb22ca7f 100644 --- a/torch_em/data/datasets/medical/micro_usp.py +++ b/torch_em/data/datasets/medical/micro_usp.py @@ -11,7 +11,6 @@ URL = "https://zenodo.org/records/10475293/files/Micro_Ultrasound_Prostate_Segmentation_Dataset.zip?" CHECKSUM = "031645dc30948314e379d0a0a7d54bad1cd4e1f3f918b77455d69810aa05dce3" -FNAME = "Micro_Ultrasound_Prostate_Segmentation_Dataset" def get_micro_usp_data(path, download): @@ -35,20 +34,6 @@ def _get_micro_usp_paths(path, split, download): image_paths = natsorted(glob(os.path.join(data_dir, split, "micro_ultrasound_scans", "*.nii.gz"))) gt_paths = natsorted(glob(os.path.join(data_dir, split, "expert_annotations", "*.nii.gz"))) - from tukra.utils import read_image - - for image_path, gt_path in zip(image_paths, gt_paths): - image = read_image(image_path) - gt = read_image(gt_path) - - import napari - v = napari.Viewer() - v.add_image(image.transpose(2, 0, 1)) - v.add_labels(gt.transpose(2, 0, 1).astype("uint8")) - napari.run() - - breakpoint() - return image_paths, gt_paths