diff --git a/docs/source/transforms/transforms.rst b/docs/source/transforms/transforms.rst index 6924d8c56..9147a01cb 100644 --- a/docs/source/transforms/transforms.rst +++ b/docs/source/transforms/transforms.rst @@ -125,10 +125,10 @@ or `aleatoric uncertainty estimation >> segmentations = [] >>> num_segmentations = 10 >>> for _ in range(num_segmentations): - ... transform = tio.RandomAffine() + ... transform = tio.RandomAffine(image_interpolation='bspline') ... transformed = transform(subject) ... segmentation = model(transformed) - ... transformed_native_space = segmentation.apply_inverse_transform() + ... transformed_native_space = segmentation.apply_inverse_transform(image_interpolation='linear') ... segmentations.append(transformed_native_space) ... diff --git a/tests/transforms/test_invertibility.py b/tests/transforms/test_invertibility.py index 0aca6ee63..0b16f38ee 100644 --- a/tests/transforms/test_invertibility.py +++ b/tests/transforms/test_invertibility.py @@ -1,5 +1,8 @@ +import copy import warnings +import torch +import torchio as tio from torchio.transforms.intensity_transform import IntensityTransform from ..utils import TorchioTestCase @@ -37,3 +40,41 @@ def test_ignore_intensity(self): inverse_transform = transformed.get_inverse_transform(warn=False) for transform in inverse_transform: assert not isinstance(transform, IntensityTransform) + + def test_different_interpolation(self): + def model_probs(subject): + subject = copy.deepcopy(subject) + subject.im.set_data(torch.rand_like(subject.im.data)) + return subject + + def model_label(subject): + subject = model_probs(subject) + subject.im.set_data(torch.bernoulli(subject.im.data)) + return subject + + transform = tio.RandomAffine(image_interpolation='bspline') + subject = copy.deepcopy(self.sample_subject) + tensor = (torch.rand(1, 20, 20, 20) > 0.5).float() # 0s and 1s + subject = tio.Subject(im=tio.ScalarImage(tensor=tensor)) + transformed = transform(subject) + assert transformed.im.data.min() < 0 + assert transformed.im.data.max() > 1 + + subject_probs = model_probs(transformed) + transformed_back = subject_probs.apply_inverse_transform() + assert transformed_back.im.data.min() < 0 + assert transformed_back.im.data.max() > 1 + transformed_back_linear = subject_probs.apply_inverse_transform( + image_interpolation='linear', + ) + assert transformed_back_linear.im.data.min() >= 0 + assert transformed_back_linear.im.data.max() <= 1 + + subject_label = model_label(transformed) + transformed_back = subject_label.apply_inverse_transform() + assert transformed_back.im.data.min() < 0 + assert transformed_back.im.data.max() > 1 + transformed_back_linear = subject_label.apply_inverse_transform( + image_interpolation='nearest', + ) + assert transformed_back_linear.im.data.unique().tolist() == [0, 1] diff --git a/torchio/data/subject.py b/torchio/data/subject.py index 74048dbef..1a9235d75 100644 --- a/torchio/data/subject.py +++ b/torchio/data/subject.py @@ -120,6 +120,7 @@ def history(self): def get_applied_transforms( self, ignore_intensity: bool = False, + image_interpolation: Optional[str] = None, ) -> List['Transform']: from ..transforms.transform import Transform from ..transforms.intensity_transform import IntensityTransform @@ -132,22 +133,30 @@ def get_applied_transforms( transform = name_to_transform[transform_name](**arguments) if ignore_intensity and isinstance(transform, IntensityTransform): continue + resamples = hasattr(transform, 'image_interpolation') + if resamples and image_interpolation is not None: + parsed = transform.parse_interpolation(image_interpolation) + transform.image_interpolation = parsed transforms_list.append(transform) return transforms_list def get_composed_history( self, ignore_intensity: bool = False, + image_interpolation: Optional[str] = None, ) -> 'Compose': from ..transforms.augmentation.composition import Compose transforms = self.get_applied_transforms( - ignore_intensity=ignore_intensity) + ignore_intensity=ignore_intensity, + image_interpolation=image_interpolation, + ) return Compose(transforms) def get_inverse_transform( self, warn: bool = True, ignore_intensity: bool = True, + image_interpolation: Optional[str] = None, ) -> 'Compose': """Get a reversed list of the inverses of the applied transforms. @@ -156,9 +165,13 @@ def get_inverse_transform( ignore_intensity: If ``True``, all instances of :class:`~torchio.transforms.intensity_transform.IntensityTransform` will be ignored. + image_interpolation: Modify interpolation for scalar images inside + transforms that perform resampling. """ history_transform = self.get_composed_history( - ignore_intensity=ignore_intensity) + ignore_intensity=ignore_intensity, + image_interpolation=image_interpolation, + ) inverse_transform = history_transform.inverse(warn=warn) return inverse_transform