Skip to content

Commit

Permalink
Merge 3bb5d67 into b09abfe
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar authored Aug 9, 2020
2 parents b09abfe + 3bb5d67 commit f948d84
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 163 deletions.
116 changes: 0 additions & 116 deletions docs/source/example.rst

This file was deleted.

10 changes: 2 additions & 8 deletions tests/transforms/augmentation/test_random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@ class TestRandomFlip(TorchioTestCase):
"""Tests for `RandomFlip`."""
def test_2d(self):
sample = self.make_2d(self.sample)
transform = RandomFlip(axes=(0, 1), flip_probability=1)
transform = RandomFlip(axes=(1, 2), flip_probability=1)
transformed = transform(sample)
assert_array_equal(
sample.t1.data.numpy()[:, :, ::-1, ::-1],
transformed.t1.data.numpy())

def test_wrong_axes(self):
sample = self.make_2d(self.sample)
transform = RandomFlip(axes=2, flip_probability=1)
with self.assertRaises(RuntimeError):
transform(sample)

def test_out_of_range_axis(self):
with self.assertRaises(ValueError):
RandomFlip(axes=3)
Expand All @@ -29,7 +23,7 @@ def test_out_of_range_axis_in_tuple(self):

def test_wrong_axes_type(self):
with self.assertRaises(ValueError):
RandomFlip(axes='wrong')
RandomFlip(axes=None)

def test_wrong_flip_probability_type(self):
with self.assertRaises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion tests/transforms/augmentation/test_random_ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_out_of_range_axis_in_tuple(self):

def test_wrong_axes_type(self):
with self.assertRaises(ValueError):
RandomGhosting(axes='wrong')
RandomGhosting(axes=None)

def test_out_of_range_restore(self):
with self.assertRaises(ValueError):
Expand Down
47 changes: 47 additions & 0 deletions torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,53 @@ def spacing(self):
def memory(self):
return np.prod(self.shape) * 4 # float32, i.e. 4 bytes per voxel

def axis_name_to_index(self, axis: str):
"""Convert an axis name to an axis index.
Args:
axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
and first letters are also valid, as only the first letter will be
used.
.. note:: If you are working with animals, you should probably use
``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
respectively.
"""
if not isinstance(axis, str):
raise ValueError('Axis must be a string')
axis = axis[0].upper()

# Generally, TorchIO tensors are (C, D, H, W)
if axis == 'H':
return -2
elif axis == 'W':
return -1
else:
try:
index = self.orientation.index(axis)
except ValueError:
index = self.orientation.index(self.flip_axis(axis))
# Return negative indices so that it does not matter whether we
# refer to spatial dimensions or not
index = -4 + index
return index

@staticmethod
def flip_axis(axis):
if axis == 'R': return 'L'
elif axis == 'L': return 'R'
elif axis == 'A': return 'P'
elif axis == 'P': return 'A'
elif axis == 'I': return 'S'
elif axis == 'S': return 'I'
else:
message = (
f'Axis not understood. Please use a value in {tuple("LRAPIS")}'
)
raise ValueError(message)

def get_spacing_string(self):
strings = [f'{n:.2f}' for n in self.spacing]
string = f'({", ".join(strings)})'
Expand Down
18 changes: 17 additions & 1 deletion torchio/data/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def get_images(self, intensity_only=True):
images_dict = self.get_images_dict(intensity_only=intensity_only)
return list(images_dict.values())

def get_first_image(self):
return self.get_images(intensity_only=False)[0]

def check_consistent_shape(self) -> None:
shapes_dict = {}
iterable = self.get_images_dict(intensity_only=False).items()
Expand All @@ -128,11 +131,24 @@ def check_consistent_shape(self) -> None:
num_unique_shapes = len(set(shapes_dict.values()))
if num_unique_shapes > 1:
message = (
'Images in sample have inconsistent shapes:'
'Images in subject have inconsistent shapes:'
f'\n{pprint.pformat(shapes_dict)}'
)
raise ValueError(message)

def check_consistent_orientation(self) -> None:
orientations_dict = {}
iterable = self.get_images_dict(intensity_only=False).items()
for image_name, image in iterable:
orientations_dict[image_name] = image.orientation
num_unique_orientations = len(set(orientations_dict.values()))
if num_unique_orientations > 1:
message = (
'Images in subject have inconsistent orientations:'
f'\n{pprint.pformat(orientations_dict)}'
)
raise ValueError(message)

