Skip to content

Commit

Permalink
Add RandomDownsample transform (#225)
Browse files Browse the repository at this point in the history
* Add some code for random anisotropy

* Add spacing property to Subject

* Use new nomenclature to avoid deprecated method

* Add option to remove citation message

* Add verbosity option to CLI tool

* Rename new transform

* Add copy flag to transform base class

* Add RandomDownsample to tests

* Add comment about copying subject

* Tell Resample not to make a copy of the input

* Improve coverage of utils module

* Add RandomDownsample to docs
  • Loading branch information
fepegar committed Jul 11, 2020
1 parent 0661035 commit 5d27296
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 55 deletions.
11 changes: 10 additions & 1 deletion docs/source/transforms/augmentation.rst
Expand Up @@ -65,12 +65,21 @@ Spatial
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. image:: ../../images/random_elastic_deformation.gif
:alt: Random elastic deformation
:alt: Random elastic deformation

.. autoclass:: RandomElasticDeformation
:show-inheritance:


:class:`RandomDownsample`
^^^^^^^^^^^^^^^^^^^^^^^^^

.. image:: https://user-images.githubusercontent.com/12688084/87075276-fe9d9d00-c217-11ea-81a4-db0cac163ce7.png
:alt: Simulation of an image with highly anisotropic spacing

.. autoclass:: RandomDownsample
:show-inheritance:


Intensity
---------
Expand Down
48 changes: 19 additions & 29 deletions tests/test_utils.py
Expand Up @@ -5,46 +5,21 @@
import unittest
import torch
import numpy as np
from torchio import LABEL, INTENSITY
import SimpleITK as sitk
from torchio import LABEL, INTENSITY, RandomFlip
from torchio.utils import (
to_tuple,
get_stem,
guess_type,
sitk_to_nib,
apply_transform_to_file,
)
from .utils import TorchioTestCase


class TestUtils(TorchioTestCase):
"""Tests for `utils` module."""

def get_sample(self, consistent):
shape = 1, 10, 20, 30
affine = np.diag((1, 2, 3, 1))
affine[:3, 3] = 40, 50, 60
shape2 = 1, 20, 10, 30
sample = {
't1': dict(
data=self.getRandomData(shape),
affine=affine,
type=INTENSITY,
),
't2': dict(
data=self.getRandomData(shape if consistent else shape2),
affine=affine,
type=INTENSITY,
),
'label': dict(
data=(self.getRandomData(shape) > 0.5).float(),
affine=affine,
type=LABEL,
),
}
return sample

@staticmethod
def getRandomData(shape):
return torch.rand(*shape)

def test_to_tuple(self):
assert to_tuple(1) == (1,)
assert to_tuple((1,)) == (1,)
Expand Down Expand Up @@ -73,3 +48,18 @@ def test_check_consistent_shape(self):
good_sample.check_consistent_shape()
with self.assertRaises(ValueError):
bad_sample.check_consistent_shape()

def test_apply_transform_to_file(self):
transform = RandomFlip()
apply_transform_to_file(
self.get_image_path('input'),
transform,
self.get_image_path('output'),
verbose=True,
)

def test_sitk_to_nib(self):
data = np.random.rand(10, 10)
image = sitk.GetImageFromArray(data)
tensor, affine = sitk_to_nib(image)
self.assertAlmostEqual(data.sum(), tensor.sum())
1 change: 1 addition & 0 deletions tests/transforms/test_transforms.py
Expand Up @@ -23,6 +23,7 @@ def get_transform(self, channels, is_3d=True):
transforms = (
torchio.CropOrPad(cp_args),
torchio.ToCanonical(),
torchio.RandomDownsample(),
torchio.Resample((1, 1.1, 1.25)),
torchio.RandomFlip(axes=flip_axes, flip_probability=1),
torchio.RandomMotion(),
Expand Down
5 changes: 3 additions & 2 deletions tests/utils.py
Expand Up @@ -105,9 +105,10 @@ def get_image_path(
if binary:
data = (data > 0.5).astype(np.uint8)
affine = np.diag((*spacing, 1))
suffix = random.choice(('.nii.gz', '.nii'))
suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.minc', '.img'))
path = self.dir / f'{stem}{suffix}'
nib.Nifti1Image(data, affine).to_filename(str(path))
if np.random.rand() > 0.5:
path = str(path)
image = Image(tensor=data, affine=affine)
image.save(path)
return path
14 changes: 10 additions & 4 deletions torchio/__init__.py
Expand Up @@ -4,14 +4,20 @@
__email__ = 'fernando.perezgarcia.17@ucl.ac.uk'
__version__ = '0.17.9'

import os
from . import utils
from .torchio import *
from .transforms import *
from .data import io, sampler, inference, ImagesDataset, Image, Queue, Subject
from . import datasets
from . import reference

print('If you use TorchIO for your research, please cite the following paper:')
print('Pérez-García et al., TorchIO: a Python library for efficient loading,')
print('preprocessing, augmentation and patch-based sampling of medical images')
print('in deep learning. Link: https://arxiv.org/abs/2003.04696\n')
CITATION = """If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696
"""

# Thanks for citing torchio. Without citations, researchers will not use TorchIO
if 'TORCHIO_HIDE_CITATION_PROMPT' not in os.environ:
print(CITATION)
8 changes: 8 additions & 0 deletions torchio/cli.py
Expand Up @@ -19,12 +19,19 @@
type=int,
help='Seed for PyTorch random number generator.',
)
@click.option(
'--verbose/--no-verbose', '-v',
type=bool,
default=False,
help='Print random transform parameters.',
)
def apply_transform(
input_path,
transform_name,
output_path,
kwargs,
seed,
verbose,
):
"""Apply transform to an image.
Expand All @@ -51,6 +58,7 @@ def apply_transform(
input_path,
transform,
output_path,
verbose=verbose,
)
return 0

Expand Down
2 changes: 1 addition & 1 deletion torchio/data/dataset.py
Expand Up @@ -101,7 +101,7 @@ def __getitem__(self, index: int) -> dict:
if not isinstance(index, int):
raise ValueError(f'Index "{index}" must be int, not {type(index)}')
subject = self.subjects[index]
sample = copy.deepcopy(subject)
sample = copy.deepcopy(subject) # cheap since images not loaded yet
sample.load()

# Apply transform (this is usually the bottleneck)
Expand Down
18 changes: 14 additions & 4 deletions torchio/data/subject.py
Expand Up @@ -68,22 +68,32 @@ def _parse_images(images: List[Tuple[str, Image]]) -> None:

@property
def shape(self):
"""Return shape of first image in sample.
"""Return shape of first image in subject.
Consistency of shapes across images in the sample is checked first.
Consistency of shapes across images in the subject is checked first.
"""
self.check_consistent_shape()
image = self.get_images(intensity_only=False)[0]
return image.shape

@property
def spatial_shape(self):
"""Return spatial shape of first image in sample.
"""Return spatial shape of first image in subject.
Consistency of shapes across images in the sample is checked first.
Consistency of shapes across images in the subject is checked first.
"""
return self.shape[1:]

@property
def spacing(self):
"""Return spacing of first image in subject.
Consistency of shapes across images in the subject is checked first.
"""
self.check_consistent_shape()
image = self.get_images(intensity_only=False)[0]
return image.spacing

def get_images_dict(self, intensity_only=True):
images = {}
for image_name, image in self.items():
Expand Down
1 change: 1 addition & 0 deletions torchio/transforms/__init__.py
Expand Up @@ -10,6 +10,7 @@

from .augmentation.spatial import RandomFlip
from .augmentation.spatial import RandomAffine
from .augmentation.spatial import RandomDownsample
from .augmentation.spatial import RandomElasticDeformation

from .augmentation.intensity import RandomSwap
Expand Down
3 changes: 1 addition & 2 deletions torchio/transforms/augmentation/intensity/random_ghosting.py
Expand Up @@ -20,8 +20,7 @@ class RandomGhosting(RandomTransform):
intensity: Positive number representing the artifact strength
:math:`s` with respect to the maximum of the :math:`k`-space.
If ``0``, the ghosts will not be visible. If a tuple
:math:`(a, b)`, is provided then
:math:`s \sim \mathcal{U}(a, b)`.
:math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`.
restore: Number between ``0`` and ``1`` indicating how much of the
:math:`k`-space center should be restored after removing the planes
that generate the artifact.
Expand Down
1 change: 1 addition & 0 deletions torchio/transforms/augmentation/spatial/__init__.py
@@ -1,3 +1,4 @@
from .random_flip import RandomFlip
from .random_affine import RandomAffine
from .random_downsample import RandomDownsample
from .random_elastic_deformation import RandomElasticDeformation
80 changes: 80 additions & 0 deletions torchio/transforms/augmentation/spatial/random_downsample.py
@@ -0,0 +1,80 @@
from typing import Union, Tuple, Optional, List
import torch
from ....torchio import DATA
from ....data.subject import Subject
from ....utils import to_tuple
from .. import RandomTransform
from ...preprocessing import Resample


class RandomDownsample(RandomTransform):
"""Downsample an image along an axis.
This transform simulates an image that has been acquired using anisotropic
spacing, using downsampling with nearest neighbor interpolation.
Args:
axes: Axis or tuple of axes along which the image will be downsampled.
downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
:math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""

def __init__(
self,
axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
downsampling: float = (1.5, 5),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
self.axes = self.parse_axes(axes)
self.downsampling_range = self.parse_downsampling(downsampling)

@staticmethod
def get_params(
axes: Tuple[int, ...],
downsampling_range: Tuple[float, float],
) -> List[bool]:
axis = axes[torch.randint(0, len(axes), (1,))]
downsampling = torch.FloatTensor(1).uniform_(*downsampling_range).item()
return axis, downsampling

@staticmethod
def parse_downsampling(downsampling_factor):
try:
iter(downsampling_factor)
except TypeError:
downsampling_factor = downsampling_factor, downsampling_factor
for n in downsampling_factor:
if n <= 1:
message = (
f'Downsampling factor must be a number > 1, not {n}')
raise ValueError(message)
return downsampling_factor

@staticmethod
def parse_axes(axes: Union[int, Tuple[int, ...]]):
axes_tuple = to_tuple(axes)
for axis in axes_tuple:
is_int = isinstance(axis, int)
if not is_int or axis not in (0, 1, 2):
raise ValueError('All axes must be 0, 1 or 2')
return axes_tuple

def apply_transform(self, sample: Subject) -> Subject:
axis, downsampling = self.get_params(self.axes, self.downsampling_range)
random_parameters_dict = {'axis': axis, 'downsampling': downsampling}
items = sample.get_images_dict(intensity_only=False).items()

target_spacing = list(sample.spacing)
target_spacing[axis] *= downsampling
transform = Resample(
tuple(target_spacing),
image_interpolation='nearest',
copy=False, # already copied in super().__init__
)
sample = transform(sample)
sample.add_transform(self, random_parameters_dict)
return sample
3 changes: 2 additions & 1 deletion torchio/transforms/preprocessing/spatial/resample.py
Expand Up @@ -66,8 +66,9 @@ def __init__(
image_interpolation: str = 'linear',
pre_affine_name: Optional[str] = None,
p: float = 1,
copy: bool = True,
):
super().__init__(p=p)
super().__init__(p=p, copy=copy)
self.reference_image, self.target_spacing = self.parse_target(target)
self.interpolation_order = self.parse_interpolation(image_interpolation)
self.affine_name = pre_affine_name
Expand Down
12 changes: 5 additions & 7 deletions torchio/transforms/transform.py
Expand Up @@ -27,9 +27,11 @@ class Transform(ABC):
Args:
p: Probability that this transform will be applied.
copy: Make a deep copy of the input before applying the transform.
"""
def __init__(self, p: float = 1):
def __init__(self, p: float = 1, copy: bool = True):
self.probability = self.parse_probability(p)
self.copy = copy

def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray]):
"""Transform a sample and return the result.
Expand All @@ -53,9 +55,7 @@ def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray]):
sample = data
self.parse_sample(sample)

# If the input is a tensor, it will be deepcopied when calling
# ImagesDataset.__getitem__
if not is_tensor:
if self.copy:
sample = deepcopy(sample)

with np.errstate(all='raise'):
Expand Down Expand Up @@ -147,9 +147,7 @@ def _get_subject_from_tensor(tensor: torch.Tensor) -> Subject:
image = Image(tensor=channel_tensor, type=INTENSITY)
subject_dict[name] = image
subject = Subject(subject_dict)
dataset = ImagesDataset([subject])
sample = dataset[0]
return sample
return subject

@staticmethod
def nib_to_sitk(data: TypeData, affine: TypeData):
Expand Down
9 changes: 5 additions & 4 deletions torchio/utils.py
Expand Up @@ -116,13 +116,14 @@ def apply_transform_to_file(
transform, # : Transform seems to create a circular import (TODO)
output_path: TypePath,
type: str = INTENSITY,
verbose: bool = False,
):
from . import Image, ImagesDataset, Subject
subject = Subject(image=Image(input_path, type))
dataset = ImagesDataset([subject], transform=transform)
transformed = dataset[0]
dataset.save_sample(transformed, dict(image=output_path))

transformed = transform(subject)
transformed.image.save(output_path)
if verbose and transformed.history:
print(transformed.history[0])

def guess_type(string: str) -> Any:
# Adapted from
Expand Down

0 comments on commit 5d27296

Please sign in to comment.