Skip to content

Commit

Permalink
Fix default pad value for label maps
Browse files Browse the repository at this point in the history
Fixes #626.
  • Loading branch information
fepegar committed Aug 20, 2021
1 parent f20eb27 commit e90e7c5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 56 deletions.
77 changes: 43 additions & 34 deletions tests/transforms/augmentation/test_random_affine.py
@@ -1,4 +1,5 @@
from torchio.transforms import RandomAffine
import torch
import torchio as tio
from ...utils import TorchioTestCase


Expand All @@ -12,7 +13,7 @@ def setUp(self):

def test_rotation_image(self):
# Rotation around image center
transform = RandomAffine(
transform = tio.RandomAffine(
degrees=(90, 90),
default_pad_value=0,
center='image',
Expand All @@ -23,7 +24,7 @@ def test_rotation_image(self):

def test_rotation_origin(self):
# Rotation around far away point, image should be empty
transform = RandomAffine(
transform = tio.RandomAffine(
degrees=(90, 90),
default_pad_value=0,
center='origin',
Expand All @@ -33,7 +34,7 @@ def test_rotation_origin(self):
self.assertEqual(total, 0)

def test_no_rotation(self):
transform = RandomAffine(
transform = tio.RandomAffine(
scales=(1, 1),
degrees=(0, 0),
default_pad_value=0,
Expand All @@ -45,7 +46,7 @@ def test_no_rotation(self):
transformed.t1.data,
)

transform = RandomAffine(
transform = tio.RandomAffine(
scales=(1, 1),
degrees=(180, 180),
default_pad_value=0,
Expand All @@ -59,20 +60,20 @@ def test_no_rotation(self):
)

def test_isotropic(self):
RandomAffine(isotropic=True)(self.sample_subject)
tio.RandomAffine(isotropic=True)(self.sample_subject)

def test_mean(self):
RandomAffine(default_pad_value='mean')(self.sample_subject)
tio.RandomAffine(default_pad_value='mean')(self.sample_subject)

def test_otsu(self):
RandomAffine(default_pad_value='otsu')(self.sample_subject)
tio.RandomAffine(default_pad_value='otsu')(self.sample_subject)

def test_bad_center(self):
with self.assertRaises(ValueError):
RandomAffine(center='bad')
tio.RandomAffine(center='bad')

def test_translation(self):
transform = RandomAffine(
transform = tio.RandomAffine(
scales=(1, 1),
degrees=0,
translation=(5, 5)
Expand All @@ -93,72 +94,80 @@ def test_translation(self):

def test_negative_scales(self):
with self.assertRaises(ValueError):
RandomAffine(scales=(-1, 1))
tio.RandomAffine(scales=(-1, 1))

def test_scale_too_large(self):
with self.assertRaises(ValueError):
RandomAffine(scales=1.5)
tio.RandomAffine(scales=1.5)

def test_scales_range_with_negative_min(self):
with self.assertRaises(ValueError):
RandomAffine(scales=(-1, 4))
tio.RandomAffine(scales=(-1, 4))

def test_wrong_scales_type(self):
with self.assertRaises(ValueError):
RandomAffine(scales='wrong')
tio.RandomAffine(scales='wrong')

def test_wrong_degrees_type(self):
with self.assertRaises(ValueError):
RandomAffine(degrees='wrong')
tio.RandomAffine(degrees='wrong')

def test_too_many_translation_values(self):
with self.assertRaises(ValueError):
RandomAffine(translation=(-10, 4, 42))
tio.RandomAffine(translation=(-10, 4, 42))

def test_wrong_translation_type(self):
with self.assertRaises(ValueError):
RandomAffine(translation='wrong')
tio.RandomAffine(translation='wrong')

def test_wrong_center(self):
with self.assertRaises(ValueError):
RandomAffine(center=0)
tio.RandomAffine(center=0)

def test_wrong_default_pad_value(self):
with self.assertRaises(ValueError):
RandomAffine(default_pad_value='wrong')
tio.RandomAffine(default_pad_value='wrong')

def test_wrong_image_interpolation_type(self):
with self.assertRaises(TypeError):
RandomAffine(image_interpolation=0)
tio.RandomAffine(image_interpolation=0)

def test_wrong_image_interpolation_value(self):
with self.assertRaises(ValueError):
RandomAffine(image_interpolation='wrong')
tio.RandomAffine(image_interpolation='wrong')

def test_incompatible_args_isotropic(self):
with self.assertRaises(ValueError):
RandomAffine(scales=(0.8, 0.5, 0.1), isotropic=True)
tio.RandomAffine(scales=(0.8, 0.5, 0.1), isotropic=True)

def test_parse_scales(self):
def do_assert(transform):
self.assertEqual(transform.scales, 3 * (0.9, 1.1))
do_assert(RandomAffine(scales=0.1))
do_assert(RandomAffine(scales=(0.9, 1.1)))
do_assert(RandomAffine(scales=3 * (0.1,)))
do_assert(RandomAffine(scales=3 * [0.9, 1.1]))
do_assert(tio.RandomAffine(scales=0.1))
do_assert(tio.RandomAffine(scales=(0.9, 1.1)))
do_assert(tio.RandomAffine(scales=3 * (0.1,)))
do_assert(tio.RandomAffine(scales=3 * [0.9, 1.1]))

def test_parse_degrees(self):
def do_assert(transform):
self.assertEqual(transform.degrees, 3 * (-10, 10))
do_assert(RandomAffine(degrees=10))
do_assert(RandomAffine(degrees=(-10, 10)))
do_assert(RandomAffine(degrees=3 * (10,)))
do_assert(RandomAffine(degrees=3 * [-10, 10]))
do_assert(tio.RandomAffine(degrees=10))
do_assert(tio.RandomAffine(degrees=(-10, 10)))
do_assert(tio.RandomAffine(degrees=3 * (10,)))
do_assert(tio.RandomAffine(degrees=3 * [-10, 10]))

def test_parse_translation(self):
def do_assert(transform):
self.assertEqual(transform.translation, 3 * (-10, 10))
do_assert(RandomAffine(translation=10))
do_assert(RandomAffine(translation=(-10, 10)))
do_assert(RandomAffine(translation=3 * (10,)))
do_assert(RandomAffine(translation=3 * [-10, 10]))
do_assert(tio.RandomAffine(translation=10))
do_assert(tio.RandomAffine(translation=(-10, 10)))
do_assert(tio.RandomAffine(translation=3 * (10,)))
do_assert(tio.RandomAffine(translation=3 * [-10, 10]))

def test_default_value_label_map(self):
# From https://github.com/fepegar/torchio/issues/626
a = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).reshape(1, 3, 3, 1)
image = tio.LabelMap(tensor=a)
aff = tio.RandomAffine(translation=(0, 1, 1), default_pad_value='otsu')
transformed = aff(image)
assert all(n in (0, 1) for n in transformed.data.flatten())
47 changes: 25 additions & 22 deletions torchio/transforms/augmentation/spatial/random_affine.py
Expand Up @@ -259,11 +259,6 @@ def apply_transform(self, subject: Subject) -> Subject:
translation_params = np.array(self.translation).copy()
subject.check_consistent_spatial_shape()
for image in self.get_images(subject):
if image[TYPE] != INTENSITY:
interpolation = 'nearest'
else:
interpolation = self.image_interpolation

if image.is_2d():
scaling_params[2] = 1
rotation_params[:-1] = 0
Expand All @@ -275,13 +270,33 @@ def apply_transform(self, subject: Subject) -> Subject:

transformed_tensors = []
for tensor in image.data:
transformed_tensor = self.apply_affine_transform(
tensor,
sitk_image = nib_to_sitk(
tensor[np.newaxis],
image.affine,
force_3d=True,
)
if image[TYPE] != INTENSITY:
interpolation = 'nearest'
default_value = 0
else:
interpolation = self.image_interpolation
if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(
sitk_image, filter_otsu=False)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(
sitk_image, filter_otsu=True)
else:
default_value = self.default_pad_value
transformed_tensor = self.apply_affine_transform(
sitk_image,
scaling_params.tolist(),
rotation_params.tolist(),
translation_params.tolist(),
interpolation,
default_value,
center_lps=center,
)
transformed_tensors.append(transformed_tensor)
Expand All @@ -290,18 +305,15 @@ def apply_transform(self, subject: Subject) -> Subject:

def apply_affine_transform(
self,
tensor: torch.Tensor,
affine: np.ndarray,
sitk_image: sitk.Image,
scaling_params: Sequence[float],
rotation_params: Sequence[float],
translation_params: Sequence[float],
interpolation: str,
default_value: float,
center_lps: Optional[TypeTripletFloat] = None,
) -> torch.Tensor:
assert tensor.ndim == 3

image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
floating = reference = image
floating = reference = sitk_image

scaling_transform = self.get_scaling_transform(
scaling_params,
Expand All @@ -325,15 +337,6 @@ def apply_affine_transform(
if self.invert_transform:
transform = transform.GetInverse()

if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(image, filter_otsu=False)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(image, filter_otsu=True)
else:
default_value = self.default_pad_value

resampler = sitk.ResampleImageFilter()
resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
resampler.SetReferenceImage(reference)
Expand Down

0 comments on commit e90e7c5

Please sign in to comment.