Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add neurips cell seg dataset; update raw image collection dataset; up… #111

Merged
merged 1 commit into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions experiments/neurips-cell-seg/check_neurips_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from torch_em.data.datasets import get_neurips_cellseg_supervised_loader, get_neurips_cellseg_unsupervised_loader
from torch_em.util.debug import check_loader


def check_supervised(split, n_images=6):
root = "/home/pape/Work/data/neurips-cell-seg"
patch_shape = [384, 384]
loader = get_neurips_cellseg_supervised_loader(root, split, patch_shape, batch_size=1)
check_loader(loader, n_images, instance_labels=True, rgb=True)


def check_unsupervised(n_images=10):
root = "/home/pape/Work/data/neurips-cell-seg"
patch_shape = [384, 384]
loader = get_neurips_cellseg_unsupervised_loader(root, patch_shape, batch_size=1,
use_images=True, use_wholeslide=True)
check_loader(loader, n_images, rgb=True)


def main():
check_supervised("train")
check_supervised("val")

check_unsupervised()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions torch_em/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .mitoem import get_mitoem_loader
from .monuseg import get_monuseg_loader
from .mouse_embryo import get_mouse_embryo_loader
from .neurips_cell_seg import get_neurips_cellseg_supervised_loader, get_neurips_cellseg_unsupervised_loader
from .plantseg import get_plantseg_loader
from .platynereis import (get_platynereis_cell_loader,
get_platynereis_nuclei_loader)
Expand Down
166 changes: 166 additions & 0 deletions torch_em/data/datasets/neurips_cell_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import json
import os
from glob import glob

import numpy as np
import torch
import torch_em


def to_rgb(image):
if image.ndim == 2:
image = np.concatenate([image[None]] * 3, axis=0)
assert image.ndim == 3
assert image.shape[0] == 3, f"{image.shape}"
return image


# would be better to make balanced splits for the different data modalities
# (but we would need to know mapping of images to modality)
def _get_image_and_label_paths(root, split, val_fraction):
path = os.path.join(root, "TrainLabeled")
assert os.path.exists(root), root

image_folder = os.path.join(path, "images")
assert os.path.exists(image_folder)
label_folder = os.path.join(path, "labels")
assert os.path.exists(label_folder)

all_image_paths = glob(os.path.join(image_folder, "*"))
all_image_paths.sort()
all_label_paths = glob(os.path.join(label_folder, "*"))
all_label_paths.sort()
assert len(all_image_paths) == len(all_label_paths)

if split is None:
return all_image_paths, all_label_paths

split_file = os.path.join(
os.path.split(__file__)[0], f"split_{val_fraction}.json"
)

if os.path.exists(split_file):
with open(split_file) as f:
split_ids = json.load(f)[split]
else:
# split into training and val images
n_images = len(all_image_paths)
n_train = int((1.0 - val_fraction) * n_images)
image_ids = list(range(n_images))
np.random.shuffle(image_ids)
train_ids, val_ids = image_ids[:n_train], image_ids[n_train:]
assert len(train_ids) + len(val_ids) == n_images

with open(split_file, "w") as f:
json.dump({"train": train_ids, "val": val_ids}, f)

split_ids = val_ids if split == "val" else train_ids

image_paths = [all_image_paths[idx] for idx in split_ids]
label_paths = [all_label_paths[idx] for idx in split_ids]
assert len(image_paths) == len(label_paths)
return image_paths, label_paths


def get_neurips_cellseg_supervised_loader(
root, split,
patch_shape, batch_size,
make_rgb=True,
label_transform=None,
label_transform2=None,
raw_transform=None,
transform=None,
label_dtype=torch.float32,
dtype=torch.float32,
n_samples=None,
sampler=None,
val_fraction=0.1,
**loader_kwargs
):
assert split in ("train", "val", None), split
image_paths, label_paths = _get_image_and_label_paths(root, split, val_fraction)

if raw_transform is None:
trafo = to_rgb if make_rgb else None
raw_transform = torch_em.transform.get_raw_transform(augmentation2=trafo)
if transform is None:
transform = torch_em.transform.get_augmentations(ndim=2)

ds = torch_em.data.ImageCollectionDataset(image_paths, label_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
label_transform=label_transform,
label_transform2=label_transform2,
label_dtype=label_dtype,
transform=transform,
n_samples=n_samples,
sampler=sampler)
return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)


def _get_image_paths(root):
path = os.path.join(root, "TrainUnlabeled")
assert os.path.exists(path), path
image_paths = glob(os.path.join(path, "*"))
image_paths.sort()
return image_paths


def _get_wholeslide_paths(root, patch_shape):
path = os.path.join(root, "TrainUnlabeled_WholeSlide")
assert os.path.exists(path), path
image_paths = glob(os.path.join(path, "*"))
image_paths.sort()

