From f5f587608d30244bfe72fcf12252a1b3f5b0b87e Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:26:13 +0200 Subject: [PATCH] Minor updates to CAMUS dataset (#292) Update camus dataset - resize_inputs functionality --- scripts/datasets/check_camus.py | 50 ------------------------- scripts/datasets/medical/check_camus.py | 21 +++++++++++ torch_em/data/datasets/medical/camus.py | 19 ++++------ 3 files changed, 29 insertions(+), 61 deletions(-) delete mode 100644 scripts/datasets/check_camus.py create mode 100644 scripts/datasets/medical/check_camus.py diff --git a/scripts/datasets/check_camus.py b/scripts/datasets/check_camus.py deleted file mode 100644 index ef268785..00000000 --- a/scripts/datasets/check_camus.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_camus_loader - - -ROOT = "/media/anwai/ANWAI/data/camus" - - -def check_camus(): - loader = get_camus_loader( - path=ROOT, - patch_shape=(1, 256, 256), - batch_size=2, - chamber=2, - resize_inputs=True, - download=False, - ) - check_loader(loader, 8) - - -def test_camus_images(): - import os - from glob import glob - - import napari - import nibabel as nib - - all_patient_dir = glob(os.path.join(ROOT, "database_nifti", "patient*")) - - # v = napari.Viewer() - - for per_patient_dir in all_patient_dir: - all_volume_paths = sorted(glob(os.path.join(per_patient_dir, "*_4CH_*.nii.gz"))) - for vol_path in all_volume_paths: - vol = nib.load(vol_path) - vol = vol.get_fdata() - - if vol.ndim == 2: - print(vol.shape) - # v.add_image(vol) - - print() - - # napari.run() - - # breakpoint() - - -if __name__ == "__main__": - # test_camus_images() - check_camus() diff --git a/scripts/datasets/medical/check_camus.py b/scripts/datasets/medical/check_camus.py new file mode 100644 index 00000000..79fb7f87 --- /dev/null +++ b/scripts/datasets/medical/check_camus.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_camus_loader + + +ROOT = "/media/anwai/ANWAI/data/camus" + + +def check_camus(): + loader = get_camus_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + chamber=2, + resize_inputs=True, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_camus() diff --git a/torch_em/data/datasets/medical/camus.py b/torch_em/data/datasets/medical/camus.py index 882ddfad..ff360a9e 100644 --- a/torch_em/data/datasets/medical/camus.py +++ b/torch_em/data/datasets/medical/camus.py @@ -3,13 +3,14 @@ from typing import Union, Tuple, Optional import torch_em -from torch_em.transform.generic import ResizeInputs from .. import util URL = "https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/folder/63fde55f73e9f004868fb7ac/download" -CHECKSUM = "43745d640db5d979332bda7f00f4746747a2591b46efc8f1966b573ce8d65655" + +# TODO: the checksums are different with each download, not sure why +# CHECKSUM = "43745d640db5d979332bda7f00f4746747a2591b46efc8f1966b573ce8d65655" def get_camus_data(path, download): @@ -20,7 +21,7 @@ def get_camus_data(path, download): return data_dir zip_path = os.path.join(path, "CAMUS.zip") - util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.download_source(path=zip_path, url=URL, download=download, checksum=None) util.unzip(zip_path=zip_path, dst=path) return data_dir @@ -60,12 +61,10 @@ def get_camus_dataset( image_paths, gt_paths = _get_camus_paths(path=path, chamber=chamber, download=download) if resize_inputs: - raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False) - label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) - patch_shape = None - else: - patch_shape = patch_shape - raw_trafo, label_trafo = None, None + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) dataset = torch_em.default_segmentation_dataset( raw_paths=image_paths, @@ -73,8 +72,6 @@ def get_camus_dataset( label_paths=gt_paths, label_key="data", patch_shape=patch_shape, - raw_transform=raw_trafo, - label_transform=label_trafo, **kwargs )