Skip to content

Commit

Permalink
Merge 169c235 into ac86f80
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jul 21, 2020
2 parents ac86f80 + 169c235 commit 04cfdd3
Show file tree
Hide file tree
Showing 23 changed files with 202 additions and 145 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Expand Up @@ -9,6 +9,7 @@
data/dataset.rst
data/patch_based.rst
transforms/transforms.rst
reproducibility.rst
datasets.rst
cli.rst
slicer.rst
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reproducibility.rst
@@ -0,0 +1,3 @@
###############
Reproducibility
###############
Expand Up @@ -11,11 +11,10 @@ def test_random_elastic_deformation(self):
transform = RandomElasticDeformation(
num_control_points=5,
max_displacement=(2, 3, 5), # half grid spacing is (3.3, 3.3, 5)
seed=42,
)
keys = ('t1', 't2', 'label')
fixtures = 2916.7192, 2955.1265, 2950
transformed = transform(self.sample)
transformed = transform(self.sample, seed=42)
for key, fixture in zip(keys, fixtures):
sample_data = self.sample[key].numpy()
transformed_data = transformed[key].numpy()
Expand Down
68 changes: 34 additions & 34 deletions tests/transforms/preprocessing/test_crop_pad.py
Expand Up @@ -24,42 +24,43 @@ def test_no_changes_mask(self):
transform = CropOrPad(shape, mask_name='label')
with self.assertWarns(UserWarning):
transformed = transform(self.sample)
for key in transformed:
image_dict = self.sample[key]
assert_array_equal(image_dict[DATA], transformed[key][DATA])
assert_array_equal(image_dict[AFFINE], transformed[key][AFFINE])
iterable = transformed.get_images_dict(intensity_only=False).items()
for image_name, image in iterable:
image = self.sample[image_name]
assert_array_equal(image[DATA], transformed[image_name][DATA])
assert_array_equal(image[AFFINE], transformed[image_name][AFFINE])

def test_different_shape(self):
shape = self.sample['t1'].spatial_shape
target_shape = 9, 21, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertNotEqual(shape, result_shape)

def test_shape_right(self):
target_shape = 9, 21, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_only_pad(self):
target_shape = 11, 22, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_only_crop(self):
target_shape = 9, 18, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_shape_negative(self):
Expand All @@ -77,8 +78,8 @@ def test_shape_string(self):
def test_shape_one(self):
transform = CropOrPad(1)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual((1, 1, 1), result_shape)

