Skip to content

Commit

Permalink
Merge 6827698 into ba5be71
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jan 19, 2021
2 parents ba5be71 + 6827698 commit c1160dd
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
4 changes: 2 additions & 2 deletions docs/source/transforms/transforms.rst
Expand Up @@ -125,10 +125,10 @@ or `aleatoric uncertainty estimation <https://www.sciencedirect.com/science/arti
>>> segmentations = []
>>> num_segmentations = 10
>>> for _ in range(num_segmentations):
... transform = tio.RandomAffine()
... transform = tio.RandomAffine(image_interpolation='bspline')
... transformed = transform(subject)
... segmentation = model(transformed)
... transformed_native_space = segmentation.apply_inverse_transform()
... transformed_native_space = segmentation.apply_inverse_transform(image_interpolation='linear')
... segmentations.append(transformed_native_space)
...

Expand Down
41 changes: 41 additions & 0 deletions tests/transforms/test_invertibility.py
@@ -1,5 +1,8 @@
import copy
import warnings

import torch
import torchio as tio
from torchio.transforms.intensity_transform import IntensityTransform
from ..utils import TorchioTestCase

Expand Down Expand Up @@ -37,3 +40,41 @@ def test_ignore_intensity(self):
inverse_transform = transformed.get_inverse_transform(warn=False)
for transform in inverse_transform:
assert not isinstance(transform, IntensityTransform)

def test_different_interpolation(self):
def model_probs(subject):
subject = copy.deepcopy(subject)
subject.im.set_data(torch.rand_like(subject.im.data))
return subject

def model_label(subject):
subject = model_probs(subject)
subject.im.set_data(torch.bernoulli(subject.im.data))
return subject

transform = tio.RandomAffine(image_interpolation='bspline')
subject = copy.deepcopy(self.sample_subject)
tensor = (torch.rand(1, 20, 20, 20) > 0.5).float() # 0s and 1s
subject = tio.Subject(im=tio.ScalarImage(tensor=tensor))
transformed = transform(subject)
assert transformed.im.data.min() < 0
assert transformed.im.data.max() > 1

subject_probs = model_probs(transformed)
transformed_back = subject_probs.apply_inverse_transform()
assert transformed_back.im.data.min() < 0
assert transformed_back.im.data.max() > 1
transformed_back_linear = subject_probs.apply_inverse_transform(
image_interpolation='linear',
)
assert transformed_back_linear.im.data.min() >= 0
assert transformed_back_linear.im.data.max() <= 1

subject_label = model_label(transformed)
transformed_back = subject_label.apply_inverse_transform()
assert transformed_back.im.data.min() < 0
assert transformed_back.im.data.max() > 1
transformed_back_linear = subject_label.apply_inverse_transform(
image_interpolation='nearest',
)
assert transformed_back_linear.im.data.unique().tolist() == [0, 1]
17 changes: 15 additions & 2 deletions torchio/data/subject.py
Expand Up @@ -120,6 +120,7 @@ def history(self):
def get_applied_transforms(
self,
ignore_intensity: bool = False,
image_interpolation: Optional[str] = None,
) -> List['Transform']:
from ..transforms.transform import Transform
from ..transforms.intensity_transform import IntensityTransform
Expand All @@ -132,22 +133,30 @@ def get_applied_transforms(
transform = name_to_transform[transform_name](**arguments)
if ignore_intensity and isinstance(transform, IntensityTransform):
continue
resamples = hasattr(transform, 'image_interpolation')
if resamples and image_interpolation is not None:
parsed = transform.parse_interpolation(image_interpolation)
transform.image_interpolation = parsed
transforms_list.append(transform)
return transforms_list

def get_composed_history(
self,
ignore_intensity: bool = False,
image_interpolation: Optional[str] = None,
) -> 'Compose':
from ..transforms.augmentation.composition import Compose
transforms = self.get_applied_transforms(
ignore_intensity=ignore_intensity)
ignore_intensity=ignore_intensity,
image_interpolation=image_interpolation,
)
return Compose(transforms)

def get_inverse_transform(
self,
warn: bool = True,
ignore_intensity: bool = True,
image_interpolation: Optional[str] = None,
) -> 'Compose':
"""Get a reversed list of the inverses of the applied transforms.
Expand All @@ -156,9 +165,13 @@ def get_inverse_transform(
ignore_intensity: If ``True``, all instances of
:class:`~torchio.transforms.intensity_transform.IntensityTransform`
will be ignored.
image_interpolation: Modify interpolation for scalar images inside
transforms that perform resampling.
"""
history_transform = self.get_composed_history(
ignore_intensity=ignore_intensity)
ignore_intensity=ignore_intensity,
image_interpolation=image_interpolation,
)
inverse_transform = history_transform.inverse(warn=warn)
return inverse_transform

Expand Down

0 comments on commit c1160dd

Please sign in to comment.