def add_transform(
self,
transform: 'Transform',
Expand Down
13 changes: 11 additions & 2 deletions torchio/transforms/augmentation/intensity/random_ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class RandomGhosting(RandomTransform):
:math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
axes: Axis along which the ghosts will be created. If
:py:attr:`axes` is a tuple, the axis will be randomly chosen
from the passed values.
from the passed values. Anatomical labels may also be used (see
:py:class:`~torchio.transforms.augmentation.RandomFlip`).
intensity: Positive number representing the artifact strength
:math:`s` with respect to the maximum of the :math:`k`-space.
If ``0``, the ghosts will not be visible. If a tuple
Expand All @@ -40,6 +41,10 @@ class RandomGhosting(RandomTransform):
.. note:: The execution time of this transform does not depend on the
number of ghosts.
.. warning:: Note that height and width of 2D images correspond to axes
``1`` and ``2`` respectively, as TorchIO images are generally considered
to have 3 spatial dimensions.
"""
def __init__(
self,
Expand All @@ -58,7 +63,7 @@ def __init__(
except TypeError:
axes = (axes,)
for axis in axes:
if axis not in (0, 1, 2):
if not isinstance(axis, str) and axis not in (0, 1, 2):
raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
self.axes = axes
self.num_ghosts_range = self.parse_range(
Expand All @@ -79,6 +84,10 @@ def parse_restore(restore):

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
axes_string = False
if any(isinstance(n, str) for n in self.axes):
sample.check_consistent_orientation()
axes_string = True
for image_name, image in sample.get_images_dict().items():
transformed_tensors = []
is_2d = image.is_2d()
Expand Down
74 changes: 41 additions & 33 deletions torchio/transforms/augmentation/spatial/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@ class RandomFlip(RandomTransform):
"""Reverse the order of elements in an image along the given axes.
Args:
axes: Axis or tuple of axes along which the image will be flipped.
axes: Index or tuple of indices of the spatial dimensions along which
the image might be flipped. If they are integers, they must be in
``(0, 1, 2)``. Anatomical labels may also be used, such as
``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
``'i'`` (inferior). Only the first letter of the string will be
used.
flip_probability: Probability that the image will be flipped. This is
computed on a per-axis basis.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
keys: See :py:class:`~torchio.transforms.Transform`.
.. note:: If the input image is 2D, all axes should be in ``(0, 1)``.
Example:
>>> import torchio as tio
>>> fpg = tio.datasets.FPG()
>>> flip = tio.RandomFlip(axes=('LR')) # flip along lateral axis only
.. tip:: It is handy to specify the axes as anatomical labels when the image
orientation is not known.
.. warning:: Note that height and width of 2D images correspond to axes
``1`` and ``2`` respectively, as TorchIO images are generally considered
to have 3 spatial dimensions.
"""

def __init__(
Expand All @@ -36,53 +53,44 @@ def __init__(
)

def apply_transform(self, sample: Subject) -> dict:
axes_to_flip_hot = self.get_params(self.axes, self.flip_probability)
axes = self.axes
axes_to_flip_hot = self.get_params(self.flip_probability)
if any(isinstance(n, str) for n in axes):
sample.check_consistent_orientation()
image = sample.get_first_image()
axes = sorted([4 + image.axis_name_to_index(n) for n in axes])
for i in range(3):
if i not in axes:
axes_to_flip_hot[i] = False
random_parameters_dict = {'axes': axes_to_flip_hot}
items = sample.get_images_dict(intensity_only=False).items()
for image_name, image_dict in items:
data = image_dict[DATA]
is_2d = data.shape[-3] == 1
for image_name, image in items:
dims = []
for dim, flip_this in enumerate(axes_to_flip_hot):
if not flip_this:
continue
actual_dim = dim + 1 # images are 4D
# If the user is using 2D images and they use (0, 1) for axes,
# they probably mean (1, 2). This should make this transform
# more user-friendly.
if is_2d:
actual_dim += 1
if actual_dim > 3:
message = (
f'Image "{image_name}" with shape {data.shape} seems to'
' be 2D, so all axes must be in (0, 1),'
f' but they are {self.axes}'
)
raise RuntimeError(message)
dims.append(actual_dim)
# data = torch.flip(data, dims=dims)
data = data.numpy()
data = np.flip(data, axis=dims)
data = data.copy() # remove negative strides
data = torch.from_numpy(data)
image_dict[DATA] = data
if dims:
data = image.numpy()
data = np.flip(data, axis=dims)
data = data.copy() # remove negative strides
data = torch.from_numpy(data)
image[DATA] = data
sample.add_transform(self, random_parameters_dict)
return sample

@staticmethod
def get_params(axes: Tuple[int, ...], probability: float) -> List[bool]:
axes_hot = [False, False, False]
for axis in axes:
random_number = torch.rand(1)
flip_this = bool(probability > random_number)
axes_hot[axis] = flip_this
return axes_hot
def get_params(probability: float) -> List[bool]:
return (probability > torch.rand(3)).tolist()

@staticmethod
def parse_axes(axes: Union[int, Tuple[int, ...]]):
axes_tuple = to_tuple(axes)
for axis in axes_tuple:
is_int = isinstance(axis, int)
if not is_int or axis not in (0, 1, 2):
raise ValueError('All axes must be 0, 1 or 2')
is_string = isinstance(axis, str)
if not is_string and not (is_int and axis in (0, 1, 2)):
message = f'All axes must be 0, 1 or 2, but found "{axis}"'
raise ValueError(message)
return axes_tuple
Loading

0 comments on commit f948d84

Please sign in to comment.