# one of the whole slides doesn't support memmap which will make it very slow to load
image_paths = [path for path in image_paths if torch_em.util.supports_memmap(path)]
assert len(image_paths) > 0

n_samples = 0
for im_path in image_paths:
shape = torch_em.util.load_image(im_path).shape
assert len(shape) == 3 and shape[-1] == 3
im_shape = shape[:2]
n_samples += np.prod([sh // psh for sh, psh in zip(im_shape, patch_shape)])

return image_paths, n_samples


def get_neurips_cellseg_unsupervised_loader(
root, patch_shape, batch_size,
make_rgb=True,
raw_transform=None,
transform=None,
dtype=torch.float32,
sampler=None,
use_images=True,
use_wholeslide=True,
**loader_kwargs,
):
if raw_transform is None:
trafo = to_rgb if make_rgb else None
raw_transform = torch_em.transform.get_raw_transform(augmentation2=trafo)
if transform is None:
transform = torch_em.transform.get_augmentations(ndim=2)

datasets = []
if use_images:
image_paths = _get_image_paths(root)
datasets.append(torch_em.data.RawImageCollectionDataset(image_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
dtype=dtype,
sampler=sampler))
if use_wholeslide:
image_paths, n_samples = _get_wholeslide_paths(root, patch_shape)
datasets.append(torch_em.data.RawImageCollectionDataset(image_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
dtype=dtype,
n_samples=n_samples,
sampler=sampler))
assert len(datasets) > 0
ds = torch.utils.data.ConcatDataset(datasets)
return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)
14 changes: 13 additions & 1 deletion torch_em/data/raw_image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

# TODO pad images that are too small for the patch shape
class RawImageCollectionDataset(torch.utils.data.Dataset):
max_sampling_attempts = 500

def _check_inputs(self, raw_images):
is_multichan = None
for raw_im in raw_images:

# we only check for compatible shapes if both images support memmap, because
# we only check for compatible shapes if images support memmap, because
# we don't want to load everything into ram
if supports_memmap(raw_im):
shape = load_image(raw_im).shape
Expand All @@ -34,6 +35,7 @@ def __init__(
transform=None,
dtype=torch.float32,
n_samples=None,
sampler=None,
):
self._check_inputs(raw_image_paths)
self.raw_images = raw_image_paths
Expand All @@ -45,6 +47,7 @@ def __init__(
self.raw_transform = raw_transform
self.transform = transform
self.dtype = dtype
self.sampler = sampler

if n_samples is None:
self._len = len(self.raw_images)
Expand Down Expand Up @@ -84,6 +87,15 @@ def _get_sample(self, index):
bb = self._sample_bounding_box(shape)
raw = np.array(raw[bb])

if self.sampler is not None:
sample_id = 0
while not self.sampler(raw):
bb = self._sample_bounding_box(shape)
raw = np.array(raw[bb])
sample_id += 1
if sample_id > self.max_sampling_attempts:
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")

# to channel first
if have_raw_channels:
raw = raw.transpose((2, 0, 1))
Expand Down
28 changes: 19 additions & 9 deletions torch_em/util/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,38 @@ def to_index(ns, rid, sid):
plt.show()


def _check_napari(loader, n_samples, instance_labels, model=None, device=None):
def _check_napari(loader, n_samples, instance_labels, model=None, device=None, rgb=False):
import napari

for ii, (x, y) in enumerate(loader):
for ii, sample in enumerate(loader):
if ii >= n_samples:
break

try:
x, y = sample
except ValueError:
x = sample
y = None

if model is None:
pred = None
else:
pred = model(x if device is None else x.to(device))
pred = ensure_array(pred)[0]

x = ensure_array(x)[0]
y = ensure_array(y)[0]
if rgb:
assert x.shape[0] == 3
x = x.transpose((1, 2, 0))

v = napari.Viewer()
v.add_image(x)
if instance_labels:
v.add_labels(y.astype("uint32"))
else:
v.add_image(y)
if y is not None:
y = ensure_array(y)[0]
if instance_labels:
v.add_labels(y.astype("uint32"))
else:
v.add_image(y)
if pred is not None:
v.add_image(pred)
napari.run()
Expand All @@ -108,8 +118,8 @@ def check_trainer(trainer, n_samples, instance_labels=False, split="val", loader
_check_napari(loader, n_samples, instance_labels, model=model, device=trainer.device)


def check_loader(loader, n_samples, instance_labels=False, plt=False):
def check_loader(loader, n_samples, instance_labels=False, plt=False, rgb=False):
if plt:
_check_plt(loader, n_samples, instance_labels)
else:
_check_napari(loader, n_samples, instance_labels)
_check_napari(loader, n_samples, instance_labels, rgb=rgb)