diff --git a/docs/source/cli.rst b/docs/source/cli.rst index ebcf7f51c..dee6031ae 100644 --- a/docs/source/cli.rst +++ b/docs/source/cli.rst @@ -7,4 +7,4 @@ Command-line tools A transform can be quickly applied to an image file using the command-line tool ``torchio-transform``:: - $ torchio-transform input.nii.gz RandomMotion output.nii.gz --kwargs "proportion_to_augment=1 num_transforms=4" + $ torchio-transform input.nii.gz RandomMotion output.nii.gz --kwargs "p=1 num_transforms=4" diff --git a/tests/transforms/augmentation/test_random_elastic_deformation.py b/tests/transforms/augmentation/test_random_elastic_deformation.py index 40d8b4b47..293f8a1a6 100644 --- a/tests/transforms/augmentation/test_random_elastic_deformation.py +++ b/tests/transforms/augmentation/test_random_elastic_deformation.py @@ -15,7 +15,7 @@ def test_random_elastic_deformation(self): seed=42, ) keys = ('t1', 't2', 'label') - fixtures = 2953.9197, 2989.769, 2975 + fixtures = 2916.7192, 2955.1265, 2950 transformed = transform(self.sample) for key, fixture in zip(keys, fixtures): sample_data = self.sample[key][torchio.DATA].numpy() @@ -27,11 +27,11 @@ def test_random_elastic_deformation(self): def test_inputs_pta_gt_one(self): with self.assertRaises(ValueError): - RandomElasticDeformation(proportion_to_augment=1.5) + RandomElasticDeformation(p=1.5) def test_inputs_pta_lt_zero(self): with self.assertRaises(ValueError): - RandomElasticDeformation(proportion_to_augment=-1) + RandomElasticDeformation(p=-1) def test_inputs_interpolation_int(self): with self.assertRaises(TypeError): diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 173662739..e7ec3a3b6 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -39,8 +39,8 @@ def test_transforms(self): ToCanonical(), Resample((1, 1.1, 1.25)), RandomFlip(axes=(0, 1, 2), flip_probability=1), - RandomMotion(proportion_to_augment=1), - RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)), + RandomMotion(), + RandomGhosting(axes=(0, 1, 2)), RandomSpike(), RandomNoise(), RandomBlur(), @@ -50,7 +50,7 @@ def test_transforms(self): RescaleIntensity((0, 1)), ZNormalization(masking_method='label'), HistogramStandardization(landmarks_dict=landmarks_dict), - RandomElasticDeformation(proportion_to_augment=1), + RandomElasticDeformation(), RandomAffine(), Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3), Crop((3, 2, 8, 0, 1, 4)), diff --git a/torchio/transforms/augmentation/intensity/random_bias_field.py b/torchio/transforms/augmentation/intensity/random_bias_field.py index 5fb8e4f9a..734faf2fb 100644 --- a/torchio/transforms/augmentation/intensity/random_bias_field.py +++ b/torchio/transforms/augmentation/intensity/random_bias_field.py @@ -15,24 +15,20 @@ class RandomBiasField(RandomTransform): If a tuple :math:`(a, b)` is specified, then :math:`n \sim \mathcal{U}(a, b)`. order: Order of the basis polynomial functions. - proportion_to_augment: Probability that this transform will be applied. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. """ def __init__( self, coefficients: Union[float, Tuple[float, float]] = 0.5, order: int = 3, - proportion_to_augment: float = 1, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.coefficients_range = self.parse_range( coefficients, 'coefficients_range') self.order = order - self.proportion_to_augment = self.parse_probability( - proportion_to_augment, - 'proportion_to_augment', - ) def apply_transform(self, sample: dict) -> dict: for image_name, image_dict in sample.items(): @@ -40,15 +36,11 @@ def apply_transform(self, sample: dict) -> dict: continue if image_dict[TYPE] != INTENSITY: continue - do_augmentation, coefficients = self.get_params( + coefficients = self.get_params( self.order, self.coefficients_range, - self.proportion_to_augment, ) sample[image_name]['random_bias_coefficients'] = coefficients - sample[image_name]['random_bias_do_augmentation'] = do_augmentation - if not do_augmentation: - continue bias_field = self.generate_bias_field( image_dict[DATA], self.order, coefficients) image_with_bias = image_dict[DATA] * torch.from_numpy(bias_field) @@ -59,20 +51,16 @@ def apply_transform(self, sample: dict) -> dict: def get_params( order: int, coefficients_range: Tuple[float, float], - probability: float, ) -> Tuple[bool, np.ndarray]: - """ - Sampling of the appropriate number of coefficients for the creation - of the bias field map - """ + # Sampling of the appropriate number of coefficients for the creation + # of the bias field map random_coefficients = [] for x_order in range(0, order + 1): for y_order in range(0, order + 1 - x_order): for _ in range(0, order + 1 - (x_order + y_order)): number = torch.FloatTensor(1).uniform_(*coefficients_range) random_coefficients.append(number.item()) - do_augmentation = torch.rand(1) < probability - return do_augmentation, np.array(random_coefficients) + return np.array(random_coefficients) @staticmethod def generate_bias_field( @@ -80,10 +68,8 @@ def generate_bias_field( order: int, coefficients: TypeData, ) -> np.ndarray: - """ - Create the bias field map using a linear combination of polynomial - functions and the coefficients previously sampled - """ + # Create the bias field map using a linear combination of polynomial + # functions and the coefficients previously sampled shape = np.array(data.shape[1:]) # first axis is channels half_shape = shape / 2 diff --git a/torchio/transforms/augmentation/intensity/random_blur.py b/torchio/transforms/augmentation/intensity/random_blur.py index 684cd93bc..ad8b6e44c 100644 --- a/torchio/transforms/augmentation/intensity/random_blur.py +++ b/torchio/transforms/augmentation/intensity/random_blur.py @@ -16,14 +16,16 @@ class RandomBlur(RandomTransform): to blur the image along each axis, where :math:`\sigma_i \sim \mathcal{U}(a, b)`. If a single value :math:`n` is provided, then :math:`a = b = n`. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. """ def __init__( self, std: Union[float, Tuple[float, float]] = (0, 4), + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.std_range = self.parse_range(std, 'std') if any(np.array(self.std_range) < 0): message = ( diff --git a/torchio/transforms/augmentation/intensity/random_ghosting.py b/torchio/transforms/augmentation/intensity/random_ghosting.py index 143f5b410..23043fc73 100644 --- a/torchio/transforms/augmentation/intensity/random_ghosting.py +++ b/torchio/transforms/augmentation/intensity/random_ghosting.py @@ -18,7 +18,7 @@ class RandomGhosting(RandomTransform): axes: Axis along which the ghosts will be created. If :py:attr:`axes` is a tuple, the axis will be randomly chosen from the passed values. - proportion_to_augment: Probability that this transform will be applied. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. .. note:: The execution time of this transform does not depend on the @@ -28,14 +28,10 @@ def __init__( self, num_ghosts: Union[int, Tuple[int, int]] = (4, 10), axes: Union[int, Tuple[int, ...]] = (0, 1, 2), - proportion_to_augment: float = 1, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) - self.proportion_to_augment = self.parse_probability( - proportion_to_augment, - 'proportion_to_augment', - ) + super().__init__(p=p, seed=seed) if not isinstance(axes, tuple): axes = (axes,) self.axes = axes @@ -53,14 +49,10 @@ def apply_transform(self, sample: dict) -> dict: params = self.get_params( self.num_ghosts_range, self.axes, - self.proportion_to_augment, ) - num_ghosts_param, axis_param, do_it = params + num_ghosts_param, axis_param = params sample[image_name]['random_ghosting_axis'] = axis_param sample[image_name]['random_ghosting_num_ghosts'] = num_ghosts_param - sample[image_name]['random_ghosting_do'] = do_it - if not do_it: - return sample if (image_dict[DATA][0] < -0.1).any(): # I use -0.1 instead of 0 because Python was warning me when # a value in a voxel was -7.191084e-35 @@ -91,13 +83,11 @@ def apply_transform(self, sample: dict) -> dict: def get_params( num_ghosts_range: Tuple[int, int], axes: Tuple[int, ...], - probability: float, ) -> Tuple: ng_min, ng_max = num_ghosts_range num_ghosts_param = torch.randint(ng_min, ng_max + 1, (1,)).item() axis_param = axes[torch.randint(0, len(axes), (1,))] - do_it = torch.rand(1) < probability - return num_ghosts_param, axis_param, do_it + return num_ghosts_param, axis_param def add_artifact( self, diff --git a/torchio/transforms/augmentation/intensity/random_motion.py b/torchio/transforms/augmentation/intensity/random_motion.py index 6f6e902d1..afb3286ed 100644 --- a/torchio/transforms/augmentation/intensity/random_motion.py +++ b/torchio/transforms/augmentation/intensity/random_motion.py @@ -43,7 +43,7 @@ class RandomMotion(RandomTransform): num_transforms: Number of simulated movements. Larger values generate more distorted images. image_interpolation: See :ref:`Interpolation`. - proportion_to_augment: Probability that this transform will be applied. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. .. warning:: Large numbers of movements lead to longer execution times. @@ -54,18 +54,14 @@ def __init__( translation: float = 10, # in mm num_transforms: int = 2, image_interpolation: Interpolation = Interpolation.LINEAR, - proportion_to_augment: float = 1, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.degrees_range = self.parse_degrees(degrees) self.translation_range = self.parse_translation(translation) self.num_transforms = num_transforms self.image_interpolation = image_interpolation - self.proportion_to_augment = self.parse_probability( - proportion_to_augment, - 'proportion_to_augment', - ) def apply_transform(self, sample: dict) -> dict: for image_name, image_dict in sample.items(): @@ -77,19 +73,15 @@ def apply_transform(self, sample: dict) -> dict: self.degrees_range, self.translation_range, self.num_transforms, - self.proportion_to_augment ) - times_params, degrees_params, translation_params, do_it = params + times_params, degrees_params, translation_params = params keys = ( 'random_motion_times', 'random_motion_degrees', 'random_motion_translation', - 'random_motion_do', ) for key, param in zip(keys, params): sample[image_name][key] = param - if not do_it: - return sample if (image_dict[DATA][0] < -0.1).any(): # I use -0.1 instead of 0 because Python was warning me when # a value in a voxel was -7.191084e-35 @@ -127,7 +119,6 @@ def get_params( degrees_range: Tuple[float, float], translation_range: Tuple[float, float], num_transforms: int, - probability: float, perturbation: float = 0.3, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]: # If perturbation is 0, time intervals between movements are constant @@ -141,8 +132,7 @@ def get_params( noise.uniform_(-step * perturbation, step * perturbation) times += noise times_params = times.numpy() - do_it = torch.rand(1) < probability - return times_params, degrees_params, translation_params, do_it + return times_params, degrees_params, translation_params def get_rigid_transforms( self, diff --git a/torchio/transforms/augmentation/intensity/random_noise.py b/torchio/transforms/augmentation/intensity/random_noise.py index 333cd936e..3610c4b21 100644 --- a/torchio/transforms/augmentation/intensity/random_noise.py +++ b/torchio/transforms/augmentation/intensity/random_noise.py @@ -14,14 +14,16 @@ class RandomNoise(RandomTransform): from which the noise is sampled. If two values :math:`(a, b)` are providede, then :math:`\sigma \sim \mathcal{U}(a, b)`. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. """ def __init__( self, std: Tuple[float, float] = (0, 0.25), + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.std_range = self.parse_range(std, 'std') if any(np.array(self.std_range) < 0): message = ( diff --git a/torchio/transforms/augmentation/intensity/random_spike.py b/torchio/transforms/augmentation/intensity/random_spike.py index 6a2daffb6..b1bcc3a3a 100644 --- a/torchio/transforms/augmentation/intensity/random_spike.py +++ b/torchio/transforms/augmentation/intensity/random_spike.py @@ -19,7 +19,7 @@ class RandomSpike(RandomTransform): intensity: Ratio :math:`r` between the spike intensity and the maximum of the spectrum. Larger values generate more distorted images. - proportion_to_augment: Probability that this transform will be applied. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. .. note:: The execution time of this transform does not depend on the @@ -29,14 +29,10 @@ def __init__( self, num_spikes: Union[int, Tuple[int, int]] = 1, intensity: Union[float, Tuple[float, float]] = (0.1, 1), - proportion_to_augment: float = 1, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) - self.proportion_to_augment = self.parse_probability( - proportion_to_augment, - 'proportion_to_augment', - ) + super().__init__(p=p, seed=seed) self.intensity_range = self.parse_range( intensity, 'intensity_range') if isinstance(num_spikes, int): @@ -53,14 +49,10 @@ def apply_transform(self, sample: dict) -> dict: params = self.get_params( self.num_spikes_range, self.intensity_range, - self.proportion_to_augment, ) - num_spikes_param, intensity_param, do_it = params + num_spikes_param, intensity_param = params sample[image_name]['random_spike_intensity'] = intensity_param sample[image_name]['random_spike_num_spikes'] = num_spikes_param - sample[image_name]['random_spike_do'] = do_it - if not do_it: - return sample if (image_dict[DATA][0] < -0.1).any(): # I use -0.1 instead of 0 because Python was warning me when # a value in a voxel was -7.191084e-35 @@ -91,13 +83,11 @@ def apply_transform(self, sample: dict) -> dict: def get_params( num_spikes_range: Tuple[int, int], intensity_range: Tuple[float, float], - probability: float, ) -> Tuple: ns_min, ns_max = num_spikes_range num_spikes_param = torch.randint(ns_min, ns_max + 1, (1,)).item() intensity_param = torch.FloatTensor(1).uniform_(*intensity_range) - do_it = torch.rand(1) < probability - return num_spikes_param, intensity_param.item(), do_it + return num_spikes_param, intensity_param.item() def add_artifact( self, diff --git a/torchio/transforms/augmentation/intensity/random_swap.py b/torchio/transforms/augmentation/intensity/random_swap.py index d6680c5bd..bab24b256 100644 --- a/torchio/transforms/augmentation/intensity/random_swap.py +++ b/torchio/transforms/augmentation/intensity/random_swap.py @@ -15,15 +15,17 @@ class RandomSwap(RandomTransform): of size :math:`d \times h \times w`. If a single number :math:`n` is provided, :math:`d = h = w = n`. num_iterations: Number of times that two patches will be swapped. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. """ def __init__( self, patch_size: TypeTuple = 15, num_iterations: int = 100, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.patch_size = to_tuple(patch_size) self.num_iterations = num_iterations diff --git a/torchio/transforms/augmentation/random_transform.py b/torchio/transforms/augmentation/random_transform.py index 331686e37..134d6cb4d 100644 --- a/torchio/transforms/augmentation/random_transform.py +++ b/torchio/transforms/augmentation/random_transform.py @@ -14,13 +14,15 @@ class RandomTransform(Transform): """Base class for stochastic augmentation transforms. Args: - seed: Seed for :mod:`torch` random number generator. + p: Probability that this transform will be applied. + seed: Seed for :py:mod:`torch` random number generator. """ def __init__( self, + p: float = 1, seed: Optional[int] = None, ): - super().__init__() + super().__init__(p=p) self._seed = seed def __call__(self, sample: dict): @@ -79,13 +81,6 @@ def parse_translation( ) -> Tuple[float, float]: return self.parse_range(translation, 'translation') - @staticmethod - def parse_probability(probability: float, name: str) -> float: - is_number = isinstance(probability, numbers.Number) - if not (is_number and 0 <= probability <= 1): - raise ValueError(f'{name} must be a number in [0, 1]') - return probability - @staticmethod def parse_interpolation(interpolation: Interpolation) -> Interpolation: if not isinstance(interpolation, Interpolation): diff --git a/torchio/transforms/augmentation/spatial/random_affine.py b/torchio/transforms/augmentation/spatial/random_affine.py index 77cac2971..f0033e963 100644 --- a/torchio/transforms/augmentation/spatial/random_affine.py +++ b/torchio/transforms/augmentation/spatial/random_affine.py @@ -35,6 +35,7 @@ class RandomAffine(RandomTransform): border that lie under an `Otsu threshold `_. image_interpolation: See :ref:`Interpolation`. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. .. note:: Rotations are performed around the center of the image. @@ -63,9 +64,10 @@ def __init__( isotropic: bool = False, default_pad_value: Union[str, float] = 'otsu', image_interpolation: Interpolation = Interpolation.LINEAR, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.scales = scales self.degrees = self.parse_degrees(degrees) self.isotropic = isotropic diff --git a/torchio/transforms/augmentation/spatial/random_elastic_deformation.py b/torchio/transforms/augmentation/spatial/random_elastic_deformation.py index f748835d9..17f11e2d1 100644 --- a/torchio/transforms/augmentation/spatial/random_elastic_deformation.py +++ b/torchio/transforms/augmentation/spatial/random_elastic_deformation.py @@ -53,7 +53,7 @@ class RandomElasticDeformation(RandomTransform): The value of the dense displacement at each voxel is always interpolated with cubic B-splines from the values at the control points of the coarse grid. - proportion_to_augment: Probability that this transform will be applied. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. `This gist `_ @@ -114,10 +114,10 @@ def __init__( max_displacement: Union[float, Tuple[float, float, float]] = 7.5, locked_borders: int = 2, image_interpolation: Interpolation = Interpolation.LINEAR, - proportion_to_augment: float = 1, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self._bspline_transformation = None self.num_control_points = to_tuple(num_control_points, length=3) self.parse_control_points(self.num_control_points) @@ -133,10 +133,6 @@ def __init__( ' or use more control points.' ) raise ValueError(message) - self.proportion_to_augment = self.parse_probability( - proportion_to_augment, - 'proportion_to_augment', - ) self.interpolation = self.parse_interpolation(image_interpolation) @staticmethod @@ -169,7 +165,6 @@ def get_params( num_control_points: Tuple[int, int, int], max_displacement: Tuple[float, float, float], num_locked_borders: int, - probability: float, ) -> Tuple: grid_shape = num_control_points num_dimensions = 3 @@ -187,8 +182,7 @@ def get_params( coarse_field[:, i] = 0 coarse_field[:, -1 - i] = 0 - do_augmentation = torch.rand(1) < probability - return do_augmentation, coarse_field.numpy() + return coarse_field.numpy() @staticmethod def get_bspline_transform( @@ -232,16 +226,12 @@ def apply_transform(self, sample: dict) -> dict: else: interpolation = self.interpolation if bspline_params is None: - do_augmentation, bspline_params = self.get_params( + bspline_params = self.get_params( self.num_control_points, self.max_displacement, self.num_locked_borders, - self.proportion_to_augment, ) params_dict['bspline_params'] = bspline_params - params_dict['do_augmentation'] = int(do_augmentation) - if not do_augmentation: - return sample image_dict[DATA] = self.apply_bspline_transform( image_dict[DATA], image_dict[AFFINE], diff --git a/torchio/transforms/augmentation/spatial/random_flip.py b/torchio/transforms/augmentation/spatial/random_flip.py index ed63d99e8..9511cfa2f 100644 --- a/torchio/transforms/augmentation/spatial/random_flip.py +++ b/torchio/transforms/augmentation/spatial/random_flip.py @@ -12,6 +12,7 @@ class RandomFlip(RandomTransform): axes: Axis or tuple of axes along which the image will be flipped. flip_probability: Probability that the image will be flipped. This is computed on a per-axis basis. + p: Probability that this transform will be applied. seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`. """ @@ -19,13 +20,13 @@ def __init__( self, axes: Union[int, Tuple[int, ...]] = 0, flip_probability: float = 0.5, + p: float = 1, seed: Optional[int] = None, ): - super().__init__(seed=seed) + super().__init__(p=p, seed=seed) self.axes = self.parse_axes(axes) self.flip_probability = self.parse_probability( flip_probability, - 'flip_probability', ) def apply_transform(self, sample: dict) -> dict: diff --git a/torchio/transforms/lambda_transform.py b/torchio/transforms/lambda_transform.py index 585dc76e7..3f97b0237 100644 --- a/torchio/transforms/lambda_transform.py +++ b/torchio/transforms/lambda_transform.py @@ -14,6 +14,7 @@ class Lambda(Transform): types_to_apply: List of strings corresponding to the image types to which this transform should be applied. If ``None``, the transform will be applied to all images in the sample. + p: Probability that this transform will be applied. Example: >>> import torchio @@ -28,8 +29,9 @@ def __init__( self, function: TypeCallable, types_to_apply: Optional[Sequence[str]] = None, + p: float = 1, ): - super().__init__() + super().__init__(p=p) self.function = function self.types_to_apply = types_to_apply diff --git a/torchio/transforms/preprocessing/intensity/histogram_standardization.py b/torchio/transforms/preprocessing/intensity/histogram_standardization.py index 970a846cf..c1fe1aa6f 100644 --- a/torchio/transforms/preprocessing/intensity/histogram_standardization.py +++ b/torchio/transforms/preprocessing/intensity/histogram_standardization.py @@ -27,14 +27,15 @@ class HistogramStandardization(NormalizationTransform): with :py:meth:`torchio.transforms.HistogramStandardization.train`. masking_method: See :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`. - + p: Probability that this transform will be applied. """ def __init__( self, landmarks_dict: Dict[str, np.ndarray], masking_method: Union[str, TypeCallable, None] = None, + p: float = 1, ): - super().__init__(masking_method=masking_method) + super().__init__(masking_method=masking_method, p=p) self.landmarks_dict = landmarks_dict def apply_normalization( diff --git a/torchio/transforms/preprocessing/intensity/normalization_transform.py b/torchio/transforms/preprocessing/intensity/normalization_transform.py index 094d43a56..27f848adb 100644 --- a/torchio/transforms/preprocessing/intensity/normalization_transform.py +++ b/torchio/transforms/preprocessing/intensity/normalization_transform.py @@ -35,6 +35,7 @@ class NormalizationTransform(Transform): def __init__( self, masking_method: Union[str, TypeCallable, None] = None, + p: float = 1, ): """ masking_method is used to choose the values used for normalization. @@ -43,7 +44,7 @@ def __init__( - A function: the mask will be computed using the function - None: all values are used """ - super().__init__() + super().__init__(p=p) self.mask_name = None if masking_method is None: self.masking_method = self.ones diff --git a/torchio/transforms/preprocessing/intensity/rescale.py b/torchio/transforms/preprocessing/intensity/rescale.py index e2fc1243c..d3305f957 100644 --- a/torchio/transforms/preprocessing/intensity/rescale.py +++ b/torchio/transforms/preprocessing/intensity/rescale.py @@ -20,6 +20,7 @@ class RescaleIntensity(NormalizationTransform): Isensee et al. use ``(0.05, 99.5)`` in their `nn-UNet paper`_. masking_method: See :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`. + p: Probability that this transform will be applied. .. _this scikit-image example: https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_equalize.html#sphx-glr-auto-examples-color-exposure-plot-equalize-py .. _nn-UNet paper: https://arxiv.org/abs/1809.10486 @@ -29,8 +30,9 @@ def __init__( out_min_max: Tuple[float, float], percentiles: Tuple[int, int] = (0, 100), masking_method: Union[str, TypeCallable, None] = None, + p: float = 1, ): - super().__init__(masking_method=masking_method) + super().__init__(masking_method=masking_method, p=p) self.out_min, self.out_max = out_min_max self.percentiles = percentiles diff --git a/torchio/transforms/preprocessing/intensity/z_normalization.py b/torchio/transforms/preprocessing/intensity/z_normalization.py index 4e406d93e..f3fd0862f 100644 --- a/torchio/transforms/preprocessing/intensity/z_normalization.py +++ b/torchio/transforms/preprocessing/intensity/z_normalization.py @@ -10,9 +10,14 @@ class ZNormalization(NormalizationTransform): Args: masking_method: See :py:class:`~torchio.transforms.preprocessing.normalization_transform.NormalizationTransform`. + p: Probability that this transform will be applied. """ - def __init__(self, masking_method: Union[str, TypeCallable, None] = None): - super().__init__(masking_method=masking_method) + def __init__( + self, + masking_method: Union[str, TypeCallable, None] = None, + p: float = 1, + ): + super().__init__(masking_method=masking_method, p=p) def apply_normalization( self, diff --git a/torchio/transforms/preprocessing/spatial/bounds_transform.py b/torchio/transforms/preprocessing/spatial/bounds_transform.py index d043d3597..c06f08751 100644 --- a/torchio/transforms/preprocessing/spatial/bounds_transform.py +++ b/torchio/transforms/preprocessing/spatial/bounds_transform.py @@ -16,12 +16,19 @@ class BoundsTransform(Transform): - """Base class for transforms that change image bounds.""" + """Base class for transforms that change image bounds. + + Args: + bounds_parameters: + p: Probability that this transform will be applied. + + """ def __init__( self, bounds_parameters: TypeBounds, + p: float = 1, ): - super().__init__() + super().__init__(p=p) self.bounds_parameters = self.parse_bounds(bounds_parameters) @property diff --git a/torchio/transforms/preprocessing/spatial/crop_or_pad.py b/torchio/transforms/preprocessing/spatial/crop_or_pad.py index 5212a664a..5672b69bc 100644 --- a/torchio/transforms/preprocessing/spatial/crop_or_pad.py +++ b/torchio/transforms/preprocessing/spatial/crop_or_pad.py @@ -26,6 +26,7 @@ class CropOrPad(BoundsTransform): If a string is given, the output volume center will be the center of the bounding box of non-zero values in the image named :py:attr:`mask_name`. + p: Probability that this transform will be applied. Example: >>> import torchio @@ -51,8 +52,9 @@ def __init__( padding_mode: str = 'constant', padding_fill: Optional[float] = None, mask_name: Optional[str] = None, + p: float = 1, ): - super().__init__(target_shape) + super().__init__(target_shape, p=p) self.padding_mode = padding_mode self.padding_fill = padding_fill if mask_name is not None and not isinstance(mask_name, str): diff --git a/torchio/transforms/preprocessing/spatial/pad.py b/torchio/transforms/preprocessing/spatial/pad.py index 0071b25a2..771a041bf 100644 --- a/torchio/transforms/preprocessing/spatial/pad.py +++ b/torchio/transforms/preprocessing/spatial/pad.py @@ -39,6 +39,7 @@ class Pad(BoundsTransform): fill: Value for constant fill. Default is ``0``. This value is only used when :attr:`padding_mode` is ``constant``. + p: Probability that this transform will be applied. """ @@ -57,13 +58,14 @@ def __init__( padding: TypeBounds, padding_mode: str = 'constant', fill: float = None, + p: float = 1, ): """ padding_mode can be 'constant', 'reflect', 'replicate' or 'circular'. See https://pytorch.org/docs/stable/nn.functional.html#pad for more information about this transform. """ - super().__init__(padding) + super().__init__(padding, p=p) if fill is not None and padding_mode != 'constant': message = ( 'If the value of "fill" is not None,' diff --git a/torchio/transforms/preprocessing/spatial/resample.py b/torchio/transforms/preprocessing/spatial/resample.py index 6013f2fcf..47ac4c084 100644 --- a/torchio/transforms/preprocessing/spatial/resample.py +++ b/torchio/transforms/preprocessing/spatial/resample.py @@ -27,6 +27,7 @@ class Resample(Transform): :py:attr:`torchio.Interpolation.NEAREST`, :py:attr:`torchio.Interpolation.LINEAR` and :py:attr:`torchio.Interpolation.BSPLINE`. + p: Probability that this transform will be applied. .. note:: Resampling is performed using :py:meth:`nibabel.processing.resample_to_output` or @@ -45,8 +46,9 @@ def __init__( target: Union[TypeSpacing, str], antialiasing: bool = True, image_interpolation: Interpolation = Interpolation.LINEAR, + p: float = 1, ): - super().__init__() + super().__init__(p=p) self.target_spacing: Tuple[float, float, float] self.reference_image: str self.parse_target(target) diff --git a/torchio/transforms/preprocessing/spatial/to_canonical.py b/torchio/transforms/preprocessing/spatial/to_canonical.py index 566b61816..deace1a06 100644 --- a/torchio/transforms/preprocessing/spatial/to_canonical.py +++ b/torchio/transforms/preprocessing/spatial/to_canonical.py @@ -18,6 +18,9 @@ class ToCanonical(Transform): See `NiBabel docs about image orientation`_ for more information. + Args: + p: Probability that this transform will be applied. + .. note:: The reorientation is performed using :py:meth:`nibabel.as_closest_canonical`. diff --git a/torchio/transforms/transform.py b/torchio/transforms/transform.py index c3ef144ad..9c591b75c 100644 --- a/torchio/transforms/transform.py +++ b/torchio/transforms/transform.py @@ -1,7 +1,11 @@ +import numbers import warnings from copy import deepcopy from abc import ABC, abstractmethod + +import torch import SimpleITK as sitk + from ..utils import is_image_dict, nib_to_sitk, sitk_to_nib from .. import TypeData, TYPE @@ -14,10 +18,18 @@ class Transform(ABC): All subclasses should overwrite :py:meth:`torchio.tranforms.Transform.apply_transform`, which takes a sample, applies some transformation and returns the result. + + Args: + p: Probability that this transform will be applied. """ + def __init__(self, p: float = 1): + self.probability = self.parse_probability(p) + def __call__(self, sample: dict): """Transform a sample and return the result.""" self.parse_sample(sample) + if torch.rand(1).item() > self.probability: + return sample sample = deepcopy(sample) sample = self.apply_transform(sample) return sample @@ -26,6 +38,17 @@ def __call__(self, sample: dict): def apply_transform(self, sample: dict): raise NotImplementedError + @staticmethod + def parse_probability(probability: float) -> float: + is_number = isinstance(probability, numbers.Number) + if not (is_number and 0 <= probability <= 1): + message = ( + 'Probability must be a number in [0, 1],' + f' not {probability}' + ) + raise ValueError(message) + return probability + @staticmethod def parse_sample(sample: dict) -> None: images_found = False