Skip to content

Commit

Permalink
Implement sponge em dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 15, 2023
1 parent 5f5c424 commit 3e29988
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
11 changes: 11 additions & 0 deletions scripts/datasets/check_sponge_em.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch_em.data.datasets import get_sponge_em_loader
from torch_em.util.debug import check_loader


def check_sponge_em():
loader = get_sponge_em_loader("./data/sponge_em", "instances", (8, 512, 512), 1, download=True)
check_loader(loader, 8, instance_labels=True)


if __name__ == "__main__":
check_sponge_em()
1 change: 1 addition & 0 deletions torch_em/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_platynereis_nuclei_loader, get_platynereis_nuclei_dataset
)
from .snemi import get_snemi_loader, get_snemi_dataset
from .sponge_em import get_sponge_em_loader, get_sponge_em_dataset
from .tissuenet import get_tissuenet_loader, get_tissuenet_dataset
from .uro_cell import get_uro_cell_loader, get_uro_cell_dataset
from .util import get_bioimageio_dataset_id
Expand Down
41 changes: 41 additions & 0 deletions torch_em/data/datasets/sponge_em.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from glob import glob

import torch_em
from . import util

URL = "https://zenodo.org/record/8150818/files/sponge_em_train_data.zip?download=1"
CHECKSUM = "f1df616cd60f81b91d7642933e9edd74dc6c486b2e546186a7c1e54c67dd32a5"


def _require_sponge_em_data(path, download):
os.makedirs(path, exist_ok=True)
zip_path = os.path.join(path, "data.zip")
util.download_source(zip_path, URL, download, CHECKSUM)
util.unzip(zip_path, path)


def get_sponge_em_dataset(path, mode, patch_shape, sample_ids=None, download=False, **kwargs):
assert mode in ("semantic", "instances")

n_files = len(glob(os.path.join(path, "*.h5")))
if n_files == 0:
_require_sponge_em_data(path, download)
n_files = len(glob(os.path.join(path, "*.h5")))
assert n_files == 3

if sample_ids is None:
sample_ids = range(1, n_files + 1)
paths = [os.path.join(path, f"train_data_0{i}.h5") for i in sample_ids]

raw_key = "volumes/raw"
label_key = f"volumes/labels/{mode}"
return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs)


def get_sponge_em_loader(path, mode, patch_shape, batch_size, sample_ids=None, download=False, **kwargs):
ds_kwargs, loader_kwargs = util.split_kwargs(
torch_em.default_segmentation_dataset, **kwargs
)
ds = get_sponge_em_dataset(path, mode, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs)
return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
2 changes: 1 addition & 1 deletion torch_em/data/datasets/uro_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,6 @@ def get_uro_cell_loader(
torch_em.default_segmentation_dataset, **kwargs
)
ds = get_uro_cell_dataset(
path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **kwargs
path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs
)
return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

0 comments on commit 3e29988

Please sign in to comment.