Skip to content

Commit

Permalink
Update tissuenet dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 14, 2023
1 parent 34e9d9c commit 70a3629
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions torch_em/data/datasets/tissuenet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from glob import glob

import z5py
import numpy as np
import pandas as pd
import torch_em
from tqdm import tqdm
import z5py

from tqdm import tqdm
from .util import unzip


Expand All @@ -17,16 +18,27 @@ def _create_split(path, split):
split_file = os.path.join(path, f"tissuenet_v1.1_{split}.npz")
split_folder = os.path.join(path, split)
os.makedirs(split_folder, exist_ok=True)
data = np.load(split_file)
data = np.load(split_file, allow_pickle=True)

x, y = data["X"], data["y"]
metadata = data["meta"]
metadata = pd.DataFrame(metadata[1:], columns=metadata[0])

for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"):
out_path = os.path.join(split_folder, f"image_{i:04}.n5")
out_path = os.path.join(split_folder, f"image_{i:04}.zarr")
nucleus_channel = im[..., 0]
cell_channel = im[..., 1]
rgb = np.stack([np.zeros_like(nucleus_channel), cell_channel, nucleus_channel])
chunks = cell_channel.shape
with z5py.File(out_path, "a") as f:
f.create_dataset("raw/nucleus", data=im[..., 0], compression="gzip", chunks=im[..., 0].shape)
f.create_dataset("raw/cell", data=im[..., 1], compression="gzip", chunks=im[..., 1].shape)
# the swithh 0<->1 is intentional, the data format is chaotic...
f.create_dataset("labels/nucleus", data=label[..., 1], compression="gzip", chunks=label[..., 1].shape)
f.create_dataset("labels/cell", data=label[..., 0], compression="gzip", chunks=label[..., 0].shape)

f.create_dataset("raw/nucleus", data=im[..., 0], compression="gzip", chunks=chunks)
f.create_dataset("raw/cell", data=cell_channel, compression="gzip", chunks=chunks)
f.create_dataset("raw/rgb", data=rgb, compression="gzip", chunks=(3,) + chunks)

# the switch 0<->1 is intentional, the data format is chaotic...
f.create_dataset("labels/nucleus", data=label[..., 1], compression="gzip", chunks=chunks)
f.create_dataset("labels/cell", data=label[..., 0], compression="gzip", chunks=chunks)
os.remove(split_file)


Expand All @@ -38,7 +50,11 @@ def _create_dataset(path, zip_path):
_create_split(path, split)


def get_tissuenet_loader(path, split, mode, download=False, **kwargs):
# TODO enable loading specific tissue types etc. (from the 'meta' attributes)
def get_tissuenet_loader(path, split, raw_channel, label_channel, download=False, **kwargs):
assert raw_channel in ("nucleus", "cell", "rgb")
assert label_channel in ("nucleus", "cell")

splits = ["train", "val", "test"]
assert split in splits

Expand All @@ -56,12 +72,11 @@ def get_tissuenet_loader(path, split, mode, download=False, **kwargs):

split_folder = os.path.join(path, split)
assert os.path.exists(split_folder)
data_path = glob(os.path.join(split_folder, "*.n5"))
data_path = glob(os.path.join(split_folder, "*.zarr"))
assert len(data_path) > 0
print(len(data_path))

assert mode in ["nucleus", "cell"], f"Got {mode}"
raw_key, label_key = f"raw/{mode}", f"labels/{mode}"
raw_key, label_key = f"raw/{raw_channel}", f"labels/{label_channel}"
with_channels = True if raw_channel == "rgb" else False
return torch_em.default_segmentation_loader(
data_path, raw_key, data_path, label_key, is_seg_dataset=True, ndim=2, **kwargs
data_path, raw_key, data_path, label_key, is_seg_dataset=True, ndim=2, with_channels=with_channels, **kwargs
)

0 comments on commit 70a3629

Please sign in to comment.