Skip to content

Commit

Permalink
Add probability argument to all transforms (#124)
Browse files Browse the repository at this point in the history
* Add probability argument to all transforms

Before this commit, the argument `proportion_to_augment` was present for only some computationally expensive transforms.

This commit generalizes this option to all transforms through the `p` argument.

Fixes #123.

Related to #34, #17.
  • Loading branch information
fepegar committed Apr 18, 2020
1 parent cbfff34 commit bb7b5ec
Show file tree
Hide file tree
Showing 25 changed files with 117 additions and 117 deletions.
2 changes: 1 addition & 1 deletion docs/source/cli.rst
Expand Up @@ -7,4 +7,4 @@ Command-line tools
A transform can be quickly applied to an image file using the command-line
tool ``torchio-transform``::

$ torchio-transform input.nii.gz RandomMotion output.nii.gz --kwargs "proportion_to_augment=1 num_transforms=4"
$ torchio-transform input.nii.gz RandomMotion output.nii.gz --kwargs "p=1 num_transforms=4"
Expand Up @@ -15,7 +15,7 @@ def test_random_elastic_deformation(self):
seed=42,
)
keys = ('t1', 't2', 'label')
fixtures = 2953.9197, 2989.769, 2975
fixtures = 2916.7192, 2955.1265, 2950
transformed = transform(self.sample)
for key, fixture in zip(keys, fixtures):
sample_data = self.sample[key][torchio.DATA].numpy()
Expand All @@ -27,11 +27,11 @@ def test_random_elastic_deformation(self):

def test_inputs_pta_gt_one(self):
with self.assertRaises(ValueError):
RandomElasticDeformation(proportion_to_augment=1.5)
RandomElasticDeformation(p=1.5)

def test_inputs_pta_lt_zero(self):
with self.assertRaises(ValueError):
RandomElasticDeformation(proportion_to_augment=-1)
RandomElasticDeformation(p=-1)

def test_inputs_interpolation_int(self):
with self.assertRaises(TypeError):
Expand Down
6 changes: 3 additions & 3 deletions tests/transforms/test_transforms.py
Expand Up @@ -39,8 +39,8 @@ def test_transforms(self):
ToCanonical(),
Resample((1, 1.1, 1.25)),
RandomFlip(axes=(0, 1, 2), flip_probability=1),
RandomMotion(proportion_to_augment=1),
RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)),
RandomMotion(),
RandomGhosting(axes=(0, 1, 2)),
RandomSpike(),
RandomNoise(),
RandomBlur(),
Expand All @@ -50,7 +50,7 @@ def test_transforms(self):
RescaleIntensity((0, 1)),
ZNormalization(masking_method='label'),
HistogramStandardization(landmarks_dict=landmarks_dict),
RandomElasticDeformation(proportion_to_augment=1),
RandomElasticDeformation(),
RandomAffine(),
Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
Crop((3, 2, 8, 0, 1, 4)),
Expand Down
32 changes: 9 additions & 23 deletions torchio/transforms/augmentation/intensity/random_bias_field.py
Expand Up @@ -15,40 +15,32 @@ class RandomBiasField(RandomTransform):
If a tuple :math:`(a, b)` is specified, then
:math:`n \sim \mathcal{U}(a, b)`.
order: Order of the basis polynomial functions.
proportion_to_augment: Probability that this transform will be applied.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
coefficients: Union[float, Tuple[float, float]] = 0.5,
order: int = 3,
proportion_to_augment: float = 1,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
super().__init__(p=p, seed=seed)
self.coefficients_range = self.parse_range(
coefficients, 'coefficients_range')
self.order = order
self.proportion_to_augment = self.parse_probability(
proportion_to_augment,
'proportion_to_augment',
)

