From 628dafd2dc8b700971cbb8928c02faa757e00570 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 24 Jun 2024 09:20:14 +0200 Subject: [PATCH] Update resize inputs pipeline --- .../datasets/{ => medical}/check_siim_acr.py | 6 ++-- torch_em/data/datasets/medical/siim_acr.py | 31 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) rename scripts/datasets/{ => medical}/check_siim_acr.py (74%) diff --git a/scripts/datasets/check_siim_acr.py b/scripts/datasets/medical/check_siim_acr.py similarity index 74% rename from scripts/datasets/check_siim_acr.py rename to scripts/datasets/medical/check_siim_acr.py index ebceb51a..1f3df848 100644 --- a/scripts/datasets/check_siim_acr.py +++ b/scripts/datasets/medical/check_siim_acr.py @@ -3,7 +3,7 @@ from torch_em.data.datasets.medical import get_siim_acr_loader -ROOT = "/media/anwai/ANWAI/data/siim_acr" +ROOT = "/scratch/share/cidas/cca/data/siim_acr" def check_siim_acr(): @@ -13,10 +13,10 @@ def check_siim_acr(): patch_shape=(512, 512), batch_size=2, download=True, - resize_inputs=False, + resize_inputs=True, sampler=MinInstanceSampler() ) - check_loader(loader, 8) + check_loader(loader, 8, plt=True, save_path="./siim_acr.png") if __name__ == "__main__": diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py index f0033dcb..7326c7e1 100644 --- a/torch_em/data/datasets/medical/siim_acr.py +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -1,12 +1,10 @@ import os from glob import glob -from typing import Union, Tuple +from typing import Union, Tuple, Literal import torch_em -from torch_em.transform.generic import ResizeInputs from .. import util -from ... import ImageCollectionDataset KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks" @@ -42,7 +40,7 @@ def _get_siim_acr_paths(path, split, download): def get_siim_acr_dataset( path: Union[os.PathLike, str], - split: str, + split: Literal["train", "test"], patch_shape: Tuple[int, int], download: bool = False, resize_inputs: bool = False, @@ -60,19 +58,18 @@ def get_siim_acr_dataset( image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download) if resize_inputs: - raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False) - label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) - patch_shape = None - else: - patch_shape = patch_shape - raw_trafo, label_trafo = None, None - - dataset = ImageCollectionDataset( - raw_image_paths=image_paths, - label_image_paths=gt_paths, + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, patch_shape=patch_shape, - raw_transform=raw_trafo, - label_transform=label_trafo, + is_seg_dataset=False, **kwargs ) dataset.max_sampling_attempts = 5000 @@ -82,7 +79,7 @@ def get_siim_acr_dataset( def get_siim_acr_loader( path: Union[os.PathLike, str], - split: str, + split: Literal["train", "test"], patch_shape: Tuple[int, int], batch_size: int, download: bool = False,