From 9be0934267493d33e839ce6c88b691115d0aa1b7 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 20 Sep 2022 15:17:44 +0000 Subject: [PATCH 01/11] allow channel be first --- test/data/test_image_collection_dataset.py | 49 ++++++++++++++++++++++ torch_em/data/image_collection_dataset.py | 9 +++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/test/data/test_image_collection_dataset.py b/test/data/test_image_collection_dataset.py index 40d1928b..7b84f7c5 100644 --- a/test/data/test_image_collection_dataset.py +++ b/test/data/test_image_collection_dataset.py @@ -3,6 +3,8 @@ from glob import glob from shutil import rmtree +import tifffile +import numpy as np from torch_em.util.test import create_image_collection_test_data @@ -37,5 +39,52 @@ def test_dataset(self): self.assertEqual(y.shape, expected_shape) +def generate_sample_data(folder, n_images, shape): + im_folder = os.path.join(folder, "images") + label_folder = os.path.join(folder, "labels") + os.makedirs(im_folder) + os.makedirs(label_folder) + for i in range(n_images): + raw = np.empty(shape, dtype=np.uint8) + label = np.ones(shape, dtype=np.float32) + tifffile.imwrite(os.path.join(im_folder, f"test_{i}.tif"), raw) + tifffile.imwrite(os.path.join(label_folder, f"test_{i}.tif"), label) + + + + + +class TestChannelsDataset(unittest.TestCase): + def test_channel_end(self): + from torch_em.data import ImageCollectionDataset + + patch_shape = (256, 256) + + with tempfile.TemporaryDirectory() as td: + im_folder = glob(os.path.join(td, "images", "*.tif")) + label_folder = glob(os.path.join(td, "labels", "*.tif")) + + generate_sample_data(td, 10, (64, 64, 2)) + ds = ImageCollectionDataset(raw_paths, label_paths, + patch_shape=patch_shape) + self.assertEqual(len(ds), 10) + self.assertEqual(ds._get_sample(0)[0].shape[0], 2) + + def test_channel_begin(self): + from torch_em.data import ImageCollectionDataset + + patch_shape = (256, 256) + + with tempfile.TemporaryDirectory() as td: + im_folder = glob(os.path.join(td, "images", "*.tif")) + label_folder = glob(os.path.join(td, "labels", "*.tif")) + + generate_sample_data(td, 10, (2, 64, 64)) + ds = ImageCollectionDataset(raw_paths, label_paths, + patch_shape=patch_shape) + self.assertEqual(len(ds), 10) + self.assertEqual(ds._get_sample(0)[0].shape[0], 2) + + if __name__ == '__main__': unittest.main() diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index a7d246fd..09fc6fbe 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -103,12 +103,17 @@ def _get_sample(self, index): shape = raw.shape # we assume images are loaded with channel last! if have_raw_channels: - shape = shape[:-1] + if shape[0] < 16: + shape = shape[1:] + prefix_box = (slice(None), ) + else: + shape = shape[:-1] + prefix_box = (, ) # sample random bounding box for this image bb = self._sample_bounding_box(shape) - raw = np.array(raw[bb]) + raw = np.array(raw[prefix_box + bb]) label = np.array(label[bb]) # to channel first From e5e5bdb516abdbbe97345446a6b360b3dd0edf68 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 20 Sep 2022 15:19:27 +0000 Subject: [PATCH 02/11] fix reorder --- torch_em/data/image_collection_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index 09fc6fbe..7310ab3c 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -117,7 +117,7 @@ def _get_sample(self, index): label = np.array(label[bb]) # to channel first - if have_raw_channels: + if have_raw_channels and len(prefix_box) == 0: raw = raw.transpose((2, 0, 1)) return raw, label From 96b9b7aff12a27e42e710e66edbc885e02849038 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 20 Sep 2022 15:21:05 +0000 Subject: [PATCH 03/11] fix formating --- test/data/test_image_collection_dataset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/data/test_image_collection_dataset.py b/test/data/test_image_collection_dataset.py index 7b84f7c5..89f63249 100644 --- a/test/data/test_image_collection_dataset.py +++ b/test/data/test_image_collection_dataset.py @@ -51,9 +51,6 @@ def generate_sample_data(folder, n_images, shape): tifffile.imwrite(os.path.join(label_folder, f"test_{i}.tif"), label) - - - class TestChannelsDataset(unittest.TestCase): def test_channel_end(self): from torch_em.data import ImageCollectionDataset From d4297983a1ea4a154453e59b83e044229eb20d1d Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 20 Sep 2022 15:26:41 +0000 Subject: [PATCH 04/11] fix definition of empty tuple --- torch_em/data/image_collection_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index 7310ab3c..0d5d9fe7 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -108,7 +108,7 @@ def _get_sample(self, index): prefix_box = (slice(None), ) else: shape = shape[:-1] - prefix_box = (, ) + prefix_box = tuple() # sample random bounding box for this image bb = self._sample_bounding_box(shape) From e094a861f6fff25a89c23f19858ede016af2bb69 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 20 Sep 2022 15:52:40 +0000 Subject: [PATCH 05/11] fix prefix_box variable initialization --- torch_em/data/image_collection_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index 0d5d9fe7..34087f2e 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -102,13 +102,13 @@ def _get_sample(self, index): shape = raw.shape # we assume images are loaded with channel last! + prefix_box = tuple() if have_raw_channels: if shape[0] < 16: shape = shape[1:] prefix_box = (slice(None), ) else: shape = shape[:-1] - prefix_box = tuple() # sample random bounding box for this image bb = self._sample_bounding_box(shape) From 1f3f6995330ac0572718c95ba5259346dfbf1902 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Tue, 20 Sep 2022 15:54:21 +0000 Subject: [PATCH 06/11] reorder for better consistencyt with data evaluation --- torch_em/data/image_collection_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index 34087f2e..e03403d8 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -104,11 +104,12 @@ def _get_sample(self, index): # we assume images are loaded with channel last! prefix_box = tuple() if have_raw_channels: - if shape[0] < 16: + if shape[-1] < 16: + shape = shape[:-1] + else: shape = shape[1:] prefix_box = (slice(None), ) - else: - shape = shape[:-1] + # sample random bounding box for this image bb = self._sample_bounding_box(shape) From 364a9dc7afb0fe942714290479cc8cc74f522800 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 21 Sep 2022 09:54:51 +0000 Subject: [PATCH 07/11] Add tempfile import --- test/data/test_image_collection_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/data/test_image_collection_dataset.py b/test/data/test_image_collection_dataset.py index 89f63249..5c789853 100644 --- a/test/data/test_image_collection_dataset.py +++ b/test/data/test_image_collection_dataset.py @@ -1,4 +1,5 @@ import os +import tempfile import unittest from glob import glob from shutil import rmtree From 05f61833adaf2b145ef7933aff98aef3fa7c7c0c Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 21 Sep 2022 09:57:34 +0000 Subject: [PATCH 08/11] improve comments about channel axis --- torch_em/data/image_collection_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index e03403d8..fae3f228 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -101,7 +101,8 @@ def _get_sample(self, index): raise NotImplementedError("Multi-channel labels are not supported.") shape = raw.shape - # we assume images are loaded with channel last! + # we determine if image has channels as te first or last axis base on array shape. + # This will work only for images with less than 16 channels. prefix_box = tuple() if have_raw_channels: if shape[-1] < 16: From 927ba96652d3894082e754f2ac0dd03f8d6c9156 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 21 Sep 2022 11:57:57 +0200 Subject: [PATCH 09/11] Update torch_em/data/image_collection_dataset.py Co-authored-by: Constantin Pape --- torch_em/data/image_collection_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index fae3f228..cc10b5bc 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -105,6 +105,9 @@ def _get_sample(self, index): # This will work only for images with less than 16 channels. prefix_box = tuple() if have_raw_channels: + # use heuristic to decide whether the data is stored in channel last or channel first order: + # if the last axis has a length smaller than 16 we assume that it's the channel axis, + # otherwise we assume it's a spatial axis and that the first axis is the channel axis. if shape[-1] < 16: shape = shape[:-1] else: From 79f4ffed6deb070e87618276ed0b0498c41c1cc0 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 21 Sep 2022 10:45:33 +0000 Subject: [PATCH 10/11] fix variable names --- test/data/test_image_collection_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/data/test_image_collection_dataset.py b/test/data/test_image_collection_dataset.py index 5c789853..f3518c9e 100644 --- a/test/data/test_image_collection_dataset.py +++ b/test/data/test_image_collection_dataset.py @@ -59,8 +59,8 @@ def test_channel_end(self): patch_shape = (256, 256) with tempfile.TemporaryDirectory() as td: - im_folder = glob(os.path.join(td, "images", "*.tif")) - label_folder = glob(os.path.join(td, "labels", "*.tif")) + raw_paths = glob(os.path.join(td, "images", "*.tif")) + label_paths = glob(os.path.join(td, "labels", "*.tif")) generate_sample_data(td, 10, (64, 64, 2)) ds = ImageCollectionDataset(raw_paths, label_paths, @@ -74,8 +74,8 @@ def test_channel_begin(self): patch_shape = (256, 256) with tempfile.TemporaryDirectory() as td: - im_folder = glob(os.path.join(td, "images", "*.tif")) - label_folder = glob(os.path.join(td, "labels", "*.tif")) + raw_paths = glob(os.path.join(td, "images", "*.tif")) + label_paths = glob(os.path.join(td, "labels", "*.tif")) generate_sample_data(td, 10, (2, 64, 64)) ds = ImageCollectionDataset(raw_paths, label_paths, From d6cc323215898bc1bd31e8e1c7bb205944b1cfc3 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Wed, 21 Sep 2022 15:26:58 +0200 Subject: [PATCH 11/11] fix test and better debug information --- test/data/test_image_collection_dataset.py | 11 ++++++----- torch_em/data/image_collection_dataset.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/test/data/test_image_collection_dataset.py b/test/data/test_image_collection_dataset.py index f3518c9e..21d8c8fd 100644 --- a/test/data/test_image_collection_dataset.py +++ b/test/data/test_image_collection_dataset.py @@ -40,14 +40,14 @@ def test_dataset(self): self.assertEqual(y.shape, expected_shape) -def generate_sample_data(folder, n_images, shape): +def generate_sample_data(folder, n_images, image_shape, label_shape): im_folder = os.path.join(folder, "images") label_folder = os.path.join(folder, "labels") os.makedirs(im_folder) os.makedirs(label_folder) for i in range(n_images): - raw = np.empty(shape, dtype=np.uint8) - label = np.ones(shape, dtype=np.float32) + raw = np.empty(image_shape, dtype=np.uint8) + label = np.ones(label_shape, dtype=np.float32) tifffile.imwrite(os.path.join(im_folder, f"test_{i}.tif"), raw) tifffile.imwrite(os.path.join(label_folder, f"test_{i}.tif"), label) @@ -59,10 +59,11 @@ def test_channel_end(self): patch_shape = (256, 256) with tempfile.TemporaryDirectory() as td: + generate_sample_data(td, 10, (256, 256, 2), (256, 256)) raw_paths = glob(os.path.join(td, "images", "*.tif")) label_paths = glob(os.path.join(td, "labels", "*.tif")) - generate_sample_data(td, 10, (64, 64, 2)) + ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape) self.assertEqual(len(ds), 10) @@ -74,10 +75,10 @@ def test_channel_begin(self): patch_shape = (256, 256) with tempfile.TemporaryDirectory() as td: + generate_sample_data(td, 10, (2, 256, 256), (256, 256)) raw_paths = glob(os.path.join(td, "images", "*.tif")) label_paths = glob(os.path.join(td, "labels", "*.tif")) - generate_sample_data(td, 10, (2, 64, 64)) ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape) self.assertEqual(len(ds), 10) diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index cc10b5bc..fa93710c 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -28,7 +28,13 @@ def _check_inputs(self, raw_images, label_images): # we assume axis last if is_multichan: - shape = shape[:-1] + # use heuristic to decide whether the data is stored in channel last or channel first order: + # if the last axis has a length smaller than 16 we assume that it's the channel axis, + # otherwise we assume it's a spatial axis and that the first axis is the channel axis. + if shape[-1] < 16: + shape = shape[:-1] + else: + shape = shape[1:] label_shape = load_image(label_im).shape if shape != label_shape: @@ -80,7 +86,7 @@ def ndim(self): def _sample_bounding_box(self, shape): if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): - raise NotImplementedError("Image padding is not supported yet.") + raise NotImplementedError("Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}") bb_start = [ np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape)