Skip to content

Commit

Permalink
Merge pull request #130 from constantinpape/tif-support
Browse files Browse the repository at this point in the history
Implement tif support in data loaders
  • Loading branch information
constantinpape committed May 20, 2023
2 parents 89ce03d + 8f6b47c commit d0847a8
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 29 deletions.
31 changes: 29 additions & 2 deletions test/data/test_segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
import unittest
from shutil import rmtree

import h5py
import numpy as np
from torch_em.util.test import create_segmentation_test_data


class TestSegmentationDataset(unittest.TestCase):
path = "./data.h5"
tmp_folder = "./tmp"
path = "./tmp/data.h5"

def setUp(self):
os.makedirs(self.tmp_folder, exist_ok=True)

def tearDown(self):
os.remove(self.path)
rmtree(self.tmp_folder)

def create_default_data(self, raw_key, label_key):
shape = (128,) * 3
Expand Down Expand Up @@ -178,6 +183,28 @@ def test_with_raw_and_label_channels(self):
self.assertEqual(x.shape, expected_raw_shape)
self.assertEqual(y.shape, expected_label_shape)

def test_tif(self):
import imageio.v3 as imageio
from torch_em.data import SegmentationDataset

raw_path = os.path.join(self.tmp_folder, "raw.tif")
label_path = os.path.join(self.tmp_folder, "labels.tif")
shape = (128, 128, 128)
imageio.imwrite(raw_path, np.random.rand(*shape).astype("float32"))
imageio.imwrite(label_path, np.random.rand(*shape).astype("float32"))

patch_shape = (32, 32, 32)
raw_key, label_key = None, None
ds = SegmentationDataset(
raw_path, raw_key, label_path, label_key, patch_shape=patch_shape
)

expected_patch_shape = (1,) + patch_shape
for i in range(10):
x, y = ds[i]
self.assertEqual(x.shape, expected_patch_shape)
self.assertEqual(y.shape, expected_patch_shape)


if __name__ == "__main__":
unittest.main()
7 changes: 3 additions & 4 deletions torch_em/data/raw_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import torch
import numpy as np
from elf.io import open_file
from elf.wrapper import RoiWrapper

from ..util import ensure_tensor_with_channels
from ..util import ensure_tensor_with_channels, load_data


class RawDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -35,7 +34,7 @@ def __init__(
):
self.raw_path = raw_path
self.raw_key = raw_key
self.raw = open_file(raw_path, mode="r")[raw_key]
self.raw = load_data(raw_path, raw_key)

self._with_channels = with_channels

Expand Down Expand Up @@ -153,7 +152,7 @@ def __setstate__(self, state):
raw_path, raw_key = state["raw_path"], state["raw_key"]
roi = state["roi"]
try:
raw = open_file(state["raw_path"], mode="r")[state["raw_key"]]
raw = load_data(raw_path, raw_key)
if roi is not None:
raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
state["raw"] = raw
Expand Down
11 changes: 5 additions & 6 deletions torch_em/data/segmentation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import torch
import numpy as np
from elf.io import open_file
from elf.wrapper import RoiWrapper

from ..util import ensure_spatial_array, ensure_tensor_with_channels
from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data


class SegmentationDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -40,11 +39,11 @@ def __init__(
):
self.raw_path = raw_path
self.raw_key = raw_key
self.raw = open_file(raw_path, mode="r")[raw_key]
self.raw = load_data(raw_path, raw_key)

self.label_path = label_path
self.label_key = label_key
self.labels = open_file(label_path, mode="r")[label_key]
self.labels = load_data(label_path, label_key)

self._with_channels = with_channels
self._with_label_channels = with_label_channels
Expand Down Expand Up @@ -175,7 +174,7 @@ def __setstate__(self, state):
label_path, label_key = state["label_path"], state["label_key"]
roi = state["roi"]
try:
raw = open_file(raw_path, mode="r")[raw_key]
raw = load_data(raw_path, raw_key)
if roi is not None:
raw = RoiWrapper(raw, (slice(None),) + roi) if state["_with_channels"] else RoiWrapper(raw, roi)
state["raw"] = raw
Expand All @@ -187,7 +186,7 @@ def __setstate__(self, state):
state["raw"] = None

try:
labels = open_file(label_path, mode="r")[label_key]
labels = load_data(label_path, label_key)
if roi is not None:
labels = RoiWrapper(labels, (slice(None),) + roi) if state["_with_label_channels"] else\
RoiWrapper(labels, roi)
Expand Down
19 changes: 9 additions & 10 deletions torch_em/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import torch
import torch.utils.data
from elf.io import open_file

from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
from .loss import DiceLoss
from .trainer import DefaultTrainer
from .trainer.tensorboard_logger import TensorboardLogger
from .transform import get_augmentations, get_raw_transform
from .util import load_data


# TODO add a heuristic to estimate this from the number of epochs
Expand Down Expand Up @@ -61,7 +61,7 @@ def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):

def _can_open(path, key):
try:
open_file(path, mode="r")[key]
load_data(path, key)
return True
except Exception:
return False
Expand Down Expand Up @@ -165,14 +165,13 @@ def _get_paths(rpath, rkey, lpath, lkey, this_roi):

def _get_default_transform(path, key, is_seg_dataset, ndim):
if is_seg_dataset and ndim is None:
with open_file(path, mode="r") as f:
shape = f[key].shape
if len(shape) == 2:
ndim = 2
else:
# heuristics to figure out whether to use default 3d
# or default anisotropic augmentations
ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
shape = load_data(path, key).shape
if len(shape) == 2:
ndim = 2
else:
# heuristics to figure out whether to use default 3d
# or default anisotropic augmentations
ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
elif is_seg_dataset and ndim is not None:
pass
else:
Expand Down
2 changes: 1 addition & 1 deletion torch_em/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .image import load_image, supports_memmap
from .image import load_data, load_image, supports_memmap
from .reporting import get_training_summary
from .training import parser_helper
from .util import (auto_compile, ensure_array, ensure_spatial_array,
Expand Down
20 changes: 14 additions & 6 deletions torch_em/util/image.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# TODO this should be partially refactored into elf.io before the next elf release
# and then be used in image_stack_wrapper as welll
import os

from elf.io import open_file
try:
import imageio.v2 as imageio
import imageio.v3 as imageio
except ImportError:
import imageio

Expand All @@ -11,7 +13,7 @@
except ImportError:
tifffile = None

TIF_EXTS = ('.tif', '.tiff')
TIF_EXTS = (".tif", ".tiff")


def supports_memmap(image_path):
Expand All @@ -21,17 +23,23 @@ def supports_memmap(image_path):
if ext.lower() not in TIF_EXTS:
return False
try:
tifffile.memmap(image_path, mode='r')
tifffile.memmap(image_path, mode="r")
except ValueError:
return False
return True


def load_image(image_path):
if supports_memmap(image_path):
return tifffile.memmap(image_path, mode='r')
elif tifffile is not None and os.path.splitext(image_path)[1].lower() in {".tiff", ".tif"}:
return tifffile.memmap(image_path, mode="r")
elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"):
return tifffile.imread(image_path)
else:
# TODO handle multi-channel images
return imageio.imread(image_path)


def load_data(path, key, mode="r"):
if key is None:
return load_image(path)
else:
return open_file(path, mode=mode)[key]

0 comments on commit d0847a8

Please sign in to comment.