diff --git a/test/data/test_image_collection_dataset.py b/test/data/test_image_collection_dataset.py index 40d1928b..21d8c8fd 100644 --- a/test/data/test_image_collection_dataset.py +++ b/test/data/test_image_collection_dataset.py @@ -1,8 +1,11 @@ import os +import tempfile import unittest from glob import glob from shutil import rmtree +import tifffile +import numpy as np from torch_em.util.test import create_image_collection_test_data @@ -37,5 +40,50 @@ def test_dataset(self): self.assertEqual(y.shape, expected_shape) +def generate_sample_data(folder, n_images, image_shape, label_shape): + im_folder = os.path.join(folder, "images") + label_folder = os.path.join(folder, "labels") + os.makedirs(im_folder) + os.makedirs(label_folder) + for i in range(n_images): + raw = np.empty(image_shape, dtype=np.uint8) + label = np.ones(label_shape, dtype=np.float32) + tifffile.imwrite(os.path.join(im_folder, f"test_{i}.tif"), raw) + tifffile.imwrite(os.path.join(label_folder, f"test_{i}.tif"), label) + + +class TestChannelsDataset(unittest.TestCase): + def test_channel_end(self): + from torch_em.data import ImageCollectionDataset + + patch_shape = (256, 256) + + with tempfile.TemporaryDirectory() as td: + generate_sample_data(td, 10, (256, 256, 2), (256, 256)) + raw_paths = glob(os.path.join(td, "images", "*.tif")) + label_paths = glob(os.path.join(td, "labels", "*.tif")) + + + ds = ImageCollectionDataset(raw_paths, label_paths, + patch_shape=patch_shape) + self.assertEqual(len(ds), 10) + self.assertEqual(ds._get_sample(0)[0].shape[0], 2) + + def test_channel_begin(self): + from torch_em.data import ImageCollectionDataset + + patch_shape = (256, 256) + + with tempfile.TemporaryDirectory() as td: + generate_sample_data(td, 10, (2, 256, 256), (256, 256)) + raw_paths = glob(os.path.join(td, "images", "*.tif")) + label_paths = glob(os.path.join(td, "labels", "*.tif")) + + ds = ImageCollectionDataset(raw_paths, label_paths, + patch_shape=patch_shape) + self.assertEqual(len(ds), 10) + self.assertEqual(ds._get_sample(0)[0].shape[0], 2) + + if __name__ == '__main__': unittest.main() diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index a7d246fd..fa93710c 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -28,7 +28,13 @@ def _check_inputs(self, raw_images, label_images): # we assume axis last if is_multichan: - shape = shape[:-1] + # use heuristic to decide whether the data is stored in channel last or channel first order: + # if the last axis has a length smaller than 16 we assume that it's the channel axis, + # otherwise we assume it's a spatial axis and that the first axis is the channel axis. + if shape[-1] < 16: + shape = shape[:-1] + else: + shape = shape[1:] label_shape = load_image(label_im).shape if shape != label_shape: @@ -80,7 +86,7 @@ def ndim(self): def _sample_bounding_box(self, shape): if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): - raise NotImplementedError("Image padding is not supported yet.") + raise NotImplementedError("Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}") bb_start = [ np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape) @@ -101,18 +107,28 @@ def _get_sample(self, index): raise NotImplementedError("Multi-channel labels are not supported.") shape = raw.shape - # we assume images are loaded with channel last! + # we determine if image has channels as te first or last axis base on array shape. + # This will work only for images with less than 16 channels. + prefix_box = tuple() if have_raw_channels: - shape = shape[:-1] + # use heuristic to decide whether the data is stored in channel last or channel first order: + # if the last axis has a length smaller than 16 we assume that it's the channel axis, + # otherwise we assume it's a spatial axis and that the first axis is the channel axis. + if shape[-1] < 16: + shape = shape[:-1] + else: + shape = shape[1:] + prefix_box = (slice(None), ) + # sample random bounding box for this image bb = self._sample_bounding_box(shape) - raw = np.array(raw[bb]) + raw = np.array(raw[prefix_box + bb]) label = np.array(label[bb]) # to channel first - if have_raw_channels: + if have_raw_channels and len(prefix_box) == 0: raw = raw.transpose((2, 0, 1)) return raw, label