Skip to content

Commit

Permalink
Merge pull request #94 from Czaki/read_CXY_data
Browse files Browse the repository at this point in the history
When loading data for train allow to read data that has channel on first position
  • Loading branch information
constantinpape committed Sep 21, 2022
2 parents 1c89315 + d6cc323 commit b9db670
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
48 changes: 48 additions & 0 deletions test/data/test_image_collection_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import tempfile
import unittest
from glob import glob
from shutil import rmtree

import tifffile
import numpy as np
from torch_em.util.test import create_image_collection_test_data


Expand Down Expand Up @@ -37,5 +40,50 @@ def test_dataset(self):
self.assertEqual(y.shape, expected_shape)


def generate_sample_data(folder, n_images, image_shape, label_shape):
im_folder = os.path.join(folder, "images")
label_folder = os.path.join(folder, "labels")
os.makedirs(im_folder)
os.makedirs(label_folder)
for i in range(n_images):
raw = np.empty(image_shape, dtype=np.uint8)
label = np.ones(label_shape, dtype=np.float32)
tifffile.imwrite(os.path.join(im_folder, f"test_{i}.tif"), raw)
tifffile.imwrite(os.path.join(label_folder, f"test_{i}.tif"), label)


class TestChannelsDataset(unittest.TestCase):
def test_channel_end(self):
from torch_em.data import ImageCollectionDataset

patch_shape = (256, 256)

with tempfile.TemporaryDirectory() as td:
generate_sample_data(td, 10, (256, 256, 2), (256, 256))
raw_paths = glob(os.path.join(td, "images", "*.tif"))
label_paths = glob(os.path.join(td, "labels", "*.tif"))


ds = ImageCollectionDataset(raw_paths, label_paths,
patch_shape=patch_shape)
self.assertEqual(len(ds), 10)
self.assertEqual(ds._get_sample(0)[0].shape[0], 2)

def test_channel_begin(self):
from torch_em.data import ImageCollectionDataset

patch_shape = (256, 256)

with tempfile.TemporaryDirectory() as td:
generate_sample_data(td, 10, (2, 256, 256), (256, 256))
raw_paths = glob(os.path.join(td, "images", "*.tif"))
label_paths = glob(os.path.join(td, "labels", "*.tif"))

ds = ImageCollectionDataset(raw_paths, label_paths,
patch_shape=patch_shape)
self.assertEqual(len(ds), 10)
self.assertEqual(ds._get_sample(0)[0].shape[0], 2)


if __name__ == '__main__':
unittest.main()
28 changes: 22 additions & 6 deletions torch_em/data/image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ def _check_inputs(self, raw_images, label_images):

# we assume axis last
if is_multichan:
shape = shape[:-1]
# use heuristic to decide whether the data is stored in channel last or channel first order:
# if the last axis has a length smaller than 16 we assume that it's the channel axis,
# otherwise we assume it's a spatial axis and that the first axis is the channel axis.
if shape[-1] < 16:
shape = shape[:-1]
else:
shape = shape[1:]

label_shape = load_image(label_im).shape
if shape != label_shape:
Expand Down Expand Up @@ -80,7 +86,7 @@ def ndim(self):

def _sample_bounding_box(self, shape):
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
raise NotImplementedError("Image padding is not supported yet.")
raise NotImplementedError("Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}")
bb_start = [
np.random.randint(0, sh - psh) if sh - psh > 0 else 0
for sh, psh in zip(shape, self.patch_shape)
Expand All @@ -101,18 +107,28 @@ def _get_sample(self, index):
raise NotImplementedError("Multi-channel labels are not supported.")

shape = raw.shape
# we assume images are loaded with channel last!
# we determine if image has channels as te first or last axis base on array shape.
# This will work only for images with less than 16 channels.
prefix_box = tuple()
if have_raw_channels:
shape = shape[:-1]
# use heuristic to decide whether the data is stored in channel last or channel first order:
# if the last axis has a length smaller than 16 we assume that it's the channel axis,
# otherwise we assume it's a spatial axis and that the first axis is the channel axis.
if shape[-1] < 16:
shape = shape[:-1]
else:
shape = shape[1:]
prefix_box = (slice(None), )


# sample random bounding box for this image
bb = self._sample_bounding_box(shape)

raw = np.array(raw[bb])
raw = np.array(raw[prefix_box + bb])
label = np.array(label[bb])

# to channel first
if have_raw_channels:
if have_raw_channels and len(prefix_box) == 0:
raw = raw.transpose((2, 0, 1))

return raw, label
Expand Down

0 comments on commit b9db670

Please sign in to comment.