diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 5e9068bd4..bb776a228 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -1,5 +1,4 @@ -#!/usr/bin/env python - +import copy import torch import numpy as np import torchio @@ -102,3 +101,16 @@ def test_transform_noop(self): tensor = torch.rand(2, 4, 5, 8).numpy() transformed = transform(tensor) self.assertIs(transformed, tensor) + + def test_original_unchanged(self): + sample = copy.deepcopy(self.sample) + composed = self.get_transform(channels=('t1', 't2'), is_3d=True) + sample = self.flip_affine_x(sample) + for transform in composed.transform.transforms: + original_data = copy.deepcopy(sample.t1.data) + transformed = transform(sample) + self.assertTensorEqual( + sample.t1.data, + original_data, + f'Changes after {transform.name}' + ) diff --git a/tests/utils.py b/tests/utils.py index decf2be43..38570f2f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -145,3 +145,9 @@ def assertTensorNotEqual(self, a, b, message=None): assert not torch.all(torch.eq(a, b)) else: assert not torch.all(torch.eq(a, b)), message + + def assertTensorEqual(self, a, b, message=None): + if message is None: + assert torch.all(torch.eq(a, b)) + else: + assert torch.all(torch.eq(a, b)), message