From 8f6b47cec075c1ac9a8d50f4b95472f776acff7e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 20 May 2023 22:32:25 +0200 Subject: [PATCH] Implement tif support in data loaders --- test/data/test_segmentation_dataset.py | 31 ++++++++++++++++++++++++-- torch_em/data/raw_dataset.py | 7 +++--- torch_em/data/segmentation_dataset.py | 11 +++++---- torch_em/segmentation.py | 19 ++++++++-------- torch_em/util/__init__.py | 2 +- torch_em/util/image.py | 20 ++++++++++++----- 6 files changed, 61 insertions(+), 29 deletions(-) diff --git a/test/data/test_segmentation_dataset.py b/test/data/test_segmentation_dataset.py index 6e7e71d4..5f972fb2 100644 --- a/test/data/test_segmentation_dataset.py +++ b/test/data/test_segmentation_dataset.py @@ -1,5 +1,6 @@ import os import unittest +from shutil import rmtree import h5py import numpy as np @@ -7,10 +8,14 @@ class TestSegmentationDataset(unittest.TestCase): - path = "./data.h5" + tmp_folder = "./tmp" + path = "./tmp/data.h5" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) def tearDown(self): - os.remove(self.path) + rmtree(self.tmp_folder) def create_default_data(self, raw_key, label_key): shape = (128,) * 3 @@ -178,6 +183,28 @@ def test_with_raw_and_label_channels(self): self.assertEqual(x.shape, expected_raw_shape) self.assertEqual(y.shape, expected_label_shape) + def test_tif(self): + import imageio.v3 as imageio + from torch_em.data import SegmentationDataset + + raw_path = os.path.join(self.tmp_folder, "raw.tif") + label_path = os.path.join(self.tmp_folder, "labels.tif") + shape = (128, 128, 128) + imageio.imwrite(raw_path, np.random.rand(*shape).astype("float32")) + imageio.imwrite(label_path, np.random.rand(*shape).astype("float32")) + + patch_shape = (32, 32, 32) + raw_key, label_key = None, None + ds = SegmentationDataset( + raw_path, raw_key, label_path, label_key, patch_shape=patch_shape + ) + + expected_patch_shape = (1,) + patch_shape + for i in range(10): + x, y = ds[i] + self.assertEqual(x.shape, expected_patch_shape) + self.assertEqual(y.shape, expected_patch_shape) + if __name__ == "__main__": unittest.main() diff --git a/torch_em/data/raw_dataset.py b/torch_em/data/raw_dataset.py index d54fd083..0f128d00 100644 --- a/torch_em/data/raw_dataset.py +++ b/torch_em/data/raw_dataset.py @@ -2,10 +2,9 @@ import torch import numpy as np -from elf.io import open_file from elf.wrapper import RoiWrapper -from ..util import ensure_tensor_with_channels +from ..util import ensure_tensor_with_channels, load_data class RawDataset(torch.utils.data.Dataset): @@ -35,7 +34,7 @@ def __init__( ): self.raw_path = raw_path self.raw_key = raw_key - self.raw = open_file(raw_path, mode="r")[raw_key] + self.raw = load_data(raw_path, raw_key) self._with_channels = with_channels @@ -153,7 +152,7 @@ def __setstate__(self, state): raw_path, raw_key = state["raw_path"], state["raw_key"] roi = state["roi"] try: - raw = open_file(state["raw_path"], mode="r")[state["raw_key"]] + raw = load_data(raw_path, raw_key) if roi is not None: raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) state["raw"] = raw diff --git a/torch_em/data/segmentation_dataset.py b/torch_em/data/segmentation_dataset.py index 0d318098..08c15512 100644 --- a/torch_em/data/segmentation_dataset.py +++ b/torch_em/data/segmentation_dataset.py @@ -2,10 +2,9 @@ import torch import numpy as np -from elf.io import open_file from elf.wrapper import RoiWrapper -from ..util import ensure_spatial_array, ensure_tensor_with_channels +from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data class SegmentationDataset(torch.utils.data.Dataset): @@ -40,11 +39,11 @@ def __init__( ): self.raw_path = raw_path self.raw_key = raw_key - self.raw = open_file(raw_path, mode="r")[raw_key] + self.raw = load_data(raw_path, raw_key) self.label_path = label_path self.label_key = label_key - self.labels = open_file(label_path, mode="r")[label_key] + self.labels = load_data(label_path, label_key) self._with_channels = with_channels self._with_label_channels = with_label_channels @@ -175,7 +174,7 @@ def __setstate__(self, state): label_path, label_key = state["label_path"], state["label_key"] roi = state["roi"] try: - raw = open_file(raw_path, mode="r")[raw_key] + raw = load_data(raw_path, raw_key) if roi is not None: raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi) state["raw"] = raw @@ -187,7 +186,7 @@ def __setstate__(self, state): state["raw"] = None try: - labels = open_file(label_path, mode="r")[label_key] + labels = load_data(label_path, label_key) if roi is not None: labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\ RoiWrapper(labels, roi) diff --git a/torch_em/segmentation.py b/torch_em/segmentation.py index 3105253b..b7b562b1 100644 --- a/torch_em/segmentation.py +++ b/torch_em/segmentation.py @@ -4,13 +4,13 @@ import torch import torch.utils.data -from elf.io import open_file from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset from .loss import DiceLoss from .trainer import DefaultTrainer from .trainer.tensorboard_logger import TensorboardLogger from .transform import get_augmentations, get_raw_transform +from .util import load_data # TODO add a heuristic to estimate this from the number of epochs @@ -61,7 +61,7 @@ def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): def _can_open(path, key): try: - open_file(path, mode="r")[key] + load_data(path, key) return True except Exception: return False @@ -165,14 +165,13 @@ def _get_paths(rpath, rkey, lpath, lkey, this_roi): def _get_default_transform(path, key, is_seg_dataset, ndim): if is_seg_dataset and ndim is None: - with open_file(path, mode="r") as f: - shape = f[key].shape - if len(shape) == 2: - ndim = 2 - else: - # heuristics to figure out whether to use default 3d - # or default anisotropic augmentations - ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 + shape = load_data(path, key).shape + if len(shape) == 2: + ndim = 2 + else: + # heuristics to figure out whether to use default 3d + # or default anisotropic augmentations + ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 elif is_seg_dataset and ndim is not None: pass else: diff --git a/torch_em/util/__init__.py b/torch_em/util/__init__.py index f8bbd74d..dc92a521 100644 --- a/torch_em/util/__init__.py +++ b/torch_em/util/__init__.py @@ -1,4 +1,4 @@ -from .image import load_image, supports_memmap +from .image import load_data, load_image, supports_memmap from .reporting import get_training_summary from .training import parser_helper from .util import (auto_compile, ensure_array, ensure_spatial_array, diff --git a/torch_em/util/image.py b/torch_em/util/image.py index 790621a8..0bf148e4 100644 --- a/torch_em/util/image.py +++ b/torch_em/util/image.py @@ -1,8 +1,10 @@ # TODO this should be partially refactored into elf.io before the next elf release # and then be used in image_stack_wrapper as welll import os + +from elf.io import open_file try: - import imageio.v2 as imageio + import imageio.v3 as imageio except ImportError: import imageio @@ -11,7 +13,7 @@ except ImportError: tifffile = None -TIF_EXTS = ('.tif', '.tiff') +TIF_EXTS = (".tif", ".tiff") def supports_memmap(image_path): @@ -21,7 +23,7 @@ def supports_memmap(image_path): if ext.lower() not in TIF_EXTS: return False try: - tifffile.memmap(image_path, mode='r') + tifffile.memmap(image_path, mode="r") except ValueError: return False return True @@ -29,9 +31,15 @@ def supports_memmap(image_path): def load_image(image_path): if supports_memmap(image_path): - return tifffile.memmap(image_path, mode='r') - elif tifffile is not None and os.path.splitext(image_path)[1].lower() in {".tiff", ".tif"}: + return tifffile.memmap(image_path, mode="r") + elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"): return tifffile.imread(image_path) else: - # TODO handle multi-channel images return imageio.imread(image_path) + + +def load_data(path, key, mode="r"): + if key is None: + return load_image(path) + else: + return open_file(path, mode=mode)[key]