Skip to content

Commit

Permalink
Update raw transforms and implement tests
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed May 21, 2023
1 parent 4359656 commit d47e7e0
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 13 deletions.
19 changes: 11 additions & 8 deletions test/transform/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import numpy
import unittest

import numpy as np
import torch

from torch_em.transform import Tile


from unittest import TestCase


class TestTile(TestCase):
class TestTile(unittest.TestCase):
def test_tile(self):
for ndim, reps in [(1, (4, 2)), (2, (4, 2)), (3, (4, 2))]:
with self.subTest():
Expand All @@ -17,11 +16,11 @@ def test_tile(self):
def _test_tile_impl(ndim, reps):
tile_aug = Tile(reps, match_shape_exactly=len(reps) == ndim)
test_shape = [2, 3, 4][:ndim]
data = numpy.random.random(test_shape)
data = np.random.random(test_shape)

x = torch.tensor(data)

expected = numpy.tile(x.numpy(), reps)
expected = np.tile(x.numpy(), reps)
if len(reps) == ndim:
expected_torch = x.repeat(*reps)
assert expected.shape == expected_torch.shape
Expand All @@ -30,7 +29,11 @@ def _test_tile_impl(ndim, reps):

assert actual.shape == expected.shape

a = numpy.array(data)
a = np.array(data)

actual = tile_aug(a)
assert actual.shape == expected.shape


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/transform/test_label_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,5 @@ def test_distance_transform_empty_labels(self):
self.assertTrue(np.allclose(tnew, 1.0))


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
103 changes: 103 additions & 0 deletions test/transform/test_raw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import unittest
from copy import deepcopy

import numpy as np
import torch


class TestRaw(unittest.TestCase):
def _test_standardize(self, input_):
from torch_em.transform.raw import standardize

def check_out(out):
self.assertEqual(out.shape, input_.shape)
if torch.is_tensor(out):
mean, std = out.mean().numpy(), out.std().numpy()
else:
mean, std = out.mean(), out.std()
self.assertLess(mean, 0.001)
self.assertLess(np.abs(1.0 - std), 0.001)

# test standardize without arguments
out = standardize(deepcopy(input_))
check_out(out)

# test standardize with axis
out = standardize(deepcopy(input_), axis=(1, 2))
check_out(out)

# test standardize with fixed mean and std
mean, std = input_.mean(), input_.std()
out = standardize(deepcopy(input_), mean=mean, std=std)
check_out(out)

def test_standardize_torch(self):
input_ = torch.rand(3, 128, 128)
self._test_standardize(input_)

def test_standardize_numpy(self):
input_ = np.random.rand(3, 128, 128)
self._test_standardize(input_)

def _test_normalize(self, input_):
from torch_em.transform.raw import normalize

def check_out(out):
self.assertEqual(out.shape, input_.shape)
if torch.is_tensor(out):
min_, max_ = out.min().numpy(), out.max().numpy()
else:
min_, max_ = out.min(), out.max()
self.assertLess(min_, 0.001)
self.assertLess(np.abs(1.0 - max_), 0.001)

# test normalize without arguments
out = normalize(deepcopy(input_))
check_out(out)

# test normalize with axis
out = normalize(deepcopy(input_), axis=(1, 2))
check_out(out)

# test normalize with fixed min, max
min_, max_ = input_.min(), input_.max() - input_.min()
out = normalize(deepcopy(input_), minval=min_, maxval=max_)
check_out(out)

def test_normalize_torch(self):
input_ = torch.randn(3, 128, 128)
self._test_normalize(input_)

def test_normalize_numpy(self):
input_ = np.random.randn(3, 128, 128)
self._test_normalize(input_)

def _test_normalize_percentile(self, input_):
from torch_em.transform.raw import normalize_percentile

def check_out(out):
self.assertEqual(out.shape, input_.shape)

# test normalize without arguments
out = normalize_percentile(deepcopy(input_))
check_out(out)

# test normalize with axis
out = normalize_percentile(deepcopy(input_), axis=(1, 2))
check_out(out)

# test normalize with percentile arguments
out = normalize_percentile(deepcopy(input_), lower=5.0, upper=95.0)
check_out(out)

def test_normalize_percentile_torch(self):
input_ = torch.randn(3, 128, 128)
self._test_normalize_percentile(input_)

def test_normalize_percentile_numpy(self):
input_ = np.random.randn(3, 128, 128)
self._test_normalize_percentile(input_)


if __name__ == "__main__":
unittest.main()
6 changes: 2 additions & 4 deletions torch_em/transform/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ def cast(inpt, typestring):
def standardize(raw, mean=None, std=None, axis=None, eps=1e-7):
raw = cast(raw, "float32")

# mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean
mean = raw.mean() if mean is None else mean
mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean
raw -= mean

# std = raw.std(axis=axis, keepdims=True) if std is None else std
std = raw.std() if std is None else std
std = raw.std(axis=axis, keepdims=True) if std is None else std
raw /= (std + eps)

return raw
Expand Down

0 comments on commit d47e7e0

Please sign in to comment.