diff --git a/scripts/datasets/check_sponge_em.py b/scripts/datasets/check_sponge_em.py new file mode 100644 index 00000000..fb10f631 --- /dev/null +++ b/scripts/datasets/check_sponge_em.py @@ -0,0 +1,11 @@ +from torch_em.data.datasets import get_sponge_em_loader +from torch_em.util.debug import check_loader + + +def check_sponge_em(): + loader = get_sponge_em_loader("./data/sponge_em", "instances", (8, 512, 512), 1, download=True) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_sponge_em() diff --git a/torch_em/data/datasets/__init__.py b/torch_em/data/datasets/__init__.py index fef8a0af..35ac4022 100644 --- a/torch_em/data/datasets/__init__.py +++ b/torch_em/data/datasets/__init__.py @@ -27,6 +27,7 @@ get_platynereis_nuclei_loader, get_platynereis_nuclei_dataset ) from .snemi import get_snemi_loader, get_snemi_dataset +from .sponge_em import get_sponge_em_loader, get_sponge_em_dataset from .tissuenet import get_tissuenet_loader, get_tissuenet_dataset from .uro_cell import get_uro_cell_loader, get_uro_cell_dataset from .util import get_bioimageio_dataset_id diff --git a/torch_em/data/datasets/sponge_em.py b/torch_em/data/datasets/sponge_em.py new file mode 100644 index 00000000..15de1b57 --- /dev/null +++ b/torch_em/data/datasets/sponge_em.py @@ -0,0 +1,41 @@ +import os +from glob import glob + +import torch_em +from . import util + +URL = "https://zenodo.org/record/8150818/files/sponge_em_train_data.zip?download=1" +CHECKSUM = "f1df616cd60f81b91d7642933e9edd74dc6c486b2e546186a7c1e54c67dd32a5" + + +def _require_sponge_em_data(path, download): + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "data.zip") + util.download_source(zip_path, URL, download, CHECKSUM) + util.unzip(zip_path, path) + + +def get_sponge_em_dataset(path, mode, patch_shape, sample_ids=None, download=False, **kwargs): + assert mode in ("semantic", "instances") + + n_files = len(glob(os.path.join(path, "*.h5"))) + if n_files == 0: + _require_sponge_em_data(path, download) + n_files = len(glob(os.path.join(path, "*.h5"))) + assert n_files == 3 + + if sample_ids is None: + sample_ids = range(1, n_files + 1) + paths = [os.path.join(path, f"train_data_0{i}.h5") for i in sample_ids] + + raw_key = "volumes/raw" + label_key = f"volumes/labels/{mode}" + return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) + + +def get_sponge_em_loader(path, mode, patch_shape, batch_size, sample_ids=None, download=False, **kwargs): + ds_kwargs, loader_kwargs = util.split_kwargs( + torch_em.default_segmentation_dataset, **kwargs + ) + ds = get_sponge_em_dataset(path, mode, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/uro_cell.py b/torch_em/data/datasets/uro_cell.py index 95074458..332c2e07 100644 --- a/torch_em/data/datasets/uro_cell.py +++ b/torch_em/data/datasets/uro_cell.py @@ -136,6 +136,6 @@ def get_uro_cell_loader( torch_em.default_segmentation_dataset, **kwargs ) ds = get_uro_cell_dataset( - path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **kwargs + path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs ) return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)