From 13b6ad8e1692a5673c0e7112ffe6ab6772df66d4 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 18 Jan 2021 00:05:55 +0000 Subject: [PATCH] Add test for different interpolation when inverting --- tests/transforms/test_invertibility.py | 41 ++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/transforms/test_invertibility.py b/tests/transforms/test_invertibility.py index 0aca6ee63..0b16f38ee 100644 --- a/tests/transforms/test_invertibility.py +++ b/tests/transforms/test_invertibility.py @@ -1,5 +1,8 @@ +import copy import warnings +import torch +import torchio as tio from torchio.transforms.intensity_transform import IntensityTransform from ..utils import TorchioTestCase @@ -37,3 +40,41 @@ def test_ignore_intensity(self): inverse_transform = transformed.get_inverse_transform(warn=False) for transform in inverse_transform: assert not isinstance(transform, IntensityTransform) + + def test_different_interpolation(self): + def model_probs(subject): + subject = copy.deepcopy(subject) + subject.im.set_data(torch.rand_like(subject.im.data)) + return subject + + def model_label(subject): + subject = model_probs(subject) + subject.im.set_data(torch.bernoulli(subject.im.data)) + return subject + + transform = tio.RandomAffine(image_interpolation='bspline') + subject = copy.deepcopy(self.sample_subject) + tensor = (torch.rand(1, 20, 20, 20) > 0.5).float() # 0s and 1s + subject = tio.Subject(im=tio.ScalarImage(tensor=tensor)) + transformed = transform(subject) + assert transformed.im.data.min() < 0 + assert transformed.im.data.max() > 1 + + subject_probs = model_probs(transformed) + transformed_back = subject_probs.apply_inverse_transform() + assert transformed_back.im.data.min() < 0 + assert transformed_back.im.data.max() > 1 + transformed_back_linear = subject_probs.apply_inverse_transform( + image_interpolation='linear', + ) + assert transformed_back_linear.im.data.min() >= 0 + assert transformed_back_linear.im.data.max() <= 1 + + subject_label = model_label(transformed) + transformed_back = subject_label.apply_inverse_transform() + assert transformed_back.im.data.min() < 0 + assert transformed_back.im.data.max() > 1 + transformed_back_linear = subject_label.apply_inverse_transform( + image_interpolation='nearest', + ) + assert transformed_back_linear.im.data.unique().tolist() == [0, 1]