Skip to content

Commit

Permalink
Fix approximation errors in CropOrPad
Browse files Browse the repository at this point in the history
Fixes #462.

Refactoring of the logic to compute cropping and padding values from
the center of a binary image.
  • Loading branch information
fepegar committed Feb 12, 2021
1 parent 68ed91a commit 56386c4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 80 deletions.
29 changes: 0 additions & 29 deletions tests/transforms/preprocessing/test_crop_pad.py
Expand Up @@ -161,35 +161,6 @@ def test_mask_corners(self):
'Physical position is different after cropping',
)

def test_mask_origin(self):
target_shape = 7, 21, 29
center_voxel = np.floor(np.array(target_shape) / 2).astype(int)
transform_center = tio.CropOrPad(target_shape)
transform_mask = tio.CropOrPad(
target_shape, mask_name='label')
mask = self.sample_subject['label'].data
mask *= 0
mask[0, 0, 0, 0] = 1
transformed_center = transform_center(self.sample_subject)
transformed_mask = transform_mask(self.sample_subject)
zipped = zip(transformed_center.values(), transformed_mask.values())
for image_center, image_mask in zipped:
# Arrays are different
self.assertTensorNotEqual(image_center.data, image_mask.data)
# Rotation matrix doesn't change
center_rotation = image_center.affine[:3, :3]
mask_rotation = image_mask.affine[:3, :3]
self.assertTensorEqual(center_rotation, mask_rotation)
# Origin does change
center_origin = image_center.affine[:3, 3]
mask_origin = image_mask.affine[:3, 3]
self.assertTensorNotEqual(center_origin, mask_origin)
# Voxel at origin is center of transformed image
origin_value = image_center.data[0, 0, 0, 0]
i, j, k = center_voxel
transformed_value = image_mask.data[0, i, j, k]
self.assertEqual(origin_value, transformed_value)

def test_2d(self):
# https://github.com/fepegar/torchio/issues/434
image = np.random.rand(1, 16, 16, 1)
Expand Down
73 changes: 44 additions & 29 deletions torchio/transforms/preprocessing/spatial/crop_or_pad.py
Expand Up @@ -8,7 +8,6 @@
from .bounds_transform import BoundsTransform
from ...transform import TypeTripletInt, TypeSixBounds
from ....data.subject import Subject
from ....utils import round_up


class CropOrPad(BoundsTransform):
Expand Down Expand Up @@ -89,7 +88,7 @@ def _bbox_mask(mask_volume: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
j_min, j_max = np.where(j_any)[0][[0, -1]]
k_min, k_max = np.where(k_any)[0][[0, -1]]
bb_min = np.array([i_min, j_min, k_min])
bb_max = np.array([i_max, j_max, k_max])
bb_max = np.array([i_max, j_max, k_max]) + 1
return bb_min, bb_max

@staticmethod
Expand Down Expand Up @@ -119,6 +118,10 @@ def _get_six_bounds_parameters(
result.extend([ini, fin])
return tuple(result)

@property
def target_shape(self):
return self.bounds_parameters[::2]

def _compute_cropping_padding_from_shapes(
self,
source_shape: TypeTripletInt,
Expand Down Expand Up @@ -176,40 +179,50 @@ def _compute_mask_center_crop_or_pad(
warnings.warn(message, RuntimeWarning)
return self._compute_center_crop_or_pad(subject=subject)

# Original subject shape (from mask shape)
# Let's assume that the center of first voxel is at coordinate 0.5
# (which is typically not the case)
subject_shape = subject.spatial_shape
# Calculate bounding box of the mask center
bb_min, bb_max = self._bbox_mask(mask[0])
# Coordinates of the mask center
center_mask = (bb_max - bb_min) / 2 + bb_min
# List of padding to do
center_mask = np.mean((bb_min, bb_max), axis=0)
padding = []
# Final cropping (after padding)
cropping = []
for dim, center_dimension in enumerate(center_mask):
# Compute coordinates of the target shape taken from the center of
# the mask
center_dim = round_up(center_dimension)
begin = center_dim - (self.bounds_parameters[2 * dim] / 2)
end = center_dim + (self.bounds_parameters[2 * dim + 1] / 2)
# Check if dimension needs padding (before or after)
begin_pad = round_up(abs(min(begin, 0)))
end_pad = round(max(end - subject_shape[dim], 0))
# Check if cropping is needed
begin_crop = round_up(max(begin, 0))
end_crop = abs(round(min(end - subject_shape[dim], 0)))
# Add padding values of the dim to the list
padding.append(begin_pad)
padding.append(end_pad)
# Add the slice of the dimension to take
cropping.append(begin_crop)
cropping.append(end_crop)
target_shape = np.array(self.target_shape)

for dim in range(3):
target_dim = target_shape[dim]
center_dim = center_mask[dim]
subject_dim = subject_shape[dim]

center_on_index = not (center_dim % 1)
target_even = not (target_dim % 2)

# Approximation when the center cannot be computed exactly
# The output will be off by half a voxel, but this is just an
# implementation detail
if target_even ^ center_on_index:
center_dim -= 0.5

begin = center_dim - target_dim / 2
if begin >= 0:
crop_ini = begin
pad_ini = 0
else:
crop_ini = 0
pad_ini = -begin

end = center_dim + target_dim / 2
if end <= subject_dim:
crop_fin = subject_dim - end
pad_fin = 0
else:
crop_fin = 0
pad_fin = end - subject_dim

padding.extend([pad_ini, pad_fin])
cropping.extend([crop_ini, crop_fin])
# Conversion for SimpleITK compatibility
padding = np.asarray(padding, dtype=int)
cropping = np.asarray(cropping, dtype=int)
if subject.is_2d() == 1:
padding[-2:] = 0
cropping[-2:] = 0
padding_params = tuple(padding.tolist()) if padding.any() else None
cropping_params = tuple(cropping.tolist()) if cropping.any() else None
return padding_params, cropping_params
Expand All @@ -221,4 +234,6 @@ def apply_transform(self, subject: Subject) -> Subject:
subject = Pad(padding_params, **padding_kwargs)(subject)
if cropping_params is not None:
subject = Crop(cropping_params)(subject)
actual, target = subject.spatial_shape, self.target_shape
assert actual == target, (actual, target)
return subject
3 changes: 2 additions & 1 deletion torchio/transforms/transform.py
Expand Up @@ -474,10 +474,11 @@ def get_mask_from_anatomical_label(

@staticmethod
def get_mask_from_bounds(
self,
bounds_parameters: TypeBounds,
tensor: torch.Tensor,
) -> torch.Tensor:
bounds_parameters = Transform.parse_bounds(bounds_parameters)
bounds_parameters = self.parse_bounds(bounds_parameters)
low = bounds_parameters[::2]
high = bounds_parameters[1::2]
i0, j0, k0 = low
Expand Down
21 changes: 0 additions & 21 deletions torchio/utils.py
Expand Up @@ -152,27 +152,6 @@ def get_torchio_cache_dir():
return Path('~/.cache/torchio').expanduser()


def round_up(value: float) -> int:
"""Round half towards infinity.
Args:
value: The value to round.
Example:
>>> round(2.5)
2
>>> round(3.5)
4
>>> round_up(2.5)
3
>>> round_up(3.5)
4
"""
return int(np.floor(value + 0.5))


def compress(input_path, output_path):
with open(input_path, 'rb') as f_in:
with gzip.open(output_path, 'wb') as f_out:
Expand Down

0 comments on commit 56386c4

Please sign in to comment.