From f8de82eb71e2e13903300d8486c07588724fe6e1 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 15:18:13 +0100 Subject: [PATCH 01/14] Use H,W,D for spatial dimensions --- torchio/data/image.py | 8 ++--- torchio/data/inference/grid_sampler.py | 10 +++--- torchio/data/sampler/sampler.py | 12 +++---- .../augmentation/intensity/random_swap.py | 6 ++-- .../transforms/preprocessing/spatial/crop.py | 18 +++++----- .../preprocessing/spatial/crop_or_pad.py | 12 +++---- .../transforms/preprocessing/spatial/pad.py | 18 +++++----- .../preprocessing/spatial/resample.py | 6 ++-- .../preprocessing/spatial/to_canonical.py | 6 ++-- torchio/transforms/transform.py | 4 +-- torchio/utils.py | 34 +++++++++---------- 11 files changed, 69 insertions(+), 65 deletions(-) diff --git a/torchio/data/image.py b/torchio/data/image.py index 641e615c5..752be7057 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -248,11 +248,11 @@ def axis_name_to_index(self, axis: str): raise ValueError('Axis must be a string') axis = axis[0].upper() - # Generally, TorchIO tensors are (C, D, H, W) + # Generally, TorchIO tensors are (C, H, W, D) if axis == 'H': - return -2 + return 1 elif axis == 'W': - return -1 + return 2 else: try: index = self.orientation.index(axis) @@ -414,7 +414,7 @@ def save(self, path, squeeze=True, channels_last=True): ) def is_2d(self) -> bool: - return self.shape[-3] == 1 + return self.shape[-1] == 1 def numpy(self) -> np.ndarray: """Get a NumPy array containing the image data.""" diff --git a/torchio/data/inference/grid_sampler.py b/torchio/data/inference/grid_sampler.py index 0ba46edc3..b06c58bc1 100644 --- a/torchio/data/inference/grid_sampler.py +++ b/torchio/data/inference/grid_sampler.py @@ -18,18 +18,18 @@ class GridSampler(PatchSampler, Dataset): Args: sample: Instance of :py:class:`~torchio.data.subject.Subject` from which patches will be extracted. - patch_size: Tuple of integers :math:`(d, h, w)` to generate patches + patch_size: Tuple of integers :math:`(h, w, d)` to generate patches of size :math:`d \times h \times w`. If a single number :math:`n` is provided, - :math:`d = h = w = n`. - patch_overlap: Tuple of even integers :math:`(d_o, h_o, w_o)` specifying + :math:`h = w = d = n`. + patch_overlap: Tuple of even integers :math:`(h_o, w_o, d_o)` specifying the overlap between patches for dense inference. If a single number - :math:`n` is provided, :math:`d_o = h_o = w_o = n`. + :math:`n` is provided, :math:`h_o = w_o = d_o = n`. padding_mode: Same as :attr:`padding_mode` in :py:class:`~torchio.transforms.Pad`. If ``None``, the volume will not be padded before sampling and patches at the border will not be cropped by the aggregator. Otherwise, the volume will be padded with - :math:`\left(\frac{d_o}{2}, \frac{h_o}{2}, \frac{w_o}{2}\right)` + :math:`\left(\frac{h_o}{2}, \frac{w_o}{2}, \frac{d_o}{2} \right)` on each side before sampling. If the sampler is passed to a :py:class:`~torchio.data.GridAggregator`, it will crop the output to its original size. diff --git a/torchio/data/sampler/sampler.py b/torchio/data/sampler/sampler.py index d96cf5210..0d94108d8 100644 --- a/torchio/data/sampler/sampler.py +++ b/torchio/data/sampler/sampler.py @@ -12,9 +12,9 @@ class PatchSampler: r"""Base class for TorchIO samplers. Args: - patch_size: Tuple of integers :math:`(d, h, w)` to generate patches - of size :math:`d \times h \times w`. - If a single number :math:`n` is provided, :math:`d = h = w = n`. + patch_size: Tuple of integers :math:`(h, w, d)` to generate patches + of size :math:`h \times w \times d`. + If a single number :math:`n` is provided, :math:`h = w = d = n`. """ def __init__(self, patch_size: TypePatchSize): patch_size_array = np.array(to_tuple(patch_size, length=3)) @@ -43,9 +43,9 @@ class RandomSampler(PatchSampler): r"""Base class for TorchIO samplers. Args: - patch_size: Tuple of integers :math:`(d, h, w)` to generate patches - of size :math:`d \times h \times w`. - If a single number :math:`n` is provided, :math:`d = h = w = n`. + patch_size: Tuple of integers :math:`(h, w, d)` to generate patches + of size :math:`h \times w \times d`. + If a single number :math:`n` is provided, :math:`h = w = d = n`. """ def __call__( self, diff --git a/torchio/transforms/augmentation/intensity/random_swap.py b/torchio/transforms/augmentation/intensity/random_swap.py index 822060f89..e5a0fc4b9 100644 --- a/torchio/transforms/augmentation/intensity/random_swap.py +++ b/torchio/transforms/augmentation/intensity/random_swap.py @@ -15,9 +15,9 @@ class RandomSwap(RandomTransform, IntensityTransform): `_. Args: - patch_size: Tuple of integers :math:`(d, h, w)` to swap patches - of size :math:`d \times h \times w`. - If a single number :math:`n` is provided, :math:`d = h = w = n`. + patch_size: Tuple of integers :math:`(h, w, d)` to swap patches + of size :math:`h \times w \times d`. + If a single number :math:`n` is provided, :math:`h = w = d = n`. num_iterations: Number of times that two patches will be swapped. p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. diff --git a/torchio/transforms/preprocessing/spatial/crop.py b/torchio/transforms/preprocessing/spatial/crop.py index ddcf20c3a..8b69ac428 100644 --- a/torchio/transforms/preprocessing/spatial/crop.py +++ b/torchio/transforms/preprocessing/spatial/crop.py @@ -8,17 +8,19 @@ class Crop(BoundsTransform): Args: cropping: Tuple - :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})` + :math:`(h_{ini}, h_{fin}, w_{ini}, w_{fin}, d_{ini}, d_{fin})` defining the number of values cropped from the edges of each axis. If the initial shape of the image is - :math:`D \times H \times W`, the final shape will be - :math:`(- d_{ini} + D - d_{fin}) \times (- h_{ini} + H - h_{fin}) \times (- w_{ini} + W - w_{fin})`. - If only three values :math:`(d, h, w)` are provided, then - :math:`d_{ini} = d_{fin} = d`, - :math:`h_{ini} = h_{fin} = h` and - :math:`w_{ini} = w_{fin} = w`. + :math:`H \times W \times D`, the final shape will be + :math:`(- h_{ini} + H - h_{fin}) \times (- w_{ini} + W - w_{fin}) + \times (- d_{ini} + D - d_{fin})`. + If only three values :math:`(h, w, d)` are provided, then + :math:`h_{ini} = h_{fin} = h`, + :math:`w_{ini} = w_{fin} = w` and + :math:`d_{ini} = d_{fin} = d`. If only one value :math:`n` is provided, then - :math:`d_{ini} = d_{fin} = h_{ini} = h_{fin} = w_{ini} = w_{fin} = n`. + :math:`h_{ini} = h_{fin} = w_{ini} = w_{fin} + = d_{ini} = d_{fin} = n`. """ @property diff --git a/torchio/transforms/preprocessing/spatial/crop_or_pad.py b/torchio/transforms/preprocessing/spatial/crop_or_pad.py index a40c34333..239b49c6c 100644 --- a/torchio/transforms/preprocessing/spatial/crop_or_pad.py +++ b/torchio/transforms/preprocessing/spatial/crop_or_pad.py @@ -18,8 +18,8 @@ class CropOrPad(BoundsTransform): physical positions of the voxels are maintained. Args: - target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is - provided, then :math:`D = H = W = N`. + target_shape: Tuple :math:`(H, W, D)`. If a single value :math:`N` is + provided, then :math:`H = W = D = N`. padding_mode: Same as :attr:`padding_mode` in :py:class:`~torchio.transforms.Pad`. mask_name: If ``None``, the centers of the input and output volumes @@ -112,11 +112,11 @@ def _get_six_bounds_parameters( r"""Compute bounds parameters for ITK filters. Args: - parameters: Tuple :math:`(d, h, w)` with the number of voxels to be + parameters: Tuple :math:`(h, w, d)` with the number of voxels to be cropped or padded. Returns: - Tuple :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})`, + Tuple :math:`(h_{ini}, h_{fin}, w_{ini}, w_{fin}, d_{ini}, d_{fin})`, where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and :math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`. @@ -158,8 +158,8 @@ def _compute_center_crop_or_pad( sample: Subject, ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: source_shape = self._get_sample_shape(sample) - # The parent class turns the 3-element shape tuple (d, h, w) - # into a 6-element bounds tuple (d, d, h, h, w, w) + # The parent class turns the 3-element shape tuple (h, w, d) + # into a 6-element bounds tuple (h, h, w, w, d, d) target_shape = np.array(self.bounds_parameters[::2]) parameters = self._compute_cropping_padding_from_shapes( source_shape, target_shape) diff --git a/torchio/transforms/preprocessing/spatial/pad.py b/torchio/transforms/preprocessing/spatial/pad.py index a086f9fdb..eecc1ad05 100644 --- a/torchio/transforms/preprocessing/spatial/pad.py +++ b/torchio/transforms/preprocessing/spatial/pad.py @@ -9,17 +9,19 @@ class Pad(BoundsTransform): Args: padding: Tuple - :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})` + :math:`(h_{ini}, h_{fin}, w_{ini}, w_{fin}, d_{ini}, d_{fin})` defining the number of values padded to the edges of each axis. If the initial shape of the image is - :math:`D \times H \times W`, the final shape will be - :math:`(d_{ini} + D + d_{fin}) \times (h_{ini} + H + h_{fin}) \times (w_{ini} + W + w_{fin})`. - If only three values :math:`(d, h, w)` are provided, then - :math:`d_{ini} = d_{fin} = d`, - :math:`h_{ini} = h_{fin} = h` and - :math:`w_{ini} = w_{fin} = w`. + :math:`H \times W \times D`, the final shape will be + :math:`(h_{ini} + H + h_{fin}) \times (w_{ini} + W + w_{fin}) + \times (d_{ini} + D + d_{fin})`. + If only three values :math:`(h, w, d)` are provided, then + :math:`h_{ini} = h_{fin} = h`, + :math:`w_{ini} = w_{fin} = w` and + :math:`d_{ini} = d_{fin} = d`. If only one value :math:`n` is provided, then - :math:`d_{ini} = d_{fin} = h_{ini} = h_{fin} = w_{ini} = w_{fin} = n`. + :math:`h_{ini} = h_{fin} = w_{ini} = w_{fin} = + d_{ini} = d_{fin} = n`. padding_mode: Type of padding. Should be one of: diff --git a/torchio/transforms/preprocessing/spatial/resample.py b/torchio/transforms/preprocessing/spatial/resample.py index e0817ab07..8f83b340b 100644 --- a/torchio/transforms/preprocessing/spatial/resample.py +++ b/torchio/transforms/preprocessing/spatial/resample.py @@ -26,8 +26,8 @@ class Resample(SpatialTransform): """Change voxel spacing by resampling. Args: - target: Tuple :math:`(s_d, s_h, s_w)`. If only one value - :math:`n` is specified, then :math:`s_d = s_h = s_w = n`. + target: Tuple :math:`(s_h, s_w, s_d)`. If only one value + :math:`n` is specified, then :math:`s_h = s_w = s_d = n`. If a string or :py:class:`~pathlib.Path` is given, all images will be resampled using the image with that name as reference or found at the path. @@ -208,7 +208,7 @@ def apply_transform(self, sample: Subject) -> dict: @staticmethod def apply_resample( - tensor: torch.Tensor, # (C, D, H, W) + tensor: torch.Tensor, affine: np.ndarray, interpolation_order: int, target_spacing: Optional[Tuple[float, float, float]] = None, diff --git a/torchio/transforms/preprocessing/spatial/to_canonical.py b/torchio/transforms/preprocessing/spatial/to_canonical.py index db2c20aa2..7161cead4 100644 --- a/torchio/transforms/preprocessing/spatial/to_canonical.py +++ b/torchio/transforms/preprocessing/spatial/to_canonical.py @@ -32,14 +32,14 @@ def apply_transform(self, sample: Subject) -> dict: affine = image[AFFINE] if nib.aff2axcodes(affine) == tuple('RAS'): continue - array = image[DATA].numpy()[np.newaxis] # (1, C, D, H, W) + array = image[DATA].numpy()[np.newaxis] # (1, C, H, W, D) # NIfTI images should have channels in 5th dimension - array = array.transpose(2, 3, 4, 0, 1) # (D, H, W, 1, C) + array = array.transpose(2, 3, 4, 0, 1) # (H, W, D, 1, C) nii = nib.Nifti1Image(array, affine) reoriented = nib.as_closest_canonical(nii) array = reoriented.get_fdata(dtype=np.float32) # https://github.com/facebookresearch/InferSent/issues/99#issuecomment-446175325 - array = array.copy().transpose(3, 4, 0, 1, 2) # (1, C, D, H, W) + array = array.copy().transpose(3, 4, 0, 1, 2) # (1, C, H, W, D) image[DATA] = torch.from_numpy(array[0]) image[AFFINE] = reoriented.affine return sample diff --git a/torchio/transforms/transform.py b/torchio/transforms/transform.py index 1a0b619f6..7885106e2 100644 --- a/torchio/transforms/transform.py +++ b/torchio/transforms/transform.py @@ -50,8 +50,8 @@ def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray]): Args: data: Instance of :py:class:`~torchio.Subject`, 4D :py:class:`torch.Tensor` or 4D NumPy array with dimensions - :math:`(C, D, H, W)`, where :math:`C` is the number of channels - and :math:`D, H, W` are the spatial dimensions. If the input is + :math:`(C, H, W, D)`, where :math:`C` is the number of channels + and :math:`H, W, D` are the spatial dimensions. If the input is a tensor, the affine matrix is an identity and a tensor will be also returned. """ diff --git a/torchio/utils.py b/torchio/utils.py index 4376d82e7..17158febb 100644 --- a/torchio/utils.py +++ b/torchio/utils.py @@ -180,17 +180,17 @@ def nib_to_sitk( if data.ndim != 4: raise ValueError(f'Input must be 4D, but has shape {tuple(data.shape)}') # Possibilities - # (1, 1, h, w) - # (c, 1, h, w) - # (1, d, h, w) - # (c, d, h, w) + # (1, h, w, 1) + # (c, h, w, 1) + # (1, h, w, 1) + # (c, h, w, d) array = np.asarray(data) affine = np.asarray(affine).astype(np.float64) is_multichannel = array.shape[0] > 1 and not force_4d - is_2d = array.shape[1] == 1 and not force_3d + is_2d = array.shape[3] == 1 and not force_3d if is_2d: - array = array[:, 0, :, :] + array = array[..., 0] if not is_multichannel and not force_4d: array = array[0] array = array.transpose() # (W, H, D, C) or (W, H, D) @@ -199,7 +199,7 @@ def nib_to_sitk( rotation, spacing = get_rotation_and_spacing_from_affine(affine) origin = np.dot(FLIP_XY, affine[:3, 3]) direction = np.dot(FLIP_XY, rotation) - if is_2d: # ignore first dimension if 2D (1, 1, H, W) + if is_2d: # ignore first dimension if 2D (1, H, W, 1) direction = direction[1:3, 1:3] image.SetOrigin(origin) # should I add a 4th value if force_4d? image.SetSpacing(spacing) @@ -229,7 +229,7 @@ def sitk_to_nib( origin = image.GetOrigin() if len(direction) == 9: rotation = direction.reshape(3, 3) - elif len(direction) == 4: # ignore first dimension if 2D (1, 1, H, W) + elif len(direction) == 4: # ignore first dimension if 2D (1, H, W, 1) rotation_2d = direction.reshape(2, 2) rotation = np.eye(3) rotation[1:3, 1:3] = rotation_2d @@ -266,26 +266,26 @@ def ensure_4d( if tensor.shape[-1] == 1: tensor = tensor[..., 0, :] if num_dimensions == 4: # assume 3D multichannel - if channels_last: # (D, H, W, C) - tensor = tensor.permute(3, 0, 1, 2) # (C, D, H, W) + if channels_last: # (H, W, D, C) + tensor = tensor.permute(3, 0, 1, 2) # (C, H, W, C) elif num_dimensions == 2: # assume 2D monochannel (H, W) - tensor = tensor[np.newaxis, np.newaxis] # (1, 1, H, W) + tensor = tensor[np.newaxis, np.newaxis] # (1, H, W, 1) elif num_dimensions == 3: # 2D multichannel or 3D monochannel? if num_spatial_dims == 2: if channels_last: # (H, W, C) tensor = tensor.permute(2, 0, 1) # (C, H, W) - tensor = tensor[:, np.newaxis] # (C, 1, H, W) - elif num_spatial_dims == 3: # (D, H, W) - tensor = tensor[np.newaxis] # (1, D, H, W) + tensor = tensor[..., np.newaxis] # (C, H, W, 1) + elif num_spatial_dims == 3: # (H, W, D) + tensor = tensor[np.newaxis] # (1, H, W, D) else: # try to guess shape = tensor.shape maybe_rgb = 3 in (shape[0], shape[-1]) if maybe_rgb: if shape[-1] == 3: # (H, W, 3) tensor = tensor.permute(2, 0, 1) # (3, H, W) - tensor = tensor[:, np.newaxis] # (3, 1, H, W) - else: # (D, H, W) - tensor = tensor[np.newaxis] # (1, D, H, W) + tensor = tensor[..., np.newaxis] # (3, H, W, 1) + else: # (H, W, D) + tensor = tensor[np.newaxis] # (1, H, W, D) else: message = ( f'{num_dimensions}D images not supported yet. Please create an' From 6da1148dfef978cab46b1097776ae66583a69c07 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 18:15:48 +0100 Subject: [PATCH 02/14] Fix 2D shapes --- tests/data/test_image.py | 8 ++++---- tests/test_utils.py | 16 ++++++++-------- tests/transforms/test_transforms.py | 6 +++--- tests/utils.py | 3 ++- torchio/data/image.py | 8 ++++---- torchio/data/subject.py | 2 +- .../augmentation/spatial/random_flip.py | 5 +---- torchio/utils.py | 14 +++++++------- 8 files changed, 30 insertions(+), 32 deletions(-) diff --git a/tests/data/test_image.py b/tests/data/test_image.py index 1d0841fc6..ce302f335 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -112,7 +112,7 @@ def test_nans_tensor(self): def test_nans_file(self): image = ScalarImage(self.get_image_path('repr_test', add_nans=True)) with self.assertWarns(UserWarning): - image._load() + image.load() def test_get_center(self): tensor = torch.rand(1, 3, 3, 3) @@ -139,19 +139,19 @@ def test_with_a_list_of_images_with_different_shapes(self): path2 = self.get_image_path('path2', shape=(7, 5, 5)) image = ScalarImage(path=[path1, path2]) with self.assertRaises(RuntimeError): - image._load() + image.load() def test_with_a_list_of_images_with_different_affines(self): path1 = self.get_image_path('path1', spacing=(1, 1, 1)) path2 = self.get_image_path('path2', spacing=(1, 2, 1)) image = ScalarImage(path=[path1, path2]) with self.assertWarns(RuntimeWarning): - image._load() + image.load() def test_with_a_list_of_2d_paths(self): shape = (5, 5) path1 = self.get_image_path('path1', shape=shape) path2 = self.get_image_path('path2', shape=shape) image = ScalarImage(path=[path1, path2]) - self.assertEqual(image.shape, (2, 1, 5, 5)) + self.assertEqual(image.shape, (2, 5, 5, 1)) self.assertEqual(image[STEM], ['path1', 'path2']) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8a062f355..00b637a9f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -53,7 +53,7 @@ def test_apply_transform_to_file(self): ) def test_sitk_to_nib(self): - data = np.random.rand(10, 10) + data = np.random.rand(10, 12) image = sitk.GetImageFromArray(data) tensor, affine = sitk_to_nib(image) self.assertAlmostEqual(data.sum(), tensor.sum()) @@ -64,36 +64,36 @@ def setUp(self): super().setUp() self.affine = np.eye(4) - def test_wrong_dims(self): + def test_wrong_num_dims(self): with self.assertRaises(ValueError): nib_to_sitk(np.random.rand(10, 10), self.affine) def test_2d_single(self): - data = np.random.rand(1, 1, 10, 12) + data = np.random.rand(1, 10, 12, 1) image = nib_to_sitk(data, self.affine) assert image.GetDimension() == 2 assert image.GetSize() == (10, 12) assert image.GetNumberOfComponentsPerPixel() == 1 def test_2d_multi(self): - data = np.random.rand(5, 1, 10, 12) + data = np.random.rand(5, 10, 12, 1) image = nib_to_sitk(data, self.affine) assert image.GetDimension() == 2 assert image.GetSize() == (10, 12) assert image.GetNumberOfComponentsPerPixel() == 5 def test_2d_3d_single(self): - data = np.random.rand(1, 1, 10, 12) + data = np.random.rand(1, 10, 12, 1) image = nib_to_sitk(data, self.affine, force_3d=True) assert image.GetDimension() == 3 - assert image.GetSize() == (1, 10, 12) + assert image.GetSize() == (10, 12, 1) assert image.GetNumberOfComponentsPerPixel() == 1 def test_2d_3d_multi(self): - data = np.random.rand(5, 1, 10, 12) + data = np.random.rand(5, 10, 12, 1) image = nib_to_sitk(data, self.affine, force_3d=True) assert image.GetDimension() == 3 - assert image.GetSize() == (1, 10, 12) + assert image.GetSize() == (10, 12, 1) assert image.GetNumberOfComponentsPerPixel() == 5 def test_3d_single(self): diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 9a6230612..2d5339992 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -14,11 +14,11 @@ def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } - disp = 1 if is_3d else (0.01, 1, 1) + disp = 1 if is_3d else (1, 1, 0.01) elastic = torchio.RandomElasticDeformation(max_displacement=disp) - cp_args = (9, 21, 30) if is_3d else (1, 21, 30) + cp_args = (9, 21, 30) if is_3d else (21, 30, 1) flip_axes = (0, 1, 2) if is_3d else (0, 1) - swap_patch = (2, 3, 4) if is_3d else (1, 3, 4) + swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) transforms = [ diff --git a/tests/utils.py b/tests/utils.py index 541125b10..50e780881 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -63,7 +63,7 @@ def setUp(self): def make_2d(self, sample): sample = copy.deepcopy(sample) for image in sample.get_images(intensity_only=False): - image[DATA] = image[DATA][:, 0:1, ...] + image[DATA] = image[DATA][..., :1] return sample def make_4d(self, sample): @@ -137,6 +137,7 @@ def get_image_path( components=1, add_nans=False ): + shape = (*shape, 1) if len(shape) == 2 else shape data = np.random.rand(components, *shape) if binary: data = (data > 0.5).astype(np.uint8) diff --git a/torchio/data/image.py b/torchio/data/image.py index 752be7057..b9116d624 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -61,7 +61,7 @@ class Image(dict): :py:class:`~torchio.data.sampler.weighted.WeightedSampler`. tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D :py:class:`torch.Tensor` or NumPy array with dimensions - :math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess + :math:`(C, H, W, D)`. If it is not 4D, TorchIO will try to guess the dimensions meanings. If 2D, the shape will be interpreted as :math:`(H, W)`. If 3D, the number of spatial dimensions should be determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims` @@ -174,7 +174,7 @@ def __repr__(self): def __getitem__(self, item): if item in (DATA, AFFINE): if item not in self: - self._load() + self.load() return super().__getitem__(item) def __array__(self): @@ -350,11 +350,11 @@ def parse_affine(affine: np.ndarray) -> np.ndarray: raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}') return affine - def _load(self) -> None: + def load(self) -> None: r"""Load the image from disk. Returns: - Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D + Tuple containing a 4D tensor of size :math:`(C, H, W, D)` and a 2D :math:`4 \times 4` affine matrix to convert voxel indices to world coordinates. """ diff --git a/torchio/data/subject.py b/torchio/data/subject.py index f3dc13358..60365deb8 100644 --- a/torchio/data/subject.py +++ b/torchio/data/subject.py @@ -148,7 +148,7 @@ def add_transform( def load(self): for image in self.get_images(intensity_only=False): - image._load() + image.load() def crop(self, index_ini, index_fin): result_dict = {} diff --git a/torchio/transforms/augmentation/spatial/random_flip.py b/torchio/transforms/augmentation/spatial/random_flip.py index bc97e3e41..d1461a836 100644 --- a/torchio/transforms/augmentation/spatial/random_flip.py +++ b/torchio/transforms/augmentation/spatial/random_flip.py @@ -19,6 +19,7 @@ class RandomFlip(RandomTransform, SpatialTransform): ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``, ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or ``'i'`` (inferior). Only the first letter of the string will be + used. If the image is 2D, ``'Height'`` and ``'Width'`` may be used. flip_probability: Probability that the image will be flipped. This is computed on a per-axis basis. @@ -33,10 +34,6 @@ class RandomFlip(RandomTransform, SpatialTransform): .. tip:: It is handy to specify the axes as anatomical labels when the image orientation is not known. - - .. warning:: Note that height and width of 2D images correspond to axes - ``1`` and ``2`` respectively, as TorchIO images are generally considered - to have 3 spatial dimensions. """ def __init__( diff --git a/torchio/utils.py b/torchio/utils.py index 17158febb..ad2f26fe2 100644 --- a/torchio/utils.py +++ b/torchio/utils.py @@ -200,14 +200,14 @@ def nib_to_sitk( origin = np.dot(FLIP_XY, affine[:3, 3]) direction = np.dot(FLIP_XY, rotation) if is_2d: # ignore first dimension if 2D (1, H, W, 1) - direction = direction[1:3, 1:3] + direction = direction[:2, :2] image.SetOrigin(origin) # should I add a 4th value if force_4d? image.SetSpacing(spacing) image.SetDirection(direction.flatten()) if data.ndim == 4: assert image.GetNumberOfComponentsPerPixel() == data.shape[0] num_spatial_dims = 2 if is_2d else 3 - assert image.GetSize() == data.shape[-num_spatial_dims:] + assert image.GetSize() == data.shape[1: 1 + num_spatial_dims] return image @@ -223,7 +223,7 @@ def sitk_to_nib( if not keepdim: data = ensure_4d(data, False, num_spatial_dims=input_spatial_dims) assert data.shape[0] == num_components - assert data.shape[-input_spatial_dims:] == image.GetSize() + assert data.shape[1: 1 + input_spatial_dims] == image.GetSize() spacing = np.array(image.GetSpacing()) direction = np.array(image.GetDirection()) origin = image.GetOrigin() @@ -232,9 +232,9 @@ def sitk_to_nib( elif len(direction) == 4: # ignore first dimension if 2D (1, H, W, 1) rotation_2d = direction.reshape(2, 2) rotation = np.eye(3) - rotation[1:3, 1:3] = rotation_2d - spacing = 1, *spacing - origin = 0, *origin + rotation[:2, :2] = rotation_2d + spacing = *spacing, 1 + origin = *origin, 0 rotation = np.dot(FLIP_XY, rotation) rotation_zoom = rotation * spacing translation = np.dot(FLIP_XY, origin) @@ -269,7 +269,7 @@ def ensure_4d( if channels_last: # (H, W, D, C) tensor = tensor.permute(3, 0, 1, 2) # (C, H, W, C) elif num_dimensions == 2: # assume 2D monochannel (H, W) - tensor = tensor[np.newaxis, np.newaxis] # (1, H, W, 1) + tensor = tensor[np.newaxis, ..., np.newaxis] # (1, H, W, 1) elif num_dimensions == 3: # 2D multichannel or 3D monochannel? if num_spatial_dims == 2: if channels_last: # (H, W, C) From 6c64d103ba60555cc4ae9ba07e394e74fb4eafbd Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 18:25:53 +0100 Subject: [PATCH 03/14] Transpose 2D images before ITK reading and writing --- torchio/data/io.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchio/data/io.py b/torchio/data/io.py index b4b373cde..407a61d8f 100644 --- a/torchio/data/io.py +++ b/torchio/data/io.py @@ -37,11 +37,16 @@ def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]: return tensor, affine -def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]: +def _read_sitk( + path: TypePath, + transpose_2d: bool = True, + ) -> Tuple[torch.Tensor, np.ndarray]: if Path(path).is_dir(): # assume DICOM image = _read_dicom(path) else: image = sitk.ReadImage(str(path)) + if image.GetDimension() == 2 and transpose_2d: + image = sitk.PermuteAxes(image, (1, 0)) data, affine = sitk_to_nib(image, keepdim=True) if data.dtype != np.float32: data = data.astype(np.float32) @@ -115,6 +120,7 @@ def _write_sitk( path: TypePath, squeeze: bool = True, use_compression: bool = True, + transpose_2d: bool = True, ) -> None: assert tensor.ndim == 4 path = Path(path) @@ -122,6 +128,8 @@ def _write_sitk( warnings.warn(f'Casting to uint 8 before saving to {path}') tensor = tensor.numpy().astype(np.uint8) image = nib_to_sitk(tensor, affine, squeeze=squeeze) + if image.GetDimension() == 2 and transpose_2d: + image = sitk.PermuteAxes(image, (1, 0)) sitk.WriteImage(image, str(path), use_compression) From 43aa81181f222195dad279d3fe76590e77fd40fa Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 18:47:24 +0100 Subject: [PATCH 04/14] Add weight and height properties to image --- torchio/data/image.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchio/data/image.py b/torchio/data/image.py index b9116d624..aa6ce2a58 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -217,6 +217,21 @@ def shape(self) -> Tuple[int, int, int, int]: def spatial_shape(self) -> TypeTripletInt: return self.shape[1:] + def check_is_2d(self): + if not self.is_2d(): + message = f'Image is not 2D. Spatial shape: {self.spatial_shape}' + raise RuntimeError(message) + + @property + def height(self) -> int: + self.check_is_2d() + return self.spatial_shape[0] + + @property + def width(self) -> int: + self.check_is_2d() + return self.spatial_shape[1] + @property def orientation(self): return nib.aff2axcodes(self.affine) From 2b295e3af16babc89feec02140a4ebf2f46f1d98 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 18:47:36 +0100 Subject: [PATCH 05/14] Fix transforms for 2D images --- .../augmentation/intensity/random_ghosting.py | 2 +- .../augmentation/intensity/random_motion.py | 15 +++++++-------- .../augmentation/spatial/random_affine.py | 4 ++-- .../spatial/random_elastic_deformation.py | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/torchio/transforms/augmentation/intensity/random_ghosting.py b/torchio/transforms/augmentation/intensity/random_ghosting.py index 44fc4d266..118b90be2 100644 --- a/torchio/transforms/augmentation/intensity/random_ghosting.py +++ b/torchio/transforms/augmentation/intensity/random_ghosting.py @@ -92,7 +92,7 @@ def apply_transform(self, sample: Subject) -> dict: for image_name, image in self.get_images_dict(sample).items(): transformed_tensors = [] is_2d = image.is_2d() - axes = [a for a in self.axes if a != 0] if is_2d else self.axes + axes = [a for a in self.axes if a != 2] if is_2d else self.axes for channel_idx, tensor in enumerate(image[DATA]): params = self.get_params( self.num_ghosts_range, diff --git a/torchio/transforms/augmentation/intensity/random_motion.py b/torchio/transforms/augmentation/intensity/random_motion.py index 21ddbc115..0701b64c2 100644 --- a/torchio/transforms/augmentation/intensity/random_motion.py +++ b/torchio/transforms/augmentation/intensity/random_motion.py @@ -76,15 +76,14 @@ def __init__( def apply_transform(self, sample: Subject) -> dict: random_parameters_images_dict = {} - for image_name, image_dict in self.get_images_dict(sample).items(): + for image_name, image in self.get_images_dict(sample).items(): result_arrays = [] - for channel_idx, data in enumerate(image_dict[DATA]): - is_2d = data.shape[-3] == 1 + for channel_idx, data in enumerate(image[DATA]): params = self.get_params( self.degrees_range, self.translation_range, self.num_transforms, - is_2d=is_2d, + is_2d=image.is_2d(), ) times_params, degrees_params, translation_params = params random_parameters_dict = { @@ -96,7 +95,7 @@ def apply_transform(self, sample: Subject) -> dict: random_parameters_images_dict[key] = random_parameters_dict image = nib_to_sitk( data[np.newaxis], - image_dict[AFFINE], + image[AFFINE], force_3d=True, ) transforms = self.get_rigid_transforms( @@ -112,7 +111,7 @@ def apply_transform(self, sample: Subject) -> dict: ) result_arrays.append(data) result = np.stack(result_arrays) - image_dict[DATA] = torch.from_numpy(result) + image[DATA] = torch.from_numpy(result) sample.add_transform(self, random_parameters_images_dict) return sample @@ -130,8 +129,8 @@ def get_params( translation_params = get_params_array( translation_range, num_transforms) if is_2d: # imagine sagittal (1, A, S) - degrees_params[:, -2:] = 0 # rotate around R axis only - translation_params[:, 0] = 0 # translate in AS plane only + degrees_params[:, :-1] = 0 # rotate around Z axis only + translation_params[:, 2] = 0 # translate in XY plane only step = 1 / (num_transforms + 1) times = torch.arange(0, 1, step)[1:] noise = torch.FloatTensor(num_transforms) diff --git a/torchio/transforms/augmentation/spatial/random_affine.py b/torchio/transforms/augmentation/spatial/random_affine.py index d561b2d4b..15b1a8f25 100644 --- a/torchio/transforms/augmentation/spatial/random_affine.py +++ b/torchio/transforms/augmentation/spatial/random_affine.py @@ -172,8 +172,8 @@ def apply_transform(self, sample: Subject) -> dict: interpolation = self.interpolation if image.is_2d(): - scaling_params[0] = 1 - rotation_params[-2:] = 0 + scaling_params[2] = 1 + rotation_params[:-1] = 0 if self.use_image_center: center = image.get_center(lps=True) diff --git a/torchio/transforms/augmentation/spatial/random_elastic_deformation.py b/torchio/transforms/augmentation/spatial/random_elastic_deformation.py index c6ea4a942..2154b105a 100644 --- a/torchio/transforms/augmentation/spatial/random_elastic_deformation.py +++ b/torchio/transforms/augmentation/spatial/random_elastic_deformation.py @@ -227,7 +227,7 @@ def apply_transform(self, sample: Subject) -> dict: else: interpolation = self.interpolation if image.is_2d(): - bspline_params[..., -3] = 0 # no displacement in LR axis + bspline_params[..., -1] = 0 # no displacement in IS axis image[DATA] = self.apply_bspline_transform( image[DATA], image[AFFINE], From 9319a09c2d1821d08d7d08591db5ef93f07aa823 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 19:03:00 +0100 Subject: [PATCH 06/14] Force 2D images suffixes --- tests/data/test_image.py | 4 ++-- tests/utils.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/data/test_image.py b/tests/data/test_image.py index ce302f335..8a706c9d5 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -150,8 +150,8 @@ def test_with_a_list_of_images_with_different_affines(self): def test_with_a_list_of_2d_paths(self): shape = (5, 5) - path1 = self.get_image_path('path1', shape=shape) - path2 = self.get_image_path('path2', shape=shape) + path1 = self.get_image_path('path1', shape=shape, suffix='.nii') + path2 = self.get_image_path('path2', shape=shape, suffix='.img') image = ScalarImage(path=[path1, path2]) self.assertEqual(image.shape, (2, 5, 5, 1)) self.assertEqual(image[STEM], ['path1', 'path2']) diff --git a/tests/utils.py b/tests/utils.py index 50e780881..f1cfb39ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -135,7 +135,8 @@ def get_image_path( shape=(10, 20, 30), spacing=(1, 1, 1), components=1, - add_nans=False + add_nans=False, + suffix=None, ): shape = (*shape, 1) if len(shape) == 2 else shape data = np.random.rand(components, *shape) @@ -144,7 +145,8 @@ def get_image_path( if add_nans: data[:] = np.nan affine = np.diag((*spacing, 1)) - suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img')) + if suffix is None: + suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img')) path = self.dir / f'{stem}{suffix}' if np.random.rand() > 0.5: path = str(path) From 620ef793e378c48cf77f15646bf3f99be701978e Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 19:03:07 +0100 Subject: [PATCH 07/14] Fix variable name --- torchio/transforms/augmentation/intensity/random_motion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchio/transforms/augmentation/intensity/random_motion.py b/torchio/transforms/augmentation/intensity/random_motion.py index 0701b64c2..f22f66fa2 100644 --- a/torchio/transforms/augmentation/intensity/random_motion.py +++ b/torchio/transforms/augmentation/intensity/random_motion.py @@ -93,7 +93,7 @@ def apply_transform(self, sample: Subject) -> dict: } key = f'{image_name}_channel_{channel_idx}' random_parameters_images_dict[key] = random_parameters_dict - image = nib_to_sitk( + sitk_image = nib_to_sitk( data[np.newaxis], image[AFFINE], force_3d=True, @@ -101,10 +101,10 @@ def apply_transform(self, sample: Subject) -> dict: transforms = self.get_rigid_transforms( degrees_params, translation_params, - image, + sitk_image, ) data = self.add_artifact( - image, + sitk_image, transforms, times_params, self.image_interpolation, From 7d7eaf024bfe27d89bc1f88e1229e395f219cbed Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 19:39:46 +0100 Subject: [PATCH 08/14] Print affine matrices if different --- torchio/data/image.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchio/data/image.py b/torchio/data/image.py index aa6ce2a58..188a77fd1 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -392,7 +392,13 @@ def load(self) -> None: warnings.warn(f'NaNs found in file "{path}"') if not np.array_equal(affine, new_affine): - message = 'Files have different affine matrices' + message = ( + 'Files have different affine matrices.' + f'\nMatrix of {paths[0]}:' + f'\n{affine}' + f'\nMatrix of {path}:' + f'\n{new_affine}' + ) warnings.warn(message, RuntimeWarning) if not tensor.shape[1:] == new_tensor.shape[1:]: From 593ae55e78dbc5b5a7a41b46474958e80a623b3b Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 20:04:16 +0100 Subject: [PATCH 09/14] Add .hdr to suffixes --- tests/data/test_image.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/data/test_image.py b/tests/data/test_image.py index 8a706c9d5..3594fc7ce 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -152,6 +152,7 @@ def test_with_a_list_of_2d_paths(self): shape = (5, 5) path1 = self.get_image_path('path1', shape=shape, suffix='.nii') path2 = self.get_image_path('path2', shape=shape, suffix='.img') - image = ScalarImage(path=[path1, path2]) - self.assertEqual(image.shape, (2, 5, 5, 1)) - self.assertEqual(image[STEM], ['path1', 'path2']) + path3 = self.get_image_path('path3', shape=shape, suffix='.hdr') + image = ScalarImage(path=[path1, path2, path3]) + self.assertEqual(image.shape, (3, 5, 5, 1)) + self.assertEqual(image[STEM], ['path1', 'path2', 'path3']) From 051237aefb8ba6446f3aa3918fe7dda02fe607b9 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 20:05:04 +0100 Subject: [PATCH 10/14] Use NiBabel to write .img files --- torchio/data/io.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchio/data/io.py b/torchio/data/io.py index 407a61d8f..616194faf 100644 --- a/torchio/data/io.py +++ b/torchio/data/io.py @@ -108,10 +108,16 @@ def _write_nibabel( if channels_last: tensor = tensor.permute(1, 2, 3, 0) tensor = tensor.squeeze() if squeeze else tensor - nii = nib.Nifti1Image(np.asarray(tensor), affine) - nii.header['qform_code'] = 1 - nii.header['sform_code'] = 0 - nii.to_filename(str(path)) + suffix = Path(path).suffix + if '.nii' in suffix: + img = nib.Nifti1Image(np.asarray(tensor), affine) + elif '.hdr' in suffix or '.img' in suffix: + img = nib.Nifti1Pair(np.asarray(tensor), affine) + else: + raise nib.loadsave.ImageFileError + img.header['qform_code'] = 1 + img.header['sform_code'] = 0 + img.to_filename(str(path)) def _write_sitk( From cd4d40cb3e4998a4bcd68c637bfb94518ffc3a41 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 20:16:29 +0100 Subject: [PATCH 11/14] Add test for axis names --- tests/data/test_image.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/data/test_image.py b/tests/data/test_image.py index 3594fc7ce..0ff7a94c8 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -149,10 +149,18 @@ def test_with_a_list_of_images_with_different_affines(self): image.load() def test_with_a_list_of_2d_paths(self): - shape = (5, 5) + shape = (5, 6) path1 = self.get_image_path('path1', shape=shape, suffix='.nii') path2 = self.get_image_path('path2', shape=shape, suffix='.img') path3 = self.get_image_path('path3', shape=shape, suffix='.hdr') image = ScalarImage(path=[path1, path2, path3]) self.assertEqual(image.shape, (3, 5, 5, 1)) self.assertEqual(image[STEM], ['path1', 'path2', 'path3']) + + def test_axis_name_2d(self): + path = self.get_image_path('im2d', shape=(5, 6)) + image = ScalarImage(path) + height_idx = image.axis_name_to_index['h'] + width_idx = image.axis_name_to_index['w'] + self.assertEqual(image.height, image.shape[height_idx]) + self.assertEqual(image.width, image.shape[width_idx]) From 425f92f1539875c8c25f33e3af4dc2010dbf2b78 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 20:18:20 +0100 Subject: [PATCH 12/14] Fix Image tests --- tests/data/test_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/data/test_image.py b/tests/data/test_image.py index 0ff7a94c8..ab86e90d5 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -154,13 +154,13 @@ def test_with_a_list_of_2d_paths(self): path2 = self.get_image_path('path2', shape=shape, suffix='.img') path3 = self.get_image_path('path3', shape=shape, suffix='.hdr') image = ScalarImage(path=[path1, path2, path3]) - self.assertEqual(image.shape, (3, 5, 5, 1)) + self.assertEqual(image.shape, (3, 5, 6, 1)) self.assertEqual(image[STEM], ['path1', 'path2', 'path3']) def test_axis_name_2d(self): path = self.get_image_path('im2d', shape=(5, 6)) image = ScalarImage(path) - height_idx = image.axis_name_to_index['h'] - width_idx = image.axis_name_to_index['w'] + height_idx = image.axis_name_to_index('h') + width_idx = image.axis_name_to_index('w') self.assertEqual(image.height, image.shape[height_idx]) self.assertEqual(image.width, image.shape[width_idx]) From edc2a84ac14c9f1323ffe010bf32bfacafb785ac Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 20:39:28 +0100 Subject: [PATCH 13/14] Ignore gunzip suffix --- torchio/data/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchio/data/io.py b/torchio/data/io.py index 616194faf..2b02139d3 100644 --- a/torchio/data/io.py +++ b/torchio/data/io.py @@ -108,7 +108,7 @@ def _write_nibabel( if channels_last: tensor = tensor.permute(1, 2, 3, 0) tensor = tensor.squeeze() if squeeze else tensor - suffix = Path(path).suffix + suffix = Path(str(path).replace('.gz', '')).suffix if '.nii' in suffix: img = nib.Nifti1Image(np.asarray(tensor), affine) elif '.hdr' in suffix or '.img' in suffix: From 27caf580ef5e998b0601d6cff01aca7bb0684c61 Mon Sep 17 00:00:00 2001 From: Fernando Date: Mon, 17 Aug 2020 21:11:01 +0100 Subject: [PATCH 14/14] Fix random swap per channel --- .../augmentation/intensity/random_swap.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/torchio/transforms/augmentation/intensity/random_swap.py b/torchio/transforms/augmentation/intensity/random_swap.py index e5a0fc4b9..e9913053e 100644 --- a/torchio/transforms/augmentation/intensity/random_swap.py +++ b/torchio/transforms/augmentation/intensity/random_swap.py @@ -53,10 +53,8 @@ def get_params(): def apply_transform(self, sample: Subject) -> dict: for image in self.get_images(sample): tensors = [] - for tensor in image[DATA]: - tensor = swap(tensor, self.patch_size, self.num_iterations) - tensors.append(tensor) - image[DATA] = torch.stack(tensors) + tensor = image[DATA] + image[DATA] = swap(tensor, self.patch_size, self.num_iterations) return sample @@ -69,12 +67,12 @@ def swap( patch_size = to_tuple(patch_size) for _ in range(num_iterations): first_ini, first_fin = get_random_indices_from_shape( - tensor.shape, + tensor.shape[-3:], patch_size, ) while True: second_ini, second_fin = get_random_indices_from_shape( - tensor.shape, + tensor.shape[-3:], patch_size, ) larger_than_initial = np.all(second_ini >= first_ini) @@ -91,10 +89,10 @@ def swap( def insert(tensor: TypeData, patch: TypeData, index_ini: np.ndarray) -> None: - index_fin = index_ini + np.array(patch.shape) + index_fin = index_ini + np.array(patch.shape[-3:]) i_ini, j_ini, k_ini = index_ini i_fin, j_fin, k_fin = index_fin - tensor[i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch + tensor[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch def crop( @@ -104,20 +102,20 @@ def crop( ) -> Union[np.ndarray, torch.Tensor]: i_ini, j_ini, k_ini = index_ini i_fin, j_fin, k_fin = index_fin - return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] + return image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] def get_random_indices_from_shape( - shape: TypeTripletInt, + spatial_shape: TypeTripletInt, patch_size: TypeTripletInt, ) -> Tuple[np.ndarray, np.ndarray]: - shape_array = np.array(shape) + shape_array = np.array(spatial_shape) patch_size_array = np.array(patch_size) max_index_ini = shape_array - patch_size_array if (max_index_ini < 0).any(): message = ( - f'Patch size {patch_size} must not be' - f' larger than image size {shape}' + f'Patch size {patch_size} cannot be' + f' larger than image spatial shape {spatial_shape}' ) raise ValueError(message) max_index_ini = max_index_ini.astype(np.uint16)