Skip to content

Commit

Permalink
Merge pull request #137 from constantinpape/expose-get-dataset
Browse files Browse the repository at this point in the history
Implement get_dataset for more datasets
  • Loading branch information
constantinpape committed Jul 14, 2023
2 parents 70a3629 + c8e06c6 commit 620ab64
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 77 deletions.
13 changes: 13 additions & 0 deletions scripts/datasets/check_livecell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch_em.data.datasets import get_livecell_loader
from torch_em.util.debug import check_loader

LIVECELL_ROOT = "/home/pape/Work/data/incu_cyte/livecell"


def check_livecell():
loader = get_livecell_loader(LIVECELL_ROOT, "train", (512, 512), 1)
check_loader(loader, 15, instance_labels=True)


if __name__ == "__main__":
check_livecell()
13 changes: 13 additions & 0 deletions scripts/datasets/check_neurips_cellseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch_em.data.datasets import get_neurips_cellseg_supervised_loader
from torch_em.util.debug import check_loader

NEURIPS_ROOT = "/home/pape/Work/data/neurips-cell-seg"


def check_neurips():
loader = get_neurips_cellseg_supervised_loader(NEURIPS_ROOT, "train", (512, 512), 1)
check_loader(loader, 15, instance_labels=True, rgb=True)


if __name__ == "__main__":
check_neurips()
30 changes: 30 additions & 0 deletions scripts/datasets/check_tissuenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from torch_em.transform.raw import standardize, normalize_percentile

from torch_em.data.datasets import get_tissuenet_loader
from torch_em.util.debug import check_loader

TISSUENET_ROOT = "/home/pape/Work/data/tissuenet"


def raw_trafo(raw):
raw = normalize_percentile(raw, axis=(1, 2))
raw = np.mean(raw, axis=0)
raw = standardize(raw)
return raw


# NOTE: the tissuenet data cannot be downloaded automatically.
# you need to download it yourself from https://datasets.deepcell.org/data
def check_tissuenet():
# set this path to where you have downloaded the tissuenet data
loader = get_tissuenet_loader(
TISSUENET_ROOT, "train", raw_channel="rgb", label_channel="cell",
patch_shape=(512, 512), batch_size=1, shuffle=True,
raw_transform=raw_trafo
)
check_loader(loader, 15, instance_labels=True, rgb=False)


if __name__ == "__main__":
check_tissuenet()
11 changes: 7 additions & 4 deletions torch_em/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@
from .cem import get_cem_mitolab_loader
from .covid_if import get_covid_if_loader
from .cremi import get_cremi_loader
from .deepbacs import get_deepbacs_loader
from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset
from .dsb import get_dsb_loader
from .hpa import get_hpa_segmentation_loader
from .isbi2012 import get_isbi_loader
from .kasthuri import get_kasthuri_loader
from .livecell import get_livecell_loader
from .livecell import get_livecell_loader, get_livecell_dataset
from .lucchi import get_lucchi_loader
from .mitoem import get_mitoem_loader
from .monuseg import get_monuseg_loader
from .mouse_embryo import get_mouse_embryo_loader
from .neurips_cell_seg import get_neurips_cellseg_supervised_loader, get_neurips_cellseg_unsupervised_loader
from .neurips_cell_seg import (
get_neurips_cellseg_supervised_loader, get_neurips_cellseg_supervised_dataset,
get_neurips_cellseg_unsupervised_loader
)
from .plantseg import get_plantseg_loader
from .platynereis import (get_platynereis_cell_loader,
get_platynereis_nuclei_loader)
from .snemi import get_snemi_loader
from .tissuenet import get_tissuenet_loader
from .tissuenet import (get_tissuenet_loader, get_tissuenet_dataset)
from .util import get_bioimageio_dataset_id
from .vnc import get_vnc_mito_loader
from .uro_cell import get_uro_cell_loader
20 changes: 17 additions & 3 deletions torch_em/data/datasets/deepbacs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from ...segmentation import default_segmentation_loader
import torch_em
from . import util