def apply_transform(self, sample: dict) -> dict:
for image_name, image_dict in sample.items():
if not is_image_dict(image_dict):
continue
if image_dict[TYPE] != INTENSITY:
continue
do_augmentation, coefficients = self.get_params(
coefficients = self.get_params(
self.order,
self.coefficients_range,
self.proportion_to_augment,
)
sample[image_name]['random_bias_coefficients'] = coefficients
sample[image_name]['random_bias_do_augmentation'] = do_augmentation
if not do_augmentation:
continue
bias_field = self.generate_bias_field(
image_dict[DATA], self.order, coefficients)
image_with_bias = image_dict[DATA] * torch.from_numpy(bias_field)
Expand All @@ -59,31 +51,25 @@ def apply_transform(self, sample: dict) -> dict:
def get_params(
order: int,
coefficients_range: Tuple[float, float],
probability: float,
) -> Tuple[bool, np.ndarray]:
"""
Sampling of the appropriate number of coefficients for the creation
of the bias field map
"""
# Sampling of the appropriate number of coefficients for the creation
# of the bias field map
random_coefficients = []
for x_order in range(0, order + 1):
for y_order in range(0, order + 1 - x_order):
for _ in range(0, order + 1 - (x_order + y_order)):
number = torch.FloatTensor(1).uniform_(*coefficients_range)
random_coefficients.append(number.item())
do_augmentation = torch.rand(1) < probability
return do_augmentation, np.array(random_coefficients)
return np.array(random_coefficients)

@staticmethod
def generate_bias_field(
data: TypeData,
order: int,
coefficients: TypeData,
) -> np.ndarray:
"""
Create the bias field map using a linear combination of polynomial
functions and the coefficients previously sampled
"""
# Create the bias field map using a linear combination of polynomial
# functions and the coefficients previously sampled
shape = np.array(data.shape[1:]) # first axis is channels
half_shape = shape / 2

Expand Down
4 changes: 3 additions & 1 deletion torchio/transforms/augmentation/intensity/random_blur.py
Expand Up @@ -16,14 +16,16 @@ class RandomBlur(RandomTransform):
to blur the image along each axis,
where :math:`\sigma_i \sim \mathcal{U}(a, b)`.
If a single value :math:`n` is provided, then :math:`a = b = n`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
std: Union[float, Tuple[float, float]] = (0, 4),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
super().__init__(p=p, seed=seed)
self.std_range = self.parse_range(std, 'std')
if any(np.array(self.std_range) < 0):
message = (
Expand Down
20 changes: 5 additions & 15 deletions torchio/transforms/augmentation/intensity/random_ghosting.py
Expand Up @@ -18,7 +18,7 @@ class RandomGhosting(RandomTransform):
axes: Axis along which the ghosts will be created. If
:py:attr:`axes` is a tuple, the axis will be randomly chosen
from the passed values.
proportion_to_augment: Probability that this transform will be applied.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. note:: The execution time of this transform does not depend on the
Expand All @@ -28,14 +28,10 @@ def __init__(
self,
num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
proportion_to_augment: float = 1,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
self.proportion_to_augment = self.parse_probability(
proportion_to_augment,
'proportion_to_augment',
)
super().__init__(p=p, seed=seed)
if not isinstance(axes, tuple):
axes = (axes,)
self.axes = axes
Expand All @@ -53,14 +49,10 @@ def apply_transform(self, sample: dict) -> dict:
params = self.get_params(
self.num_ghosts_range,
self.axes,
self.proportion_to_augment,
)
num_ghosts_param, axis_param, do_it = params
num_ghosts_param, axis_param = params
sample[image_name]['random_ghosting_axis'] = axis_param
sample[image_name]['random_ghosting_num_ghosts'] = num_ghosts_param
sample[image_name]['random_ghosting_do'] = do_it
if not do_it:
return sample
if (image_dict[DATA][0] < -0.1).any():
# I use -0.1 instead of 0 because Python was warning me when
# a value in a voxel was -7.191084e-35
Expand Down Expand Up @@ -91,13 +83,11 @@ def apply_transform(self, sample: dict) -> dict:
def get_params(
num_ghosts_range: Tuple[int, int],
axes: Tuple[int, ...],
probability: float,
) -> Tuple:
ng_min, ng_max = num_ghosts_range
num_ghosts_param = torch.randint(ng_min, ng_max + 1, (1,)).item()
axis_param = axes[torch.randint(0, len(axes), (1,))]
do_it = torch.rand(1) < probability
return num_ghosts_param, axis_param, do_it
return num_ghosts_param, axis_param

def add_artifact(
self,
Expand Down
20 changes: 5 additions & 15 deletions torchio/transforms/augmentation/intensity/random_motion.py
Expand Up @@ -43,7 +43,7 @@ class RandomMotion(RandomTransform):
num_transforms: Number of simulated movements.
Larger values generate more distorted images.
image_interpolation: See :ref:`Interpolation`.
proportion_to_augment: Probability that this transform will be applied.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. warning:: Large numbers of movements lead to longer execution times.
Expand All @@ -54,18 +54,14 @@ def __init__(
translation: float = 10, # in mm
num_transforms: int = 2,
image_interpolation: Interpolation = Interpolation.LINEAR,
proportion_to_augment: float = 1,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
super().__init__(p=p, seed=seed)
self.degrees_range = self.parse_degrees(degrees)
self.translation_range = self.parse_translation(translation)
self.num_transforms = num_transforms
self.image_interpolation = image_interpolation
self.proportion_to_augment = self.parse_probability(
proportion_to_augment,
'proportion_to_augment',
)

def apply_transform(self, sample: dict) -> dict:
for image_name, image_dict in sample.items():
Expand All @@ -77,19 +73,15 @@ def apply_transform(self, sample: dict) -> dict:
self.degrees_range,
self.translation_range,
self.num_transforms,
self.proportion_to_augment
)
times_params, degrees_params, translation_params, do_it = params
times_params, degrees_params, translation_params = params
keys = (
'random_motion_times',
'random_motion_degrees',
'random_motion_translation',
'random_motion_do',
)
for key, param in zip(keys, params):
sample[image_name][key] = param
if not do_it:
return sample
if (image_dict[DATA][0] < -0.1).any():
# I use -0.1 instead of 0 because Python was warning me when
# a value in a voxel was -7.191084e-35
Expand Down Expand Up @@ -127,7 +119,6 @@ def get_params(
degrees_range: Tuple[float, float],
translation_range: Tuple[float, float],
num_transforms: int,
probability: float,
perturbation: float = 0.3,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]:
# If perturbation is 0, time intervals between movements are constant
Expand All @@ -141,8 +132,7 @@ def get_params(
noise.uniform_(-step * perturbation, step * perturbation)
times += noise
times_params = times.numpy()
do_it = torch.rand(1) < probability
return times_params, degrees_params, translation_params, do_it
return times_params, degrees_params, translation_params

def get_rigid_transforms(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchio/transforms/augmentation/intensity/random_noise.py
Expand Up @@ -14,14 +14,16 @@ class RandomNoise(RandomTransform):
from which the noise is sampled.
If two values :math:`(a, b)` are providede,
then :math:`\sigma \sim \mathcal{U}(a, b)`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
std: Tuple[float, float] = (0, 0.25),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
super().__init__(p=p, seed=seed)
self.std_range = self.parse_range(std, 'std')
if any(np.array(self.std_range) < 0):
message = (
Expand Down
20 changes: 5 additions & 15 deletions torchio/transforms/augmentation/intensity/random_spike.py
Expand Up @@ -19,7 +19,7 @@ class RandomSpike(RandomTransform):
intensity: Ratio :math:`r` between the spike intensity and the maximum
of the spectrum.
Larger values generate more distorted images.
proportion_to_augment: Probability that this transform will be applied.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. note:: The execution time of this transform does not depend on the
Expand All @@ -29,14 +29,10 @@ def __init__(
self,
num_spikes: Union[int, Tuple[int, int]] = 1,
intensity: Union[float, Tuple[float, float]] = (0.1, 1),
proportion_to_augment: float = 1,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
self.proportion_to_augment = self.parse_probability(
proportion_to_augment,
'proportion_to_augment',
)
super().__init__(p=p, seed=seed)
self.intensity_range = self.parse_range(
intensity, 'intensity_range')
if isinstance(num_spikes, int):
Expand All @@ -53,14 +49,10 @@ def apply_transform(self, sample: dict) -> dict:
params = self.get_params(
self.num_spikes_range,
self.intensity_range,
self.proportion_to_augment,
)
num_spikes_param, intensity_param, do_it = params
num_spikes_param, intensity_param = params
sample[image_name]['random_spike_intensity'] = intensity_param
sample[image_name]['random_spike_num_spikes'] = num_spikes_param
sample[image_name]['random_spike_do'] = do_it
if not do_it:
return sample
if (image_dict[DATA][0] < -0.1).any():
# I use -0.1 instead of 0 because Python was warning me when
# a value in a voxel was -7.191084e-35
Expand Down Expand Up @@ -91,13 +83,11 @@ def apply_transform(self, sample: dict) -> dict:
def get_params(
num_spikes_range: Tuple[int, int],
intensity_range: Tuple[float, float],
probability: float,
) -> Tuple:
ns_min, ns_max = num_spikes_range
num_spikes_param = torch.randint(ns_min, ns_max + 1, (1,)).item()
intensity_param = torch.FloatTensor(1).uniform_(*intensity_range)
do_it = torch.rand(1) < probability
return num_spikes_param, intensity_param.item(), do_it
return num_spikes_param, intensity_param.item()

def add_artifact(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchio/transforms/augmentation/intensity/random_swap.py
Expand Up @@ -15,15 +15,17 @@ class RandomSwap(RandomTransform):
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided, :math:`d = h = w = n`.
num_iterations: Number of times that two patches will be swapped.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
patch_size: TypeTuple = 15,
num_iterations: int = 100,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(seed=seed)
super().__init__(p=p, seed=seed)
self.patch_size = to_tuple(patch_size)
self.num_iterations = num_iterations

Expand Down

0 comments on commit bb7b5ec

Please sign in to comment.