Skip to content

Commit

Permalink
Merge 4e8bc09 into a1abc16
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed May 23, 2020
2 parents a1abc16 + 4e8bc09 commit 93d80ab
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 20 deletions.
17 changes: 17 additions & 0 deletions tests/transforms/augmentation/test_random_flip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torchio
from ...utils import TorchioTestCase


class TestRandomFlip(TorchioTestCase):
"""Tests for `RandomFlip`."""
def test_2d(self):
sample = self.make_2d(self.sample)
transform = torchio.transforms.RandomFlip(
axes=(0, 1), flip_probability=1)
transform(sample)

def test_wrong_axes(self):
sample = self.make_2d(self.sample)
transform = torchio.transforms.RandomFlip(axes=2, flip_probability=1)
with self.assertRaises(RuntimeError):
transform(sample)
19 changes: 12 additions & 7 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
class TestTransforms(TorchioTestCase):
"""Tests for all transforms."""

def get_transform(self, channels):
def get_transform(self, channels, is_3d=True):
landmarks_dict = {
channel: np.linspace(0, 100, 13) for channel in channels
}
elastic = torchio.RandomElasticDeformation(max_displacement=1)
transforms = (
torchio.CropOrPad((9, 21, 30)),
torchio.CropOrPad((9, 21, 30)) if is_3d else torchio.CropOrPad((1, 21, 30)),
torchio.ToCanonical(),
torchio.Resample((1, 1.1, 1.25)),
torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1) if is_3d else torchio.RandomFlip(axes=(0, 1), flip_probability=1),
torchio.RandomMotion(),
torchio.RandomGhosting(axes=(0, 1, 2)),
torchio.RandomSpike(),
Expand All @@ -38,11 +38,16 @@ def get_transform(self, channels):
)
return torchio.Compose(transforms)

def test_transforms_sample(self):
transform = self.get_transform(channels=('t1', 't2'))
transform(self.sample)

def test_transforms_tensor(self):
tensor = torch.rand(2, 4, 5, 8)
transform = self.get_transform(channels=('channel_0', 'channel_1'))
transform(tensor)

def test_transforms_sample_3d(self):
transform = self.get_transform(channels=('t1', 't2'))
transform(self.sample)

def test_transforms_sample_2d(self):
transform = self.get_transform(channels=('t1', 't2'))
sample = self.make_2d(self.sample)
transform(sample)
9 changes: 8 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import shutil
import random
import tempfile
Expand All @@ -6,7 +7,7 @@
import numpy as np
import nibabel as nib
from torchio.datasets import IXITiny
from torchio import INTENSITY, LABEL, Image, ImagesDataset, Subject
from torchio import INTENSITY, LABEL, DATA, Image, ImagesDataset, Subject


class TorchioTestCase(unittest.TestCase):
Expand Down Expand Up @@ -53,6 +54,12 @@ def setUp(self):
self.dataset = ImagesDataset(self.subjects_list)
self.sample = self.dataset[-1]

def make_2d(self, sample):
sample = copy.deepcopy(sample)
for image in sample.get_images(intensity_only=False):
image[DATA] = image[DATA][:, 0:1, ...]
return sample

def get_inconsistent_sample(self):
"""Return a sample containing images of different shape."""
subject = Subject(
Expand Down
9 changes: 5 additions & 4 deletions torchio/data/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ def get_random_indices_from_shape(
shape: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
) -> Tuple[np.ndarray, np.ndarray]:
shape_array = np.array(shape, dtype=np.uint16)
patch_size_array = np.array(patch_size, dtype=np.uint16)
shape_array = np.array(shape)
patch_size_array = np.array(patch_size)
max_index_ini = shape_array - patch_size_array
if (max_index_ini < 0).any():
message = (
f'Patch size {patch_size_array} must not be'
f' larger than image size {tuple(shape_array)}'
f'Patch size {patch_size} must not be'
f' larger than image size {shape}'
)
raise ValueError(message)
max_index_ini = max_index_ini.astype(np.uint16)
coordinates = []
for max_coordinate in max_index_ini.tolist():
if max_coordinate == 0:
Expand Down
6 changes: 3 additions & 3 deletions torchio/transforms/augmentation/intensity/random_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def get_params(
degrees_range, num_transforms)
translation_params = get_params_array(
translation_range, num_transforms)
if is_2d:
degrees_params[:-1] = 0 # rotate around z axis only
translation_params[-1] = 0 # translate in xy plane only
if is_2d: # imagine sagittal (1, A, S)
degrees_params[:, -2:] = 0 # rotate around R axis only
translation_params[:, 0] = 0 # translate in AS plane only
step = 1 / (num_transforms + 1)
times = torch.arange(0, 1, step)[1:]
noise = torch.FloatTensor(num_transforms)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,21 @@ def apply_transform(self, sample: Subject) -> dict:
self.max_displacement,
self.num_locked_borders,
)
random_parameters_dict = {'coarse_grid': bspline_params}
for image_dict in sample.get_images(intensity_only=False):
if image_dict[TYPE] == LABEL:
interpolation = Interpolation.NEAREST
else:
interpolation = self.interpolation
is_2d = image_dict[DATA].shape[-3] == 1
if is_2d:
bspline_params[..., -3] = 0 # no displacement in LR axis
image_dict[DATA] = self.apply_bspline_transform(
image_dict[DATA],
image_dict[AFFINE],
bspline_params,
interpolation,
)
random_parameters_dict = {'coarse_grid': bspline_params}
sample.add_transform(self, random_parameters_dict)
return sample

Expand Down
24 changes: 20 additions & 4 deletions torchio/transforms/augmentation/spatial/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class RandomFlip(RandomTransform):
computed on a per-axis basis.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. note:: If the input image is 2D, all axes should be in ``(0, 1)``.
"""

def __init__(
Expand All @@ -33,16 +35,30 @@ def __init__(
def apply_transform(self, sample: Subject) -> dict:
axes_to_flip_hot = self.get_params(self.axes, self.flip_probability)
random_parameters_dict = {'axes': axes_to_flip_hot}
for image_dict in sample.get_images(intensity_only=False):
tensor = image_dict[DATA]
items = sample.get_images_dict(intensity_only=False).items()
for image_name, image_dict in items:
data = image_dict[DATA]
is_2d = data.shape[-3] == 1
dims = []
for dim, flip_this in enumerate(axes_to_flip_hot):
if not flip_this:
continue
actual_dim = dim + 1 # images are 4D
# If the user is using 2D images and they use (0, 1) for axes,
# they probably mean (1, 2). This should make this transform
# more user-friendly.
if is_2d:
actual_dim += 1
if actual_dim > 3:
message = (
f'Image "{image_name}" with shape {data.shape} seems to'
' be 2D, so all axes must be in (0, 1),'
f' but they are {self.axes}'
)
raise RuntimeError(message)
dims.append(actual_dim)
tensor = torch.flip(tensor, dims=dims)
image_dict[DATA] = tensor
data = torch.flip(data, dims=dims)
image_dict[DATA] = data
sample.add_transform(self, random_parameters_dict)
return sample

Expand Down

0 comments on commit 93d80ab

Please sign in to comment.