Skip to content

Commit

Permalink
Improve random transforms reproducibility
Browse files Browse the repository at this point in the history
Add seed to dict only if transformed Subject

Add method to generate seed

Add reproducibility tests

Update tests

Fix docstring

Update CLI tool

Add get_transform function

Use JSON to store transforms history

Update transforms

Rename function to get transform class
  • Loading branch information
fepegar committed Jul 21, 2020
1 parent ac86f80 commit 169c235
Show file tree
Hide file tree
Showing 23 changed files with 202 additions and 145 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Expand Up @@ -9,6 +9,7 @@
data/dataset.rst
data/patch_based.rst
transforms/transforms.rst
reproducibility.rst
datasets.rst
cli.rst
slicer.rst
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reproducibility.rst
@@ -0,0 +1,3 @@
###############
Reproducibility
###############
Expand Up @@ -11,11 +11,10 @@ def test_random_elastic_deformation(self):
transform = RandomElasticDeformation(
num_control_points=5,
max_displacement=(2, 3, 5), # half grid spacing is (3.3, 3.3, 5)
seed=42,
)
keys = ('t1', 't2', 'label')
fixtures = 2916.7192, 2955.1265, 2950
transformed = transform(self.sample)
transformed = transform(self.sample, seed=42)
for key, fixture in zip(keys, fixtures):
sample_data = self.sample[key].numpy()
transformed_data = transformed[key].numpy()
Expand Down
68 changes: 34 additions & 34 deletions tests/transforms/preprocessing/test_crop_pad.py
Expand Up @@ -24,42 +24,43 @@ def test_no_changes_mask(self):
transform = CropOrPad(shape, mask_name='label')
with self.assertWarns(UserWarning):
transformed = transform(self.sample)
for key in transformed:
image_dict = self.sample[key]
assert_array_equal(image_dict[DATA], transformed[key][DATA])
assert_array_equal(image_dict[AFFINE], transformed[key][AFFINE])
iterable = transformed.get_images_dict(intensity_only=False).items()
for image_name, image in iterable:
image = self.sample[image_name]
assert_array_equal(image[DATA], transformed[image_name][DATA])
assert_array_equal(image[AFFINE], transformed[image_name][AFFINE])

def test_different_shape(self):
shape = self.sample['t1'].spatial_shape
target_shape = 9, 21, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertNotEqual(shape, result_shape)

def test_shape_right(self):
target_shape = 9, 21, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_only_pad(self):
target_shape = 11, 22, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_only_crop(self):
target_shape = 9, 18, 30
transform = CropOrPad(target_shape)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_shape_negative(self):
Expand All @@ -77,8 +78,8 @@ def test_shape_string(self):
def test_shape_one(self):
transform = CropOrPad(1)
transformed = transform(self.sample)
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual((1, 1, 1), result_shape)

def test_wrong_mask_name(self):
Expand Down Expand Up @@ -106,17 +107,15 @@ def test_mask_only_pad(self):
mask [0, 4:6, 5:8, 3:7] = 1
transformed = transform(self.sample)
shapes = []
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
shapes.append(result_shape)
set_shapes = set(shapes)
message = f'Images have different shapes: {set_shapes}'
assert len(set_shapes) == 1, message
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape,
f'Wrong shape for image: {key}',
)
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_mask_only_crop(self):
target_shape = 9, 18, 30
Expand All @@ -126,17 +125,15 @@ def test_mask_only_crop(self):
mask [0, 4:6, 5:8, 3:7] = 1
transformed = transform(self.sample)
shapes = []
for key in transformed:
result_shape = transformed[key].spatial_shape
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
shapes.append(result_shape)
set_shapes = set(shapes)
message = f'Images have different shapes: {set_shapes}'
assert len(set_shapes) == 1, message
for key in transformed:
result_shape = transformed[key].spatial_shape
self.assertEqual(target_shape, result_shape,
f'Wrong shape for image: {key}',
)
for image in transformed.get_images(intensity_only=False):
result_shape = image.spatial_shape
self.assertEqual(target_shape, result_shape)