URLS = {
Expand Down Expand Up @@ -40,7 +40,9 @@ def _get_paths(path, bac_type, split):
return image_folder, label_folder


def get_deepbacs_loader(path, split, bac_type="mixed", download=False, **kwargs):
def get_deepbacs_dataset(
path, split, patch_shape, bac_type="mixed", download=False, **kwargs
):
assert split in ("train", "test")
bac_types = list(URLS.keys())
assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
Expand All @@ -51,4 +53,16 @@ def get_deepbacs_loader(path, split, bac_type="mixed", download=False, **kwargs)

image_folder, label_folder = _get_paths(path, bac_type, split)

return default_segmentation_loader(image_folder, "*.tif", label_folder, "*.tif", **kwargs)
kwargs = util.ensure_transforms(ndim=2, **kwargs)

dataset = torch_em.default_segmentation_dataset(
image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs
)
return dataset


def get_deepbacs_loader(path, split, patch_shape, batch_size, bac_type="mixed", download=False, **kwargs):
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader
89 changes: 29 additions & 60 deletions torch_em/data/datasets/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch_em
import torch.utils.data
from .util import download_source, unzip, update_kwargs
from . import util

try:
from pycocotools.coco import COCO
Expand Down Expand Up @@ -39,8 +39,8 @@ def _download_livecell_images(path, download):
url = URLS["images"]
checksum = CHECKSUM
zip_path = os.path.join(path, "livecell.zip")
download_source(zip_path, url, download, checksum)
unzip(zip_path, path, True)
util.download_source(zip_path, url, download, checksum)
util.unzip(zip_path, path, True)


# TODO use download flag
Expand Down Expand Up @@ -144,42 +144,11 @@ def _download_livecell_annotations(path, split, download, cell_types, label_path
return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types)


def _livecell_segmentation_loader(
image_paths, label_paths,
batch_size, patch_shape,
label_transform=None,
label_transform2=None,
raw_transform=None,
transform=None,
label_dtype=torch.float32,
dtype=torch.float32,
n_samples=None,
**loader_kwargs
def get_livecell_dataset(
path, split, patch_shape, download=False,
offsets=None, boundaries=False, binary=False,
cell_types=None, label_path=None, label_dtype=torch.int64, **kwargs
):

# we always use a raw transform in the convenience function
if raw_transform is None:
raw_transform = torch_em.transform.get_raw_transform()

# we always use augmentations in the convenience function
if transform is None:
transform = torch_em.transform.get_augmentations(ndim=2)

ds = torch_em.data.ImageCollectionDataset(image_paths, label_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
label_transform=label_transform,
label_transform2=label_transform2,
label_dtype=label_dtype,
transform=transform,
n_samples=n_samples)

return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)


def get_livecell_loader(path, patch_shape, split, download=False,
offsets=None, boundaries=False, binary=False,
cell_types=None, label_path=None, label_dtype=torch.int64, **kwargs):
assert split in ("train", "val", "test")
if cell_types is not None:
assert isinstance(cell_types, (list, tuple)),\
Expand All @@ -188,25 +157,25 @@ def get_livecell_loader(path, patch_shape, split, download=False,
_download_livecell_images(path, download)
image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types, label_path)

assert sum((offsets is not None, boundaries, binary)) <= 1
if offsets is not None:
# we add a binary target channel for foreground background segmentation
label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets,
add_binary_target=True,
add_mask=True)
msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden."
kwargs = update_kwargs(kwargs, 'label_transform2', label_transform, msg=msg)
label_dtype = torch.float32
elif boundaries:
label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
kwargs = update_kwargs(kwargs, 'label_transform', label_transform, msg=msg)
label_dtype = torch.float32
elif binary:
label_transform = torch_em.transform.label.labels_to_binary
msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden."
kwargs = update_kwargs(kwargs, 'label_transform', label_transform, msg=msg)
label_dtype = torch.float32

kwargs.update({"patch_shape": patch_shape})
return _livecell_segmentation_loader(image_paths, seg_paths, label_dtype=label_dtype, **kwargs)
kwargs = util.ensure_transforms(ndim=2, **kwargs)
kwargs, label_dtype = util.add_instance_label_transform(
kwargs, add_binary_target=True, label_dtype=label_dtype,
offsets=offsets, boundaries=boundaries, binary=binary
)

