diff --git a/docs/source/transforms/preprocessing.rst b/docs/source/transforms/preprocessing.rst index c26348f88..293b3f636 100644 --- a/docs/source/transforms/preprocessing.rst +++ b/docs/source/transforms/preprocessing.rst @@ -59,6 +59,13 @@ Spatial :members: _get_six_bounds_parameters +:class:`EnsureShapeMultiple` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: EnsureShapeMultiple + :show-inheritance: + + :class:`Crop` ~~~~~~~~~~~~~ diff --git a/tests/transforms/preprocessing/test_ensure_shape_multiple.py b/tests/transforms/preprocessing/test_ensure_shape_multiple.py new file mode 100644 index 000000000..c8a9e4521 --- /dev/null +++ b/tests/transforms/preprocessing/test_ensure_shape_multiple.py @@ -0,0 +1,31 @@ +import torchio as tio +from ...utils import TorchioTestCase + + +class TestEnsureShapeMultiple(TorchioTestCase): + + def test_bad_method(self): + with self.assertRaises(ValueError): + tio.EnsureShapeMultiple(1, method='bad') + + def test_pad(self): + sample_t1 = self.sample_subject.t1 + assert sample_t1.shape == (1, 10, 20, 30) + transform = tio.EnsureShapeMultiple(4, method='pad') + transformed = transform(sample_t1) + assert transformed.shape == (1, 12, 20, 32) + + def test_crop(self): + sample_t1 = self.sample_subject.t1 + assert sample_t1.shape == (1, 10, 20, 30) + transform = tio.EnsureShapeMultiple(4, method='crop') + transformed = transform(sample_t1) + assert transformed.shape == (1, 8, 20, 28) + + def test_2d(self): + sample_t1 = self.sample_subject.t1 + sample_2d = sample_t1.data[..., :1] + assert sample_2d.shape == (1, 10, 20, 1) + transform = tio.EnsureShapeMultiple(4, method='crop') + transformed = transform(sample_2d) + assert transformed.shape == (1, 8, 20, 1) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 33a8ca8e1..b1c89b3e8 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -25,6 +25,7 @@ def get_transform(self, channels, is_3d=True, labels=True): tio.CropOrPad(cp_args), tio.ToCanonical(), tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), + tio.EnsureShapeMultiple(2, method='crop'), tio.Resample((1, 1.1, 1.25)), tio.RandomFlip(axes=flip_axes, flip_probability=1), tio.RandomMotion(), diff --git a/torchio/transforms/__init__.py b/torchio/transforms/__init__.py index 9ce0c9ee8..81ed1b446 100644 --- a/torchio/transforms/__init__.py +++ b/torchio/transforms/__init__.py @@ -33,6 +33,7 @@ from .preprocessing import ToCanonical from .preprocessing import ZNormalization from .preprocessing import RescaleIntensity +from .preprocessing import EnsureShapeMultiple from .preprocessing import HistogramStandardization from .preprocessing.intensity.histogram_standardization import train as train_histogram @@ -78,5 +79,6 @@ 'HistogramStandardization', 'RescaleIntensity', 'CropOrPad', + 'EnsureShapeMultiple', 'train_histogram', ] diff --git a/torchio/transforms/preprocessing/__init__.py b/torchio/transforms/preprocessing/__init__.py index b772884a6..be880c8e9 100644 --- a/torchio/transforms/preprocessing/__init__.py +++ b/torchio/transforms/preprocessing/__init__.py @@ -3,6 +3,7 @@ from .spatial.resample import Resample from .spatial.crop_or_pad import CropOrPad from .spatial.to_canonical import ToCanonical +from .spatial.ensure_shape_multiple import EnsureShapeMultiple from .intensity.rescale import RescaleIntensity from .intensity.z_normalization import ZNormalization @@ -15,7 +16,8 @@ 'Resample', 'ToCanonical', 'CropOrPad', - 'RescaleIntensity', + 'EnsureShapeMultiple', 'ZNormalization', + 'RescaleIntensity', 'HistogramStandardization', ] diff --git a/torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py b/torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py new file mode 100644 index 000000000..1995a3051 --- /dev/null +++ b/torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py @@ -0,0 +1,62 @@ +from typing import Union, Optional + +import numpy as np + +from ... import SpatialTransform +from ....utils import to_tuple +from ....data.subject import Subject +from ....typing import TypeTripletInt +from .crop_or_pad import CropOrPad + + +class EnsureShapeMultiple(SpatialTransform): + """Crop or pad an image to a shape that is a multiple of :math:`N`. + + Args: + target_multiple: Tuple :math:`(w, h, d)`. If a single value :math:`n` is + provided, then :math:`w = h = d = n`. + method: Either ``'crop'`` or ``'pad'``. + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + + Example: + >>> import torchio as tio + >>> image = tio.datasets.Colin27().t1 + >>> image.shape + (1, 181, 217, 181) + >>> transform = tio.EnsureShapeMultiple(8, method='pad') + >>> transformed = transform(image) + >>> transformed.shape + (1, 184, 224, 184) + >>> transform = tio.EnsureShapeMultiple(8, method='crop') + >>> transformed = transform(image) + >>> transformed.shape + (1, 176, 216, 176) + >>> image_2d = image.data[..., :1] + >>> image_2d.shape + torch.Size([1, 181, 217, 1]) + >>> transformed = transform(image_2d) + >>> transformed.shape + torch.Size([1, 176, 216, 1]) + + """ + def __init__( + self, + target_multiple: Union[int, TypeTripletInt], + *, + method: Optional[str] = 'pad', + **kwargs + ): + super().__init__(**kwargs) + self.target_multiple = np.array(to_tuple(target_multiple, 3)) + if method not in ('crop', 'pad'): + raise ValueError('Method must be "crop" or "pad"') + self.method = method + + def apply_transform(self, subject: Subject) -> Subject: + source_shape = np.array(subject.spatial_shape, np.uint16) + function = np.floor if self.method == 'crop' else np.ceil + integer_ratio = function(source_shape / self.target_multiple) + target_shape = integer_ratio * self.target_multiple + target_shape = np.maximum(target_shape, 1) + return CropOrPad(target_shape.astype(int))(subject) diff --git a/torchio/transforms/transform.py b/torchio/transforms/transform.py index 40af6f96b..0a50f9eef 100644 --- a/torchio/transforms/transform.py +++ b/torchio/transforms/transform.py @@ -120,12 +120,13 @@ def apply_transform(self, subject: Subject): def add_transform_to_subject_history(self, subject): from .augmentation import RandomTransform - from . import Compose, OneOf, CropOrPad + from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple call_others = ( RandomTransform, Compose, OneOf, CropOrPad, + EnsureShapeMultiple, ) if not isinstance(self, call_others): subject.add_transform(self, self._get_reproducing_arguments()) @@ -311,8 +312,13 @@ def _get_reproducing_arguments(self): Return a dictionary with the arguments that would be necessary to reproduce the transform exactly. """ - reproducing_arguments = dict(include=self.include, exclude=self.exclude, copy=self.copy) - reproducing_arguments.update({name: getattr(self, name) for name in self.args_names}) + reproducing_arguments = dict( + include=self.include, + exclude=self.exclude, + copy=self.copy, + ) + args_names = {name: getattr(self, name) for name in self.args_names} + reproducing_arguments.update(args_names) return reproducing_arguments def is_invertible(self):