diff --git a/scripts/datasets/medical/check_micro_usp.py b/scripts/datasets/medical/check_micro_usp.py new file mode 100644 index 00000000..9ea3ed56 --- /dev/null +++ b/scripts/datasets/medical/check_micro_usp.py @@ -0,0 +1,24 @@ +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_micro_usp_loader + + +ROOT = "/media/anwai/ANWAI/data/micro-usp" + + +def check_micro_usp(): + loader = get_micro_usp_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + split="train", + resize_inputs=True, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_micro_usp() diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 4402c9a0..40b52fb5 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -6,9 +6,10 @@ from .feta24 import get_feta24_dataset, get_feta24_loader from .idrid import get_idrid_dataset, get_idrid_loader from .jnuifm import get_jnuifm_dataset, get_jnuifm_loader +from .micro_usp import get_micro_usp_dataset, get_micro_usp_loader from .montgomery import get_montgomery_dataset, get_montgomery_loader -from .oimhs import get_oimhs_dataset, get_oimhs_loader from .msd import get_msd_dataset, get_msd_loader +from .oimhs import get_oimhs_dataset, get_oimhs_loader from .osic_pulmofib import get_osic_pulmofib_dataset, get_osic_pulmofib_loader from .papila import get_papila_dataset, get_papila_loader from .plethora import get_plethora_dataset, get_plethora_loader diff --git a/torch_em/data/datasets/medical/micro_usp.py b/torch_em/data/datasets/medical/micro_usp.py new file mode 100644 index 00000000..eb22ca7f --- /dev/null +++ b/torch_em/data/datasets/medical/micro_usp.py @@ -0,0 +1,89 @@ +import os +from glob import glob +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/10475293/files/Micro_Ultrasound_Prostate_Segmentation_Dataset.zip?" +CHECKSUM = "031645dc30948314e379d0a0a7d54bad1cd4e1f3f918b77455d69810aa05dce3" + + +def get_micro_usp_data(path, download): + os.makedirs(path, exist_ok=True) + + fname = Path(URL).stem + data_dir = os.path.join(path, fname) + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, f"{fname}.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_micro_usp_paths(path, split, download): + data_dir = get_micro_usp_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, split, "micro_ultrasound_scans", "*.nii.gz"))) + gt_paths = natsorted(glob(os.path.join(data_dir, split, "expert_annotations", "*.nii.gz"))) + + return image_paths, gt_paths + + +def get_micro_usp_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of prostate in micro-ultrasound scans. + + This dataset is from Jiang et al. - https://doi.org/10.1016/j.compmedimag.2024.102326. + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_micro_usp_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_micro_usp_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of prostate in micro-ultrasound scans. See `get_micro_usp_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_micro_usp_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader