From 281b2eaa95b8dfeb0442ebbb3ac3c44fe8d0e505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Fri, 16 Oct 2020 13:50:51 +0100 Subject: [PATCH] Avoid using SimpleITK in some transforms (#334) * Use PyTorch to crop * Use Crop transform to create patches * Remove typing hint * Use NumPy to pad * Use scipy to blur * Fix kwarg name and default values --- tests/data/test_image.py | 22 ----- tests/transforms/preprocessing/test_pad.py | 20 +++++ torchio/data/dataset.py | 2 +- torchio/data/image.py | 18 ---- torchio/data/inference/aggregator.py | 4 +- torchio/data/inference/grid_sampler.py | 13 ++- torchio/data/sampler/sampler.py | 36 ++++++-- torchio/data/sampler/weighted.py | 9 +- torchio/data/subject.py | 14 --- torchio/torchio.py | 1 + .../augmentation/intensity/random_blur.py | 21 +++-- .../preprocessing/spatial/bounds_transform.py | 31 +------ .../transforms/preprocessing/spatial/crop.py | 22 +++-- .../transforms/preprocessing/spatial/pad.py | 87 +++++++++---------- 14 files changed, 134 insertions(+), 166 deletions(-) create mode 100644 tests/transforms/preprocessing/test_pad.py diff --git a/tests/data/test_image.py b/tests/data/test_image.py index f8603b4a0..d44227b47 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -38,16 +38,6 @@ def test_tensor_affine(self): sample_input = torch.ones((4, 10, 10, 10)) RandomAffine()(sample_input) - def test_crop_attributes(self): - cropped = self.sample.crop((1, 1, 1), (5, 5, 5)) - self.assertIs(self.sample.t1['pre_affine'], cropped.t1['pre_affine']) - - def test_crop_does_not_create_wrong_path(self): - data = torch.ones((1, 10, 10, 10)) - image = ScalarImage(tensor=data) - cropped = image.crop((1, 1, 1), (5, 5, 5)) - self.assertIs(cropped.path, None) - def test_scalar_image_type(self): data = torch.ones((1, 10, 10, 10)) image = ScalarImage(tensor=data) @@ -68,18 +58,6 @@ def test_wrong_label_map_type(self): with self.assertRaises(ValueError): LabelMap(tensor=data, type=INTENSITY) - def test_crop_scalar_image_type(self): - data = torch.ones((1, 10, 10, 10)) - image = ScalarImage(tensor=data) - cropped = image.crop((1, 1, 1), (5, 5, 5)) - self.assertIs(cropped.type, INTENSITY) - - def test_crop_label_map_type(self): - data = torch.ones((1, 10, 10, 10)) - label = LabelMap(tensor=data) - cropped = label.crop((1, 1, 1), (5, 5, 5)) - self.assertIs(cropped.type, LABEL) - def test_no_input(self): with self.assertRaises(ValueError): ScalarImage() diff --git a/tests/transforms/preprocessing/test_pad.py b/tests/transforms/preprocessing/test_pad.py new file mode 100644 index 000000000..326dd750f --- /dev/null +++ b/tests/transforms/preprocessing/test_pad.py @@ -0,0 +1,20 @@ +import torch +import SimpleITK as sitk +from torchio.utils import sitk_to_nib +from torchio.transforms import Pad +from ...utils import TorchioTestCase + + +class TestPad(TorchioTestCase): + """Tests for `Pad`.""" + def test_pad(self): + image = self.sample.t1 + padding = 1, 2, 3, 4, 5, 6 + sitk_image = image.as_sitk() + low, high = padding[::2], padding[1::2] + sitk_padded = sitk.ConstantPad(sitk_image, low, high, 0) + tio_padded = Pad(padding, padding_mode=0)(image) + sitk_tensor, sitk_affine = sitk_to_nib(sitk_padded) + tio_tensor, tio_affine = sitk_to_nib(tio_padded.as_sitk()) + self.assertTensorEqual(sitk_tensor, tio_tensor) + self.assertTensorEqual(sitk_affine, tio_affine) diff --git a/torchio/data/dataset.py b/torchio/data/dataset.py index b8204fc28..279f2b093 100644 --- a/torchio/data/dataset.py +++ b/torchio/data/dataset.py @@ -75,7 +75,7 @@ def __init__( def __len__(self): return len(self.subjects) - def __getitem__(self, index: int) -> dict: + def __getitem__(self, index: int) -> Subject: if not isinstance(index, int): raise ValueError(f'Index "{index}" must be int, not {type(index)}') subject = self.subjects[index] diff --git a/torchio/data/image.py b/torchio/data/image.py index bb1a5f8c9..c45389313 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -472,24 +472,6 @@ def plot(self, **kwargs) -> None: from ..visualization import plot_volume # avoid circular import plot_volume(self, **kwargs) - def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt): - new_origin = nib.affines.apply_affine(self.affine, index_ini) - new_affine = self.affine.copy() - new_affine[:3, 3] = new_origin - i0, j0, k0 = index_ini - i1, j1, k1 = index_fin - patch = self.data[:, i0:i1, j0:j1, k0:k1].clone() - kwargs = dict( - tensor=patch, - affine=new_affine, - type=self.type, - path=self.path, - ) - for key, value in self.items(): - if key in PROTECTED_KEYS: continue - kwargs[key] = value # should I copy? deepcopy? - return self.__class__(**kwargs) - class ScalarImage(Image): """Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`. diff --git a/torchio/data/inference/aggregator.py b/torchio/data/inference/aggregator.py index ea68dc967..a41f7604a 100644 --- a/torchio/data/inference/aggregator.py +++ b/torchio/data/inference/aggregator.py @@ -27,9 +27,9 @@ class GridAggregator: information about patch-based sampling. """ def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'): - sample = sampler.sample + subject = sampler.subject self.volume_padded = sampler.padding_mode is not None - self.spatial_shape = sample.spatial_shape + self.spatial_shape = subject.spatial_shape self._output_tensor = None self.patch_overlap = sampler.patch_overlap self.parse_overlap_mode(overlap_mode) diff --git a/torchio/data/inference/grid_sampler.py b/torchio/data/inference/grid_sampler.py index caf7eb229..36cf3e0ba 100644 --- a/torchio/data/inference/grid_sampler.py +++ b/torchio/data/inference/grid_sampler.py @@ -47,7 +47,7 @@ def __init__( patch_overlap: TypeTuple = (0, 0, 0), padding_mode: Union[str, float, None] = None, ): - self.sample = sample + self.subject = sample self.patch_overlap = np.array(to_tuple(patch_overlap, length=3)) self.padding_mode = padding_mode if padding_mode is not None: @@ -55,9 +55,9 @@ def __init__( border = self.patch_overlap // 2 padding = border.repeat(2) pad = Pad(padding, padding_mode=padding_mode) - self.sample = pad(self.sample) + self.subject = pad(self.subject) PatchSampler.__init__(self, patch_size) - sizes = self.sample.spatial_shape, self.patch_size, self.patch_overlap + sizes = self.subject.spatial_shape, self.patch_size, self.patch_overlap self.parse_sizes(*sizes) self.locations = self.get_patches_locations(*sizes) @@ -68,10 +68,9 @@ def __getitem__(self, index): # Assume 3D location = self.locations[index] index_ini = location[:3] - index_fin = location[3:] - cropped_sample = self.sample.crop(index_ini, index_fin) - cropped_sample[LOCATION] = location - return cropped_sample + cropped_subject = self.crop(self.subject, index_ini, self.patch_size) + cropped_subject[LOCATION] = location + return cropped_subject @staticmethod def parse_sizes( diff --git a/torchio/data/sampler/sampler.py b/torchio/data/sampler/sampler.py index 08d8e912c..864846f73 100644 --- a/torchio/data/sampler/sampler.py +++ b/torchio/data/sampler/sampler.py @@ -29,14 +29,38 @@ def __init__(self, patch_size: TypePatchSize): def extract_patch( self, - sample: Subject, + subject: Subject, + index_ini: TypeTripletInt, + ) -> Subject: + cropped_subject = self.crop(subject, index_ini, self.patch_size) + cropped_subject['index_ini'] = np.array(index_ini).astype(int) + return cropped_subject + + def crop( + self, + subject: Subject, index_ini: TypeTripletInt, + patch_size: TypeTripletInt, ) -> Subject: - index_ini = np.array(index_ini) - index_fin = index_ini + self.patch_size - cropped_sample = sample.crop(index_ini, index_fin) - cropped_sample['index_ini'] = index_ini.astype(int) - return cropped_sample + transform = self.get_crop_transform(subject, index_ini, patch_size) + return transform(subject) + + @staticmethod + def get_crop_transform( + subject, + index_ini, + patch_size: TypePatchSize, + ): + from ...transforms.preprocessing.spatial.crop import Crop + shape = np.array(subject.spatial_shape, dtype=np.uint16) + index_ini = np.array(index_ini, dtype=np.uint16) + patch_size = np.array(patch_size, dtype=np.uint16) + index_fin = index_ini + patch_size + crop_ini = index_ini.tolist() + crop_fin = (shape - index_fin).tolist() + start = () + cropping = sum(zip(crop_ini, crop_fin), start) + return Crop(cropping) class RandomSampler(PatchSampler): diff --git a/torchio/data/sampler/weighted.py b/torchio/data/sampler/weighted.py index 9f8223fec..b16088b29 100644 --- a/torchio/data/sampler/weighted.py +++ b/torchio/data/sampler/weighted.py @@ -168,15 +168,14 @@ def get_cumulative_distribution_function( def extract_patch( self, - sample: Subject, + subject: Subject, probability_map: np.ndarray, cdf: np.ndarray ) -> Subject: index_ini = self.get_random_index_ini(probability_map, cdf) - index_fin = index_ini + self.patch_size - cropped_sample = sample.crop(index_ini, index_fin) - cropped_sample['index_ini'] = index_ini.astype(int) - return cropped_sample + cropped_subject = self.crop(subject, index_ini, self.patch_size) + cropped_subject['index_ini'] = index_ini.astype(int) + return cropped_subject def get_random_index_ini( self, diff --git a/torchio/data/subject.py b/torchio/data/subject.py index 464c9c4e0..0cdd59965 100644 --- a/torchio/data/subject.py +++ b/torchio/data/subject.py @@ -155,20 +155,6 @@ def load(self): for image in self.get_images(intensity_only=False): image.load() - def crop(self, index_ini, index_fin): - """Make a copy of the subject with a reduced field of view (patch).""" - result_dict = {} - for key, value in self.items(): - if isinstance(value, Image): - # patch.clone() is much faster than copy.deepcopy(patch) - value = value.crop(index_ini, index_fin) - else: - value = copy.deepcopy(value) - result_dict[key] = value - new = Subject(result_dict) - new.history = self.history - return new - def update_attributes(self): # This allows to get images using attribute notation, e.g. subject.t1 self.__dict__.update(self) diff --git a/torchio/torchio.py b/torchio/torchio.py index b66a2812f..305a55b64 100644 --- a/torchio/torchio.py +++ b/torchio/torchio.py @@ -30,6 +30,7 @@ TypeNumber = Union[int, float] TypeData = Union[torch.Tensor, np.ndarray] TypeTripletInt = Tuple[int, int, int] +TypeSextetInt = Tuple[int, int, int, int, int, int] TypeTripletFloat = Tuple[float, float, float] TypeTuple = Union[int, TypeTripletInt] TypeRangeInt = Union[int, Tuple[int, int]] diff --git a/torchio/transforms/augmentation/intensity/random_blur.py b/torchio/transforms/augmentation/intensity/random_blur.py index 4c06b785b..c8bd9381d 100644 --- a/torchio/transforms/augmentation/intensity/random_blur.py +++ b/torchio/transforms/augmentation/intensity/random_blur.py @@ -2,8 +2,8 @@ import torch import numpy as np import SimpleITK as sitk -from ....utils import nib_to_sitk, sitk_to_nib -from ....torchio import DATA, AFFINE, TypeData +import scipy.ndimage as ndi +from ....torchio import DATA, AFFINE, TypeData, TypeTripletFloat from ....data.subject import Subject from ... import IntensityTransform from .. import RandomTransform @@ -25,7 +25,7 @@ class RandomBlur(RandomTransform, IntensityTransform): """ def __init__( self, - std: Union[float, Tuple[float, float]] = (0, 4), + std: Union[float, Tuple[float, float]] = (0, 2), p: float = 1, seed: Optional[int] = None, keys: Optional[List[str]] = None, @@ -44,7 +44,7 @@ def apply_transform(self, sample: Subject) -> dict: random_parameters_images_dict[key] = random_parameters_dict transformed_tensor = blur( tensor, - image[AFFINE], + image.spacing, std, ) transformed_tensors.append(transformed_tensor) @@ -58,10 +58,13 @@ def get_params(std_range: Tuple[float, float]) -> np.ndarray: return std -def blur(data: TypeData, affine: TypeData, std: np.ndarray) -> torch.Tensor: +def blur( + data: TypeData, + spacing: TypeTripletFloat, + std_voxel: np.ndarray, + ) -> torch.Tensor: assert data.ndim == 3 - image = nib_to_sitk(data[np.newaxis], affine) - image = sitk.DiscreteGaussian(image, std.tolist()) - array, _ = sitk_to_nib(image) - tensor = torch.from_numpy(array[0]) + std_physical = np.array(std_voxel) * np.array(spacing) + blurred = ndi.gaussian_filter(data, std_physical) + tensor = torch.from_numpy(blurred) return tensor diff --git a/torchio/transforms/preprocessing/spatial/bounds_transform.py b/torchio/transforms/preprocessing/spatial/bounds_transform.py index b2ae2451d..1416f7a2d 100644 --- a/torchio/transforms/preprocessing/spatial/bounds_transform.py +++ b/torchio/transforms/preprocessing/spatial/bounds_transform.py @@ -1,7 +1,6 @@ from typing import Union, Tuple, List, Optional import torch import numpy as np -import SimpleITK as sitk from ....data.subject import Subject from ....torchio import DATA, AFFINE, TypeTripletInt from ... import SpatialTransform @@ -38,7 +37,7 @@ def bounds_function(self): raise NotImplementedError @staticmethod - def parse_bounds(bounds_parameters: TypeBounds) -> Tuple[int, ...]: + def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds: try: bounds_parameters = tuple(bounds_parameters) except TypeError: @@ -65,31 +64,3 @@ def parse_bounds(bounds_parameters: TypeBounds) -> Tuple[int, ...]: f' 3 or 6 integers, not {bounds_parameters}' ) raise ValueError(message) - - def apply_transform(self, sample: Subject) -> dict: - low = self.bounds_parameters[::2] - high = self.bounds_parameters[1::2] - for image in self.get_images(sample): - itk_image = image.as_sitk() - result = self._apply_bounds_function(itk_image, low, high) - data, affine = self.sitk_to_nib(result) - tensor = torch.from_numpy(data) - image[DATA] = tensor - image[AFFINE] = affine - return sample - - def _apply_bounds_function(self, image, low, high): - num_components = image.GetNumberOfComponentsPerPixel() - if self.bounds_function == sitk.Crop or num_components == 1: - result = self.bounds_function(image, low, high) - else: # padding not supported for vector images - components = [ - sitk.VectorIndexSelectionCast(image, i) - for i in range(num_components) - ] - components_padded = [ - self.bounds_function(component, low, high) - for component in components - ] - result = sitk.Compose(components_padded) - return result diff --git a/torchio/transforms/preprocessing/spatial/crop.py b/torchio/transforms/preprocessing/spatial/crop.py index 8191c0f5d..5636fd47d 100644 --- a/torchio/transforms/preprocessing/spatial/crop.py +++ b/torchio/transforms/preprocessing/spatial/crop.py @@ -1,5 +1,6 @@ -from typing import Callable -import SimpleITK as sitk +import numpy as np +import nibabel as nib +from ....torchio import TypeTripletInt, DATA from .bounds_transform import BoundsTransform @@ -21,8 +22,17 @@ class Crop(BoundsTransform): If only one value :math:`n` is provided, then :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} = d_{ini} = d_{fin} = n`. - """ - @property - def bounds_function(self) -> Callable: - return sitk.Crop + def apply_transform(self, sample) -> dict: + low = self.bounds_parameters[::2] + high = self.bounds_parameters[1::2] + index_ini = low + index_fin = np.array(sample.spatial_shape) - high + for image in self.get_images(sample): + new_origin = nib.affines.apply_affine(image.affine, index_ini) + new_affine = image.affine.copy() + new_affine[:3, 3] = new_origin + i0, j0, k0 = index_ini + i1, j1, k1 = index_fin + image[DATA] = image[DATA][:, i0:i1, j0:j1, k0:k1].clone() + return sample diff --git a/torchio/transforms/preprocessing/spatial/pad.py b/torchio/transforms/preprocessing/spatial/pad.py index a7cae8a82..37797a91e 100644 --- a/torchio/transforms/preprocessing/spatial/pad.py +++ b/torchio/transforms/preprocessing/spatial/pad.py @@ -1,6 +1,12 @@ from numbers import Number from typing import Callable, Union, List, Optional -import SimpleITK as sitk + +import numpy as np +import nibabel as nib +import torch + +from ....torchio import DATA, AFFINE +from ....data.subject import Subject from .bounds_transform import BoundsTransform, TypeBounds @@ -22,35 +28,27 @@ class Pad(BoundsTransform): If only one value :math:`n` is provided, then :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} = d_{ini} = d_{fin} = n`. - padding_mode: - Type of padding. Should be one of: - - - A number. Pad with a constant value. - - - ``reflect`` Pad with reflection of image without repeating the last value on the edge. - - - ``mirror`` Same as ``reflect``. - - - ``edge`` Pad with the last value at the edge of the image. - - - ``replicate`` Same as ``edge``. - - - ``circular`` Pad with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning. - - - ``wrap`` Same as ``circular``. - + padding_mode: See possible modes in `NumPy docs`_. If it is a number, + the mode will be set to ``'constant'``. p: Probability that this transform will be applied. keys: See :py:class:`~torchio.transforms.Transform`. + + .. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ - PADDING_FUNCTIONS = { - 'reflect': sitk.MirrorPad, - 'mirror': sitk.MirrorPad, - 'edge': sitk.ZeroFluxNeumannPad, - 'replicate': sitk.ZeroFluxNeumannPad, - 'circular': sitk.WrapPad, - 'wrap': sitk.WrapPad, - } + PADDING_MODES = ( + 'empty', + 'edge', + 'wrap', + 'constant', + 'linear_ramp', + 'maximum', + 'mean', + 'median', + 'minimum', + 'reflect', + 'symmetric', + ) def __init__( self, @@ -59,17 +57,12 @@ def __init__( p: float = 1, keys: Optional[List[str]] = None, ): - """ - padding_mode can be 'constant', 'reflect', 'replicate' or 'circular'. - See https://pytorch.org/docs/stable/nn.functional.html#pad for more - information about this transform. - """ super().__init__(padding, p=p, keys=keys) self.padding_mode, self.fill = self.parse_padding_mode(padding_mode) @classmethod def parse_padding_mode(cls, padding_mode): - if padding_mode in cls.PADDING_FUNCTIONS: + if padding_mode in cls.PADDING_MODES: fill = None elif isinstance(padding_mode, Number): fill = padding_mode @@ -77,21 +70,23 @@ def parse_padding_mode(cls, padding_mode): else: message = ( f'Padding mode "{padding_mode}" not valid. Valid options are' - f' {list(cls.PADDING_FUNCTIONS.keys())} or a number' + f' {list(cls.PADDING_MODES)} or a number' ) raise KeyError(message) return padding_mode, fill - @property - def bounds_function(self) -> Callable: - if self.fill is not None: - function = _pad_with_fill(self.fill) - else: - function = self.PADDING_FUNCTIONS[self.padding_mode] - return function - - -def _pad_with_fill(fill): - def wrapped(image, bounds1, bounds2): - return sitk.ConstantPad(image, bounds1, bounds2, fill) - return wrapped + def apply_transform(self, subject: Subject) -> Subject: + low = self.bounds_parameters[::2] + for image in self.get_images(subject): + new_origin = nib.affines.apply_affine(image.affine, -np.array(low)) + new_affine = image.affine.copy() + new_affine[:3, 3] = new_origin + kwargs = dict(mode=self.padding_mode) + if self.padding_mode == 'constant': + kwargs['constant_values'] = self.fill + pad_params = self.bounds_parameters + paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] + padded = np.pad(image[DATA], paddings, **kwargs) + image[DATA] = torch.from_numpy(padded) + image[AFFINE] = new_affine + return subject