def test_wrong_mask_name(self):
Expand Down Expand Up @@ -106,17 +107,15 @@ def test_mask_only_pad(self):
mask [0, 4:6, 5:8, 3:7] = 1
transformed = transform(self.sample)
shapes = []
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
shapes.append(result_shape)
set_shapes = set(shapes)
message = f'Images have different shapes: {set_shapes}'
assert len(set_shapes) == 1, message
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape,
f'Wrong shape for image: {key}',
)
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_mask_only_crop(self):
target_shape = 9, 18, 30
Expand All @@ -126,17 +125,15 @@ def test_mask_only_crop(self):
mask [0, 4:6, 5:8, 3:7] = 1
transformed = transform(self.sample)
shapes = []
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
shapes.append(result_shape)
set_shapes = set(shapes)
message = f'Images have different shapes: {set_shapes}'
assert len(set_shapes) == 1, message
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape,
f'Wrong shape for image: {key}',
)
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_center_mask(self):
"""The mask bounding box and the input image have the same center"""
Expand All @@ -148,8 +145,9 @@ def test_center_mask(self):
mask[0, 4:6, 9:11, 14:16] = 1
transformed_center = transform_center(self.sample)
transformed_mask = transform_mask(self.sample)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
tc_images = transformed_center.get_images(intensity_only=False)
tm_images = transformed_mask.get_images(intensity_only=False)
for image_center, image_mask in zip(tc_images, tm_images):
assert_array_equal(
image_center[DATA], image_mask[DATA],
'Data is different after cropping',
Expand All @@ -171,8 +169,9 @@ def test_mask_corners(self):
mask[0, -1, -1, -1] = 1
transformed_center = transform_center(self.sample)
transformed_mask = transform_mask(self.sample)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
tc_images = transformed_center.get_images(intensity_only=False)
tm_images = transformed_mask.get_images(intensity_only=False)
for image_center, image_mask in zip(tc_images, tm_images):
assert_array_equal(
image_center[DATA], image_mask[DATA],
'Data is different after cropping',
Expand All @@ -193,8 +192,9 @@ def test_mask_origin(self):
mask[0, 0, 0, 0] = 1
transformed_center = transform_center(self.sample)
transformed_mask = transform_mask(self.sample)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
tc_images = transformed_center.get_images(intensity_only=False)
tm_images = transformed_mask.get_images(intensity_only=False)
for image_center, image_mask in zip(tc_images, tm_images):
# Arrays are different
assert not np.array_equal(image_center[DATA], image_mask[DATA])
# Rotation matrix doesn't change
Expand Down
23 changes: 11 additions & 12 deletions tests/transforms/preprocessing/test_resample.py
Expand Up @@ -13,33 +13,32 @@ def test_spacing(self):
spacing = 2
transform = Resample(spacing)
transformed = transform(self.sample)
for image_dict in transformed.values():
image = image_dict.as_sitk()
for image in transformed.get_images(intensity_only=False):
image = image.as_sitk()
self.assertEqual(image.GetSpacing(), 3 * (spacing,))

def test_reference_name(self):
sample = self.get_inconsistent_sample()
reference_name = 't1'
transform = Resample(reference_name)
transformed = transform(sample)
ref_image_dict = sample[reference_name]
for image_dict in transformed.values():
self.assertEqual(
ref_image_dict.shape, image_dict.shape)
assert_array_equal(ref_image_dict[AFFINE], image_dict[AFFINE])
ref_image = sample[reference_name]
for image in transformed.get_images(intensity_only=False):
self.assertEqual(ref_image.shape, image.shape)
assert_array_equal(ref_image[AFFINE], image[AFFINE])

def test_affine(self):
spacing = 1
affine_name = 'pre_affine'
transform = Resample(spacing, pre_affine_name=affine_name)
transformed = transform(self.sample)
for image_dict in transformed.values():
if affine_name in image_dict.keys():
for image in transformed.get_images(intensity_only=False):
if affine_name in image.keys():
new_affine = np.eye(4)
new_affine[0, 3] = 10
assert_array_equal(image_dict[AFFINE], new_affine)
assert_array_equal(image[AFFINE], new_affine)
else:
assert_array_equal(image_dict[AFFINE], np.eye(4))
assert_array_equal(image[AFFINE], np.eye(4))

def test_missing_affine(self):
transform = Resample(1, pre_affine_name='missing')
Expand All @@ -50,7 +49,7 @@ def test_reference_path(self):
reference_image, reference_path = self.get_reference_image_and_path()
transform = Resample(reference_path)
transformed = transform(self.sample)
for image in transformed.values():
for image in transformed.get_images(intensity_only=False):
self.assertEqual(reference_image.shape, image.shape)
assert_array_equal(reference_image.affine, image.affine)

Expand Down
89 changes: 89 additions & 0 deletions tests/transforms/test_reproducibility.py
@@ -0,0 +1,89 @@
import warnings
import torch
import torchio
from torchio import Subject, Image, INTENSITY
from torchio.transforms import RandomNoise
from ..utils import TorchioTestCase


class TestReproducibility(TorchioTestCase):

def setUp(self):
super().setUp()
self.subject = Subject(img=Image(tensor=torch.ones(4, 4, 4)))

def random_stuff(self, seed=None):
transform = RandomNoise(std=(100, 100))#, seed=seed)
transformed = transform(self.subject, seed=seed)
value = transformed.img.data.sum().item()
_, seed = transformed.get_applied_transforms()[0]
return value, seed

def test_reproducibility_no_seed(self):
a, seed_a = self.random_stuff()
b, seed_b = self.random_stuff()
self.assertNotEqual(a, b)
c, seed_c = self.random_stuff(seed_a)
self.assertEqual(c, a)
self.assertEqual(seed_c, seed_a)

def test_reproducibility_seed(self):
torch.manual_seed(42)
a, seed_a = self.random_stuff()
b, seed_b = self.random_stuff()
self.assertNotEqual(a, b)
c, seed_c = self.random_stuff(seed_a)
self.assertEqual(c, a)
self.assertEqual(seed_c, seed_a)

torch.manual_seed(42)
a2, seed_a2 = self.random_stuff()
self.assertEqual(a2, a)
self.assertEqual(seed_a2, seed_a)
b2, seed_b2 = self.random_stuff()
self.assertNotEqual(a2, b2)
self.assertEqual(b2, b)
self.assertEqual(seed_b2, seed_b)
c2, seed_c2 = self.random_stuff(seed_a2)
self.assertEqual(c2, a2)
self.assertEqual(seed_c2, seed_a2)
self.assertEqual(c2, c)
self.assertEqual(seed_c2, seed_c)

# def test_all_random_transforms(self):
# sample = Subject(
# t1=Image(tensor=torch.rand(20, 20, 20)),
# seg=Image(tensor=torch.rand(20, 20, 20) > 1, type=INTENSITY)
# )

# transforms_names = [
# name
# for name in dir(torchio)
# if name.startswith('Random')
# ]

# # Downsample at the end so that the image shape is not modified
# transforms_names.remove('RandomDownsample')
# transforms_names.append('RandomDownsample')

# transforms = []
# for transform_name in transforms_names:
# transform = getattr(torchio, transform_name)()
# transforms.append(transform)
# composed_transform = torchio.Compose(transforms)
# with warnings.catch_warnings(): # ignore elastic deformation warning
# warnings.simplefilter('ignore', UserWarning)
# transformed = composed_transform(sample)

# new_transforms = []
# for transform_name, params_dict in transformed.history:
# transform_class = getattr(torchio, transform_name)
# transform = transform_class(seed=params_dict['seed'])
# new_transforms.append(transform)
# composed_transform = torchio.Compose(transforms)
# with warnings.catch_warnings(): # ignore elastic deformation warning
# warnings.simplefilter('ignore', UserWarning)
# new_transformed = composed_transform(sample)

# self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
# self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
14 changes: 9 additions & 5 deletions tests/utils.py
Expand Up @@ -4,6 +4,7 @@
import tempfile
import unittest
from pathlib import Path
import torch
import numpy as np
import nibabel as nib
from torchio.datasets import IXITiny
Expand Down Expand Up @@ -54,6 +55,14 @@ def setUp(self):
self.dataset = ImagesDataset(self.subjects_list)
self.sample = self.dataset[-1]

def tearDown(self):
"""Tear down test fixtures, if any."""
print('Deleting', self.dir)
shutil.rmtree(self.dir)

def assertTensorEqual(self, a, b):
assert torch.all(torch.eq(a, b))

def make_2d(self, sample):
sample = copy.deepcopy(sample)
for image in sample.get_images(intensity_only=False):
Expand Down Expand Up @@ -85,11 +94,6 @@ def get_reference_image_and_path(self):
image = Image(path, INTENSITY)
return image, path

def tearDown(self):
"""Tear down test fixtures, if any."""
print('Deleting', self.dir)
shutil.rmtree(self.dir)

def get_ixi_tiny(self):
root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
return IXITiny(root_dir, download=True)
Expand Down
3 changes: 1 addition & 2 deletions torchio/cli.py
Expand Up @@ -51,14 +51,13 @@ def apply_transform(
raise ValueError(message) from error

params_dict = get_params_dict_from_kwargs(kwargs)
if issubclass(transform_class, RandomTransform):
params_dict['seed'] = seed
transform = transform_class(**params_dict)
apply_transform_to_file(
input_path,
transform,
output_path,
verbose=verbose,
seed=seed,
)
return 0

Expand Down
2 changes: 1 addition & 1 deletion torchio/data/queue.py
Expand Up @@ -54,7 +54,7 @@ class Queue(Dataset):
>>> patch_size = 96
>>> queue_length = 300
>>> samples_per_volume = 10
>>> sample = torchio.data.UniformSampler(patch_size)
>>> sampler = torchio.data.UniformSampler(patch_size)
>>> patches_queue = torchio.Queue(
... subjects_dataset, # instance of torchio.ImagesDataset
... queue_length,
Expand Down

0 comments on commit 04cfdd3

Please sign in to comment.