Skip to content

Commit

Permalink
Add test for different interpolation when inverting
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jan 18, 2021
1 parent bf665d2 commit 13b6ad8
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/transforms/test_invertibility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy
import warnings

import torch
import torchio as tio
from torchio.transforms.intensity_transform import IntensityTransform
from ..utils import TorchioTestCase

Expand Down Expand Up @@ -37,3 +40,41 @@ def test_ignore_intensity(self):
inverse_transform = transformed.get_inverse_transform(warn=False)
for transform in inverse_transform:
assert not isinstance(transform, IntensityTransform)

def test_different_interpolation(self):
def model_probs(subject):
subject = copy.deepcopy(subject)
subject.im.set_data(torch.rand_like(subject.im.data))
return subject

def model_label(subject):
subject = model_probs(subject)
subject.im.set_data(torch.bernoulli(subject.im.data))
return subject

transform = tio.RandomAffine(image_interpolation='bspline')
subject = copy.deepcopy(self.sample_subject)
tensor = (torch.rand(1, 20, 20, 20) > 0.5).float() # 0s and 1s
subject = tio.Subject(im=tio.ScalarImage(tensor=tensor))
transformed = transform(subject)
assert transformed.im.data.min() < 0
assert transformed.im.data.max() > 1

subject_probs = model_probs(transformed)
transformed_back = subject_probs.apply_inverse_transform()
assert transformed_back.im.data.min() < 0
assert transformed_back.im.data.max() > 1
transformed_back_linear = subject_probs.apply_inverse_transform(
image_interpolation='linear',
)
assert transformed_back_linear.im.data.min() >= 0
assert transformed_back_linear.im.data.max() <= 1

subject_label = model_label(transformed)
transformed_back = subject_label.apply_inverse_transform()
assert transformed_back.im.data.min() < 0
assert transformed_back.im.data.max() > 1
transformed_back_linear = subject_label.apply_inverse_transform(
image_interpolation='nearest',
)
assert transformed_back_linear.im.data.unique().tolist() == [0, 1]

0 comments on commit 13b6ad8

Please sign in to comment.