Skip to content

Commit

Permalink
Update transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jul 12, 2020
1 parent 9a4be72 commit f3529c8
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,28 @@ class RandomBiasField(RandomTransform):
:math:`n \sim \mathcal{U}(a, b)`.
order: Order of the basis polynomial functions.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
coefficients: Union[float, Tuple[float, float]] = 0.5,
order: int = 3,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.coefficients_range = self.parse_range(
coefficients, 'coefficients_range')
self.order = order

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
for image_name, image_dict in sample.get_images_dict().items():
coefficients = self.get_params(
self.order,
self.coefficients_range,
)
random_parameters_dict = {'coefficients': coefficients}
random_parameters_images_dict[image_name] = random_parameters_dict

bias_field = self.generate_bias_field(
image_dict[DATA], self.order, coefficients)
image_with_bias = image_dict[DATA] * torch.from_numpy(bias_field)
image_dict[DATA] = image_with_bias
sample.add_transform(self, random_parameters_images_dict)
return sample

@staticmethod
Expand Down
7 changes: 1 addition & 6 deletions torchio/transforms/augmentation/intensity/random_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ class RandomBlur(RandomTransform):
where :math:`\sigma_i \sim \mathcal{U}(a, b)` mm.
If a single value :math:`n` is provided, then :math:`a = b = n`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
std: Union[float, Tuple[float, float]] = (0, 4),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.std_range = self.parse_range(std, 'std')
if any(np.array(self.std_range) < 0):
message = (
Expand All @@ -36,17 +34,14 @@ def __init__(
raise ValueError(message)

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
for image_name, image_dict in sample.get_images_dict().items():
std = self.get_params(self.std_range)
random_parameters_dict = {'std': std}
random_parameters_images_dict[image_name] = random_parameters_dict
image_dict[DATA][0] = blur(
image_dict[DATA][0],
image_dict[AFFINE],
std,
)
sample.add_transform(self, random_parameters_images_dict)
return sample

@staticmethod
Expand Down
12 changes: 1 addition & 11 deletions torchio/transforms/augmentation/intensity/random_ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class RandomGhosting(RandomTransform):
:math:`k`-space center should be restored after removing the planes
that generate the artifact.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. note:: The execution time of this transform does not depend on the
number of ghosts.
Expand All @@ -37,9 +36,8 @@ def __init__(
intensity: Union[float, Tuple[float, float]] = (0.5, 1),
restore: float = 0.02,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
if not isinstance(axes, tuple):
try:
axes = tuple(axes)
Expand Down Expand Up @@ -84,7 +82,6 @@ def parse_intensity(intensity):
return intensity

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
for image_name, image_dict in sample.get_images_dict().items():
data = image_dict[DATA]
is_2d = data.shape[-3] == 1
Expand All @@ -95,20 +92,13 @@ def apply_transform(self, sample: Subject) -> dict:
self.intensity_range,
)
num_ghosts_param, axis_param, intensity_param = params
random_parameters_dict = {
'axis': axis_param,
'num_ghosts': num_ghosts_param,
'intensity': intensity_param,
}
random_parameters_images_dict[image_name] = random_parameters_dict
image_dict[DATA][0] = self.add_artifact(
data[0],
num_ghosts_param,
axis_param,
intensity_param,
self.restore,
)
sample.add_transform(self, random_parameters_images_dict)
return sample

@staticmethod
Expand Down
7 changes: 1 addition & 6 deletions torchio/transforms/augmentation/intensity/random_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class RandomMotion(RandomTransform):
Larger values generate more distorted images.
image_interpolation: See :ref:`Interpolation`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. warning:: Large numbers of movements lead to longer execution times for
3D images.
Expand All @@ -55,9 +54,8 @@ def __init__(
num_transforms: int = 2,
image_interpolation: str = 'linear',
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.degrees_range = self.parse_degrees(degrees)
self.translation_range = self.parse_translation(translation)
if not 0 < num_transforms or not isinstance(num_transforms, int):
Expand All @@ -70,7 +68,6 @@ def __init__(
self.image_interpolation = self.parse_interpolation(image_interpolation)

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
for image_name, image_dict in sample.get_images_dict().items():
data = image_dict[DATA]
is_2d = data.shape[-3] == 1
Expand All @@ -86,7 +83,6 @@ def apply_transform(self, sample: Subject) -> dict:
'degrees': degrees_params,
'translation': translation_params,
}
random_parameters_images_dict[image_name] = random_parameters_dict
image = self.nib_to_sitk(
data[0],
image_dict[AFFINE],
Expand All @@ -105,7 +101,6 @@ def apply_transform(self, sample: Subject) -> dict:
# Add channels dimension
data = data[np.newaxis, ...]
image_dict[DATA] = torch.from_numpy(data)
sample.add_transform(self, random_parameters_images_dict)
return sample

@staticmethod
Expand Down
8 changes: 1 addition & 7 deletions torchio/transforms/augmentation/intensity/random_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ class RandomNoise(RandomTransform):
If two values :math:`(a, b)` are provided,
then :math:`\sigma \sim \mathcal{U}(a, b)`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
mean: Union[float, Tuple[float, float]] = 0,
std: Union[float, Tuple[float, float]] = (0, 0.25),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.mean_range = self.parse_range(mean, 'mean')
self.std_range = self.parse_range(std, 'std')
if any(np.array(self.std_range) < 0):
Expand All @@ -39,13 +37,9 @@ def __init__(
raise ValueError(message)

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
for image_name, image_dict in sample.get_images_dict().items():
mean, std = self.get_params(self.mean_range, self.std_range)
random_parameters_dict = {'std': std}
random_parameters_images_dict[image_name] = random_parameters_dict
image_dict[DATA] = add_noise(image_dict[DATA], mean, std)
sample.add_transform(self, random_parameters_images_dict)
return sample

@staticmethod
Expand Down
7 changes: 1 addition & 6 deletions torchio/transforms/augmentation/intensity/random_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class RandomSpike(RandomTransform):
of the spectrum.
Larger values generate more distorted images.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
.. note:: The execution time of this transform does not depend on the
number of spikes.
Expand All @@ -29,9 +28,8 @@ def __init__(
num_spikes: Union[int, Tuple[int, int]] = 1,
intensity: Union[float, Tuple[float, float]] = (1, 3),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.intensity_range = self.parse_range(
intensity, 'intensity_range')
if isinstance(num_spikes, int):
Expand All @@ -40,7 +38,6 @@ def __init__(
self.num_spikes_range = num_spikes

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
for image_name, image_dict in sample.get_images_dict().items():
params = self.get_params(
self.num_spikes_range,
Expand All @@ -51,7 +48,6 @@ def apply_transform(self, sample: Subject) -> dict:
'intensity': intensity_param,
'spikes_positions': spikes_positions_param,
}
random_parameters_images_dict[image_name] = random_parameters_dict
image_dict[DATA] = self.add_artifact(
image_dict.as_sitk(),
spikes_positions_param,
Expand All @@ -60,7 +56,6 @@ def apply_transform(self, sample: Subject) -> dict:
# Add channels dimension
image_dict[DATA] = image_dict[DATA][np.newaxis, ...]
image_dict[DATA] = torch.from_numpy(image_dict[DATA])
sample.add_transform(self, random_parameters_images_dict)
return sample

@staticmethod
Expand Down
4 changes: 1 addition & 3 deletions torchio/transforms/augmentation/intensity/random_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ class RandomSwap(RandomTransform):
If a single number :math:`n` is provided, :math:`d = h = w = n`.
num_iterations: Number of times that two patches will be swapped.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""
def __init__(
self,
patch_size: TypeTuple = 15,
num_iterations: int = 100,
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.patch_size = to_tuple(patch_size)
self.num_iterations = num_iterations

Expand Down
29 changes: 10 additions & 19 deletions torchio/transforms/augmentation/random_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,28 @@ class RandomTransform(Transform):
Args:
p: Probability that this transform will be applied.
seed: Seed for :py:mod:`torch` random number generator.
"""
def __init__(
self,
p: float = 1,
seed: Optional[int] = None,
):
def __init__(self, p: float = 1):
super().__init__(p=p)
self._seed = self.generate_seed() if seed is None else seed

@staticmethod
def generate_seed():
# https://github.com/fepegar/torchio/issues/208#issuecomment-650262724
return torch.randint(1, 2**63 - 1, (1,)).item()

def __call__(self, sample: Subject):
def __call__(self, sample: Subject, seed: Optional[int] = None):
seed = self.generate_seed() if seed is None else seed
with torch.random.fork_rng():
self.check_seed()
torch.manual_seed(seed)
transformed = super().__call__(sample)
if transformed is sample:
pass # the transform was not applied
elif not isinstance(transformed, Subject):
pass # random parameters are stored in instances of Subject
else:
_, random_params_dict = transformed.history[-1]
random_params_dict['seed'] = self._seed
transformed.add_transform(self, seed)
return transformed

@staticmethod
def generate_seed():
# https://github.com/fepegar/torchio/issues/208#issuecomment-650262724
return torch.randint(1, 2**63 - 1, (1,)).item()

@staticmethod
def parse_range(
nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
Expand Down Expand Up @@ -98,9 +92,6 @@ def parse_translation(
) -> Tuple[float, float]:
return self.parse_range(translation, 'translation')

def check_seed(self) -> None:
torch.manual_seed(self._seed)

@staticmethod
def fourier_transform(array: np.ndarray):
transformed = np.fft.fftn(array)
Expand Down
10 changes: 1 addition & 9 deletions torchio/transforms/augmentation/spatial/random_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class RandomAffine(RandomTransform):
`Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
image_interpolation: See :ref:`Interpolation`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
Example:
>>> import torchio
Expand Down Expand Up @@ -81,9 +80,8 @@ def __init__(
default_pad_value: Union[str, float] = 'otsu',
image_interpolation: str = 'linear',
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.scales = scales
self.degrees = self.parse_degrees(degrees)
self.translation = self.parse_range(translation, 'translation')
Expand Down Expand Up @@ -183,12 +181,6 @@ def apply_transform(self, sample: Subject) -> dict:
interpolation,
center_lps=center,
)
random_parameters_dict = {
'scaling': scaling_params,
'rotation': rotation_params,
'translation': translation_params,
}
sample.add_transform(self, random_parameters_dict)
return sample

def apply_affine_transform(
Expand Down
6 changes: 1 addition & 5 deletions torchio/transforms/augmentation/spatial/random_downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@ class RandomDownsample(RandomTransform):
downsampling: Downsampling factor :math:`m \gt 1`. If a tuple
:math:`(a, b)` is provided then :math:`m \sim \mathcal{U}(a, b)`.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
"""

def __init__(
self,
axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
downsampling: float = (1.5, 5),
p: float = 1,
seed: Optional[int] = None,
):
super().__init__(p=p, seed=seed)
super().__init__(p=p)
self.axes = self.parse_axes(axes)
self.downsampling_range = self.parse_downsampling(downsampling)

Expand Down Expand Up @@ -65,7 +63,6 @@ def parse_axes(axes: Union[int, Tuple[int, ...]]):

def apply_transform(self, sample: Subject) -> Subject:
axis, downsampling = self.get_params(self.axes, self.downsampling_range)
random_parameters_dict = {'axis': axis, 'downsampling': downsampling}
items = sample.get_images_dict(intensity_only=False).items()

target_spacing = list(sample.spacing)
Expand All @@ -76,5 +73,4 @@ def apply_transform(self, sample: Subject) -> Subject:
copy=False, # already copied in super().__init__
)
sample = transform(sample)
sample.add_transform(self, random_parameters_dict)
return sample
Loading

0 comments on commit f3529c8

Please sign in to comment.