def test_center_mask(self):
"""The mask bounding box and the input image have the same center"""
Expand All @@ -148,8 +145,9 @@ def test_center_mask(self):
mask[0, 4:6, 9:11, 14:16] = 1
transformed_center = transform_center(self.sample)
transformed_mask = transform_mask(self.sample)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
tc_images = transformed_center.get_images(intensity_only=False)
tm_images = transformed_mask.get_images(intensity_only=False)
for image_center, image_mask in zip(tc_images, tm_images):
assert_array_equal(
image_center[DATA], image_mask[DATA],
'Data is different after cropping',
Expand All @@ -171,8 +169,9 @@ def test_mask_corners(self):
mask[0, -1, -1, -1] = 1
transformed_center = transform_center(self.sample)
transformed_mask = transform_mask(self.sample)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
tc_images = transformed_center.get_images(intensity_only=False)
tm_images = transformed_mask.get_images(intensity_only=False)
for image_center, image_mask in zip(tc_images, tm_images):
assert_array_equal(
image_center[DATA], image_mask[DATA],
'Data is different after cropping',
Expand All @@ -193,8 +192,9 @@ def test_mask_origin(self):
mask[0, 0, 0, 0] = 1
transformed_center = transform_center(self.sample)
transformed_mask = transform_mask(self.sample)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
tc_images = transformed_center.get_images(intensity_only=False)
tm_images = transformed_mask.get_images(intensity_only=False)
for image_center, image_mask in zip(tc_images, tm_images):
# Arrays are different
assert not np.array_equal(image_center[DATA], image_mask[DATA])
# Rotation matrix doesn't change
Expand Down
23 changes: 11 additions & 12 deletions tests/transforms/preprocessing/test_resample.py
Expand Up @@ -13,33 +13,32 @@ def test_spacing(self):
spacing = 2
transform = Resample(spacing)
transformed = transform(self.sample)
for image_dict in transformed.values():
image = image_dict.as_sitk()
for image in transformed.get_images(intensity_only=False):
image = image.as_sitk()
self.assertEqual(image.GetSpacing(), 3 * (spacing,))

def test_reference_name(self):
sample = self.get_inconsistent_sample()
reference_name = 't1'
transform = Resample(reference_name)
transformed = transform(sample)
ref_image_dict = sample[reference_name]
for image_dict in transformed.values():
self.assertEqual(
ref_image_dict.shape, image_dict.shape)
assert_array_equal(ref_image_dict[AFFINE], image_dict[AFFINE])
ref_image = sample[reference_name]
for image in transformed.get_images(intensity_only=False):
self.assertEqual(ref_image.shape, image.shape)
assert_array_equal(ref_image[AFFINE], image[AFFINE])

def test_affine(self):
spacing = 1
affine_name = 'pre_affine'
transform = Resample(spacing, pre_affine_name=affine_name)
transformed = transform(self.sample)
for image_dict in transformed.values():
if affine_name in image_dict.keys():
for image in transformed.get_images(intensity_only=False):
if affine_name in image.keys():
new_affine = np.eye(4)
new_affine[0, 3] = 10
assert_array_equal(image_dict[AFFINE], new_affine)
assert_array_equal(image[AFFINE], new_affine)
else:
assert_array_equal(image_dict[AFFINE], np.eye(4))
assert_array_equal(image[AFFINE], np.eye(4))

def test_missing_affine(self):
transform = Resample(1, pre_affine_name='missing')
Expand All @@ -50,7 +49,7 @@ def test_reference_path(self):
reference_image, reference_path = self.get_reference_image_and_path()
transform = Resample(reference_path)
transformed = transform(self.sample)
for image in transformed.values():
for image in transformed.get_images(intensity_only=False):
self.assertEqual(reference_image.shape, image.shape)
assert_array_equal(reference_image.affine, image.affine)

Expand Down
89 changes: 89 additions & 0 deletions tests/transforms/test_reproducibility.py
@@ -0,0 +1,89 @@
import warnings
import torch
import torchio
from torchio import Subject, Image, INTENSITY
from torchio.transforms import RandomNoise
from ..utils import TorchioTestCase


class TestReproducibility(TorchioTestCase):

def setUp(self):
super().setUp()
self.subject = Subject(img=Image(tensor=torch.ones(4, 4, 4)))

def random_stuff(self, seed=None):
transform = RandomNoise(std=(100, 100))#, seed=seed)
transformed = transform(self.subject, seed=seed)
value = transformed.img.data.sum().item()
_, seed = transformed.get_applied_transforms()[0]
return value, seed

def test_reproducibility_no_seed(self):
a, seed_a = self.random_stuff()
b, seed_b = self.random_stuff()
self.assertNotEqual(a, b)
c, seed_c = self.random_stuff(seed_a)
self.assertEqual(c, a)
self.assertEqual(seed_c, seed_a)

def test_reproducibility_seed(self):
torch.manual_seed(42)
a, seed_a = self.random_stuff()
b, seed_b = self.random_stuff()
self.assertNotEqual(a, b)
c, seed_c = self.random_stuff(seed_a)
self.assertEqual(c, a)
self.assertEqual(seed_c, seed_a)

torch.manual_seed(42)
a2, seed_a2 = self.random_stuff()
self.assertEqual(a2, a)
self.assertEqual(seed_a2, seed_a)
b2, seed_b2 = self.random_stuff()
self.assertNotEqual(a2, b2)
self.assertEqual(b2, b)
self.assertEqual(seed_b2, seed_b)
c2, seed_c2 = self.random_stuff(seed_a2)
self.assertEqual(c2, a2)
self.assertEqual(seed_c2, seed_a2)
self.assertEqual(c2, c)
self.assertEqual(seed_c2, seed_c)

# def test_all_random_transforms(self):
# sample = Subject(
# t1=Image(tensor=torch.rand(20, 20, 20)),
# seg=Image(tensor=torch.rand(20, 20, 20) > 1, type=INTENSITY)
# )

# transforms_names = [
# name
# for name in dir(torchio)
# if name.startswith('Random')
# ]

# # Downsample at the end so that the image shape is not modified
# transforms_names.remove('RandomDownsample')
# transforms_names.append('RandomDownsample')

# transforms = []
# for transform_name in transforms_names:
# transform = getattr(torchio, transform_name)()
# transforms.append(transform)
# composed_transform = torchio.Compose(transforms)
# with warnings.catch_warnings(): # ignore elastic deformation warning
# warnings.simplefilter('ignore', UserWarning)
# transformed = composed_transform(sample)

# new_transforms = []
# for transform_name, params_dict in transformed.history:
# transform_class = getattr(torchio, transform_name)
# transform = transform_class(seed=params_dict['seed'])
# new_transforms.append(transform)
# composed_transform = torchio.Compose(transforms)
# with warnings.catch_warnings(): # ignore elastic deformation warning
# warnings.simplefilter('ignore', UserWarning)
# new_transformed = composed_transform(sample)

# self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
# self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
14 changes: 9 additions & 5 deletions tests/utils.py
Expand Up @@ -4,6 +4,7 @@
import tempfile
import unittest
from pathlib import Path
import torch
import numpy as np
import nibabel as nib
from torchio.datasets import IXITiny
Expand Down Expand Up @@ -54,6 +55,14 @@ def setUp(self):
self.dataset = ImagesDataset(self.subjects_list)
self.sample = self.dataset[-1]

def tearDown(self):
"""Tear down test fixtures, if any."""
print('Deleting', self.dir)
shutil.rmtree(self.dir)

def assertTensorEqual(self, a, b):
assert torch.all(torch.eq(a, b))

def make_2d(self, sample):
sample = copy.deepcopy(sample)
for image in sample.get_images(intensity_only=False):
Expand Down Expand Up @@ -85,11 +94,6 @@ def get_reference_image_and_path(self):
image = Image(path, INTENSITY)
return image, path

def tearDown(self):
"""Tear down test fixtures, if any."""
print('Deleting', self.dir)
shutil.rmtree(self.dir)

def get_ixi_tiny(self):
root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
return IXITiny(root_dir, download=True)
Expand Down
3 changes: 1 addition & 2 deletions torchio/cli.py
Expand Up @@ -51,14 +51,13 @@ def apply_transform(
raise ValueError(message) from error

params_dict = get_params_dict_from_kwargs(kwargs)
if issubclass(transform_class, RandomTransform):
params_dict['seed'] = seed
transform = transform_class(**params_dict)
apply_transform_to_file(
input_path,
transform,
output_path,
verbose=verbose,
seed=seed,
)
return 0

Expand Down
2 changes: 1 addition & 1 deletion torchio/data/queue.py
Expand Up @@ -54,7 +54,7 @@ class Queue(Dataset):
>>> patch_size = 96
>>> queue_length = 300
>>> samples_per_volume = 10
>>> sample = torchio.data.UniformSampler(patch_size)
>>> sampler = torchio.data.UniformSampler(patch_size)
>>> patches_queue = torchio.Queue(
... subjects_dataset, # instance of torchio.ImagesDataset
... queue_length,
Expand Down

0 comments on commit 169c235

Please sign in to comment.