dataset = torch_em.data.ImageCollectionDataset(
image_paths, seg_paths, patch_shape=patch_shape, label_dtype=label_dtype
)
return dataset


def get_livecell_loader(path, split, patch_shape, batch_size, download=False,
offsets=None, boundaries=False, binary=False,
cell_types=None, label_path=None, label_dtype=torch.int64, **kwargs):
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_livecell_dataset(
path, split, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary,
cell_types=cell_types, label_path=label_path, label_dtype=label_dtype, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader
29 changes: 25 additions & 4 deletions torch_em/data/datasets/neurips_cell_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def _get_image_and_label_paths(root, split, val_fraction):
return image_paths, label_paths


def get_neurips_cellseg_supervised_loader(
root, split,
patch_shape, batch_size,
def get_neurips_cellseg_supervised_dataset(
root, split, patch_shape,
make_rgb=True,
label_transform=None,
label_transform2=None,
Expand All @@ -75,7 +74,6 @@ def get_neurips_cellseg_supervised_loader(
n_samples=None,
sampler=None,
val_fraction=0.1,
**loader_kwargs
):
assert split in ("train", "val", None), split
image_paths, label_paths = _get_image_and_label_paths(root, split, val_fraction)
Expand All @@ -95,6 +93,29 @@ def get_neurips_cellseg_supervised_loader(
transform=transform,
n_samples=n_samples,
sampler=sampler)
return ds


def get_neurips_cellseg_supervised_loader(
root, split,
patch_shape, batch_size,
make_rgb=True,
label_transform=None,
label_transform2=None,
raw_transform=None,
transform=None,
label_dtype=torch.float32,
dtype=torch.float32,
n_samples=None,
sampler=None,
val_fraction=0.1,
**loader_kwargs
):
ds = get_neurips_cellseg_supervised_dataset(
root, split, patch_shape, make_rgb=make_rgb, label_transform=label_transform,
label_transform2=label_transform2, raw_transform=raw_transform, transform=transform,
label_dtype=label_dtype, dtype=dtype, n_samples=n_samples, sampler=sampler, val_fraction=val_fraction,
)
return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)


Expand Down
29 changes: 23 additions & 6 deletions torch_em/data/datasets/tissuenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import z5py

from tqdm import tqdm
from .util import unzip
from . import util


# Automated download is currently not possible, because of authentication
Expand Down Expand Up @@ -43,15 +43,16 @@ def _create_split(path, split):


def _create_dataset(path, zip_path):
unzip(zip_path, path, remove=False)
util.unzip(zip_path, path, remove=False)
splits = ["train", "val", "test"]
assert all([os.path.exists(os.path.join(path, f"tissuenet_v1.1_{split}.npz")) for split in splits])
for split in splits:
_create_split(path, split)


# TODO enable loading specific tissue types etc. (from the 'meta' attributes)
def get_tissuenet_loader(path, split, raw_channel, label_channel, download=False, **kwargs):
def get_tissuenet_dataset(
path, split, patch_shape, raw_channel, label_channel, download=False, **kwargs
):
assert raw_channel in ("nucleus", "cell", "rgb")
assert label_channel in ("nucleus", "cell")

Expand All @@ -76,7 +77,23 @@ def get_tissuenet_loader(path, split, raw_channel, label_channel, download=False
assert len(data_path) > 0

raw_key, label_key = f"raw/{raw_channel}", f"labels/{label_channel}"

kwargs = util.ensure_transforms(ndim=2, **kwargs)
with_channels = True if raw_channel == "rgb" else False
return torch_em.default_segmentation_loader(
data_path, raw_key, data_path, label_key, is_seg_dataset=True, ndim=2, with_channels=with_channels, **kwargs
kwargs = util.update_kwargs(kwargs, "with_channels", with_channels)
kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True)
kwargs = util.update_kwargs(kwargs, "ndim", 2)

return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)


# TODO enable loading specific tissue types etc. (from the 'meta' attributes)
def get_tissuenet_loader(
path, split, patch_shape, batch_size, raw_channel, label_channel, download=False, **kwargs
):
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_tissuenet_dataset(
path, split, patch_shape, raw_channel, label_channel, download, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader
Loading

0 comments on commit 620ab64

Please sign in to comment.