Skip to content

Commit

Permalink
Use torch to get random b-spline parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Dec 30, 2019
1 parent f652384 commit e5f5971
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions torchio/transforms/random_elastic_deformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,28 @@ def __init__(
self.verbose = verbose

def apply_transform(self, sample):
bspline_params = self.get_params()

# only do augmentation with a probability `proportion_to_augment`
# Only do augmentation with a probability `proportion_to_augment`
do_augmentation = torch.rand(1) < self.proportion_to_augment
if not do_augmentation:
return sample

bspline_params = None
for image_dict in sample.values():
if not is_image_dict(image_dict):
continue
if image_dict['type'] == LABEL:
interpolation = Interpolation.NEAREST
else:
interpolation = self.image_interpolation
# TODO: assert that all images have the same shape
if bspline_params is None:
image = self.nib_to_sitk(
image_dict['data'].squeeze(), image_dict['affine'])
bspline_params = self.get_params(
image,
self.num_control_points,
self.deformation_std,
)
image_dict['data'] = self.apply_bspline_transform(
image_dict['data'],
image_dict['affine'],
Expand All @@ -52,23 +60,18 @@ def apply_transform(self, sample):
return sample

@staticmethod
def get_params():
pass
def get_params(image, num_control_points, deformation_std):
mesh_shape = 3 * (num_control_points,)
bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape)
default_params = bspline_transform.GetParameters()
bspline_params = torch.rand(len(default_params)) * deformation_std
return bspline_params.numpy()

@staticmethod
def get_bspline_transform(shape, deformation_std, num_control_points):
shape = list(reversed(shape))
shape[0], shape[2] = shape[2], shape[0]
itkimg = sitk.GetImageFromArray(np.zeros(shape))
trans_from_domain_mesh_size = 3 * [num_control_points]
bspline_transform = sitk.BSplineTransformInitializer(
itkimg, trans_from_domain_mesh_size)
params = bspline_transform.GetParameters()
params_numpy = np.asarray(params, dtype=float)
params_numpy = params_numpy + np.random.randn(
params_numpy.shape[0]) * deformation_std
params = tuple(params_numpy)
bspline_transform.SetParameters(params)
def get_bspline_transform(image, num_control_points, bspline_params):
mesh_shape = 3 * (num_control_points,)
bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape)
bspline_transform.SetParameters(bspline_params.tolist())
return bspline_transform

def apply_bspline_transform(
Expand All @@ -87,19 +90,19 @@ def apply_bspline_transform(
for i, channel_array in enumerate(array): # use sitk.VectorImage?
image = self.nib_to_sitk(channel_array, affine)
bspline_transform = self.get_bspline_transform(
channel_array.shape,
self.deformation_std,
image,
self.num_control_points,
bspline_params,
)

resampler = sitk.ResampleImageFilter()
resampler.SetInterpolator(interpolation.value)
resampler.SetReferenceImage(image)
resampler.SetDefaultPixelValue(0)
resampler.SetDefaultPixelValue(0) # should I change this?
resampler.SetTransform(bspline_transform)
resampled = resampler.Execute(image)

channel_array = sitk.GetArrayFromImage(resampled)
channel_array = channel_array.transpose(2, 1, 0) # ITK to NumPy
channel_array = channel_array.transpose() # ITK to NumPy
array[i] = channel_array
return array

0 comments on commit e5f5971

Please sign in to comment.