Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Add doc strings and tests for volume loader
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 26, 2018
1 parent 0cbaf1d commit 175af35
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
52 changes: 52 additions & 0 deletions inferno/io/volumetric/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@


class VolumeLoader(SyncableDataset):
""" Loader for in-memory volumetric data.
Parameters
----------
volume: np.ndarray
the volumetric data
window_size: list or tuple
size of the (3d) sliding window used for iteration
stride: list or tuple
stride of the (3d) sliding window used for iteration
downsampling_ratio: list or tuple (default: None)
factor by which the data is downsampled (no downsapling by default)
padding: list (default: None)
padding for data, follows np.pad syntax
padding_mode: str (default: 'reflect')
padding mode as in np.pad
transforms: callable (default: None)
transforms applied on each batch loaded from volume
return_index_spec: bool (default: False)
whether to return the index spec for each batch
name: str (default: None)
name of this volume
is_multichannel: bool (default: False)
is this a multichannel volume? sliding window is NOT applied to channel dimension
"""

def __init__(self, volume, window_size, stride, downsampling_ratio=None, padding=None,
padding_mode='reflect', transforms=None, return_index_spec=False, name=None,
is_multichannel=False):
Expand Down Expand Up @@ -125,6 +151,32 @@ def __repr__(self):


class HDF5VolumeLoader(VolumeLoader):
""" Loader for volumes stored in hdf5, zarr or n5.
Zarr and n5 are file formats very similar to hdf5, but use
the regular filesystem to store data instead of a filesystem
in a file as hdf5.
The file type will be infered from the extension:
.hdf5, .h5 and .hdf map to hdf5
.n5 maps to n5
.zr and .zarr map to zarr
It will fail for other extensions.
Parameters
----------
path: str
path to file
path_in_h5_dataset: str (default: None)
path in file
data_slice: slice (default: None)
slice loaded from dataset
transforms: callable (default: None)
transforms applied on each batch loaded from volume
name: str (default: None)
name of this volume
slicing_config: kwargs
keyword arguments for base class `VolumeLoader`
"""

@staticmethod
def is_h5(file_path):
Expand Down
57 changes: 57 additions & 0 deletions tests/test_io/test_volumetric/test_volume_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
import os
from shutil import rmtree

import numpy as np
import h5py


class TestVolumeLoader(unittest.TestCase):
shape = (100, 100, 100)
def setUp(self):
self.data = np.random.rand(*self.shape)

def test_loader(self):
from inferno.io.volumetric import VolumeLoader
loader = VolumeLoader(self.data,
window_size=(10, 10, 10),
stride=(10, 10, 10), return_index_spec=True)
for batch, idx in loader:
slice_ = loader.base_sequence[int(idx)]
expected = self.data[slice_]
self.assertEqual(batch.shape, expected.shape)
self.assertTrue(np.allclose(batch, expected))


class TestHDF5VolumeLoader(unittest.TestCase):
shape = (100, 100, 100)
def setUp(self):
try:
os.mkdir('./tmp')
except OSError:
pass
self.data = np.random.rand(*self.shape)
with h5py.File('./tmp/data.h5') as f:
f.create_dataset('data', data=self.data)

def tearDown(self):
try:
rmtree('./tmp')
except OSError:
pass

def test_hdf5_loader(self):
from inferno.io.volumetric import HDF5VolumeLoader
loader = HDF5VolumeLoader('./tmp/data.h5', 'data',
window_size=(10, 10, 10),
stride=(10, 10, 10), return_index_spec=True)
for batch, idx in loader:
slice_ = loader.base_sequence[int(idx)]
expected = self.data[slice_]
self.assertEqual(batch.shape, expected.shape)
self.assertTrue(np.allclose(batch, expected))



if __name__ == '__main__':
unittest.main()

0 comments on commit 175af35

Please sign in to comment.