Skip to content

Commit

Permalink
Merge ec19bc3 into b9ac52d
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Dec 29, 2020
2 parents b9ac52d + ec19bc3 commit c1a29fa
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 4 deletions.
7 changes: 7 additions & 0 deletions docs/source/transforms/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ Spatial
:members: _get_six_bounds_parameters


:class:`EnsureShapeMultiple`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: EnsureShapeMultiple
:show-inheritance:


:class:`Crop`
~~~~~~~~~~~~~

Expand Down
31 changes: 31 additions & 0 deletions tests/transforms/preprocessing/test_ensure_shape_multiple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torchio as tio
from ...utils import TorchioTestCase


class TestEnsureShapeMultiple(TorchioTestCase):

def test_bad_method(self):
with self.assertRaises(ValueError):
tio.EnsureShapeMultiple(1, method='bad')

def test_pad(self):
sample_t1 = self.sample_subject.t1
assert sample_t1.shape == (1, 10, 20, 30)
transform = tio.EnsureShapeMultiple(4, method='pad')
transformed = transform(sample_t1)
assert transformed.shape == (1, 12, 20, 32)

def test_crop(self):
sample_t1 = self.sample_subject.t1
assert sample_t1.shape == (1, 10, 20, 30)
transform = tio.EnsureShapeMultiple(4, method='crop')
transformed = transform(sample_t1)
assert transformed.shape == (1, 8, 20, 28)

def test_2d(self):
sample_t1 = self.sample_subject.t1
sample_2d = sample_t1.data[..., :1]
assert sample_2d.shape == (1, 10, 20, 1)
transform = tio.EnsureShapeMultiple(4, method='crop')
transformed = transform(sample_2d)
assert transformed.shape == (1, 8, 20, 1)
1 change: 1 addition & 0 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_transform(self, channels, is_3d=True, labels=True):
tio.CropOrPad(cp_args),
tio.ToCanonical(),
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
tio.EnsureShapeMultiple(2, method='crop'),
tio.Resample((1, 1.1, 1.25)),
tio.RandomFlip(axes=flip_axes, flip_probability=1),
tio.RandomMotion(),
Expand Down
2 changes: 2 additions & 0 deletions torchio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .preprocessing import ToCanonical
from .preprocessing import ZNormalization
from .preprocessing import RescaleIntensity
from .preprocessing import EnsureShapeMultiple
from .preprocessing import HistogramStandardization
from .preprocessing.intensity.histogram_standardization import train as train_histogram

Expand Down Expand Up @@ -78,5 +79,6 @@
'HistogramStandardization',
'RescaleIntensity',
'CropOrPad',
'EnsureShapeMultiple',
'train_histogram',
]
4 changes: 3 additions & 1 deletion torchio/transforms/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .spatial.resample import Resample
from .spatial.crop_or_pad import CropOrPad
from .spatial.to_canonical import ToCanonical
from .spatial.ensure_shape_multiple import EnsureShapeMultiple

from .intensity.rescale import RescaleIntensity
from .intensity.z_normalization import ZNormalization
Expand All @@ -15,7 +16,8 @@
'Resample',
'ToCanonical',
'CropOrPad',
'RescaleIntensity',
'EnsureShapeMultiple',
'ZNormalization',
'RescaleIntensity',
'HistogramStandardization',
]
62 changes: 62 additions & 0 deletions torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Union, Optional

import numpy as np

from ... import SpatialTransform
from ....utils import to_tuple
from ....data.subject import Subject
from ....typing import TypeTripletInt
from .crop_or_pad import CropOrPad


class EnsureShapeMultiple(SpatialTransform):
"""Crop or pad an image to a shape that is a multiple of :math:`N`.
Args:
target_multiple: Tuple :math:`(w, h, d)`. If a single value :math:`n` is
provided, then :math:`w = h = d = n`.
method: Either ``'crop'`` or ``'pad'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Example:
>>> import torchio as tio
>>> image = tio.datasets.Colin27().t1
>>> image.shape
(1, 181, 217, 181)
>>> transform = tio.EnsureShapeMultiple(8, method='pad')
>>> transformed = transform(image)
>>> transformed.shape
(1, 184, 224, 184)
>>> transform = tio.EnsureShapeMultiple(8, method='crop')
>>> transformed = transform(image)
>>> transformed.shape
(1, 176, 216, 176)
>>> image_2d = image.data[..., :1]
>>> image_2d.shape
torch.Size([1, 181, 217, 1])
>>> transformed = transform(image_2d)
>>> transformed.shape
torch.Size([1, 176, 216, 1])
"""
def __init__(
self,
target_multiple: Union[int, TypeTripletInt],
*,
method: Optional[str] = 'pad',
**kwargs
):
super().__init__(**kwargs)
self.target_multiple = np.array(to_tuple(target_multiple, 3))
if method not in ('crop', 'pad'):
raise ValueError('Method must be "crop" or "pad"')
self.method = method

def apply_transform(self, subject: Subject) -> Subject:
source_shape = np.array(subject.spatial_shape, np.uint16)
function = np.floor if self.method == 'crop' else np.ceil
integer_ratio = function(source_shape / self.target_multiple)
target_shape = integer_ratio * self.target_multiple
target_shape = np.maximum(target_shape, 1)
return CropOrPad(target_shape.astype(int))(subject)
12 changes: 9 additions & 3 deletions torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ def apply_transform(self, subject: Subject):

def add_transform_to_subject_history(self, subject):
from .augmentation import RandomTransform
from . import Compose, OneOf, CropOrPad
from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
call_others = (
RandomTransform,
Compose,
OneOf,
CropOrPad,
EnsureShapeMultiple,
)
if not isinstance(self, call_others):
subject.add_transform(self, self._get_reproducing_arguments())
Expand Down Expand Up @@ -311,8 +312,13 @@ def _get_reproducing_arguments(self):
Return a dictionary with the arguments that would be necessary to
reproduce the transform exactly.
"""
reproducing_arguments = dict(include=self.include, exclude=self.exclude, copy=self.copy)
reproducing_arguments.update({name: getattr(self, name) for name in self.args_names})
reproducing_arguments = dict(
include=self.include,
exclude=self.exclude,
copy=self.copy,
)
args_names = {name: getattr(self, name) for name in self.args_names}
reproducing_arguments.update(args_names)
return reproducing_arguments

def is_invertible(self):
Expand Down

0 comments on commit c1a29fa

Please sign in to comment.