Skip to content

Commit

Permalink
Merge e9635ff into 90612ef
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jun 24, 2020
2 parents 90612ef + e9635ff commit 8cb194e
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 115 deletions.
9 changes: 2 additions & 7 deletions tests/transforms/augmentation/test_random_motion.py
Expand Up @@ -5,10 +5,5 @@
class TestRandomMotion(TorchioTestCase):
"""Tests for `RandomMotion`."""
def test_random_motion(self):
transform = torchio.transforms.RandomMotion(
seed=42,
)
transformed = transform(self.sample)
self.sample['t2'][torchio.DATA] = self.sample['t2'][torchio.DATA] - 0.5
with self.assertWarns(UserWarning):
transformed = transform(self.sample)
with self.assertRaises(ValueError):
transform = torchio.transforms.RandomMotion(num_transforms=0)
122 changes: 49 additions & 73 deletions torchio/transforms/augmentation/intensity/random_ghosting.py
@@ -1,4 +1,3 @@
import warnings
from typing import Tuple, Optional, Union
import torch
import numpy as np
Expand All @@ -18,10 +17,14 @@ class RandomGhosting(RandomTransform):
axes: Axis along which the ghosts will be created. If
:py:attr:`axes` is a tuple, the axis will be randomly chosen
from the passed values.
intensity: Number between 0 and 1 representing the artifact strength
:math:`s`. If ``0``, the ghosts will not be visible. If a tuple
intensity: Positive number representing the artifact strength
:math:`s` with respect to the maximum of the :math:`k`-space.
If ``0``, the ghosts will not be visible. If a tuple
:math:`(a, b)`, is provided then
:math:`s \sim \mathcal{U}(a, b)`.
restore: Number between ``0`` and ``1`` indicating how much of the
:math:`k`-space center should be restored after removing the planes
that generate the artifact.
p: Probability that this transform will be applied.
seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
Expand All @@ -33,6 +36,7 @@ def __init__(
num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
intensity: Union[float, Tuple[float, float]] = (0.5, 1),
restore: float = 0.02,
p: float = 1,
seed: Optional[int] = None,
):
Expand All @@ -52,10 +56,15 @@ def __init__(
self.num_ghosts_range = num_ghosts
self.intensity_range = self.parse_range(intensity, 'intensity')
for n in self.intensity_range:
if not 0 <= n <= 1:
if n < 0:
message = (
f'Intensity must be a number between 0 and 1, not {n}')
f'Intensity must be a positive number, not {n}')
raise ValueError(message)
if not 0 <= restore < 1:
message = (
f'Restore must be a number between 0 and 1, not {restore}')
raise ValueError(message)
self.restore = restore

def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict = {}
Expand All @@ -75,31 +84,13 @@ def apply_transform(self, sample: Subject) -> dict:
'intensity': intensity_param,
}
random_parameters_images_dict[image_name] = random_parameters_dict
if (data[0] < -0.1).any():
# I use -0.1 instead of 0 because Python was warning me when
# a value in a voxel was -7.191084e-35
# There must be a better way of solving this
message = (
f'Image "{image_name}" from "{image_dict["stem"]}"'
' has negative values.'
' Results can be unexpected because the transformed sample'
' is computed as the absolute values'
' of an inverse Fourier transform'
)
warnings.warn(message)
image = self.nib_to_sitk(
image_dict[DATA][0] = self.add_artifact(
data[0],
image_dict[AFFINE],
)
data = self.add_artifact(
image,
num_ghosts_param,
axis_param,
intensity_param,
self.restore,
)
# Add channels dimension
data = data[np.newaxis, ...]
image_dict[DATA] = torch.from_numpy(data)
sample.add_transform(self, random_parameters_images_dict)
return sample

Expand All @@ -115,57 +106,42 @@ def get_params(
intensity = torch.FloatTensor(1).uniform_(*intensity_range).item()
return num_ghosts, axis, intensity

@staticmethod
def get_axis_and_size(axis, array):
if axis == 1:
axis = 0
size = array.shape[0]
elif axis == 0:
axis = 1
size = array.shape[1]
elif axis == 2: # we will also traverse in sagittal (if RAS)
size = array.shape[0]
else:
raise RuntimeError(f'Axis "{axis}" is not valid')
return axis, size

@staticmethod
def get_slice(axis, array, slice_idx):
# Comments apply if RAS
if axis == 0: # sagittal (columns) - artifact AP
image_slice = array[slice_idx, ...]
elif axis == 1: # coronal (columns) - artifact LR
image_slice = array[:, slice_idx, :]
elif axis == 2: # sagittal (rows) - artifact IS
image_slice = array[slice_idx, ...].T
else:
raise RuntimeError(f'Axis "{axis}" is not valid')
return image_slice

def add_artifact(
self,
image: sitk.Image,
tensor: torch.Tensor,
num_ghosts: int,
axis: int,
intensity: float,
restore_center: float,
):
array = sitk.GetArrayFromImage(image).transpose()
# Leave first 5% of frequencies untouched. If the image is in RAS
# orientation, this helps applying the ghosting in the desired axis
# intuitively
# [Why? I forgot]
percentage_to_avoid = 0.05
axis, size = self.get_axis_and_size(axis, array)
for slice_idx in range(size):
image_slice = self.get_slice(axis, array, slice_idx)
spectrum = self.fourier_transform(image_slice)
for row_idx, row in enumerate(spectrum):
if row_idx % num_ghosts:
continue
progress = row_idx / array.shape[0]
if np.abs(progress - 0.5) < percentage_to_avoid / 2:
continue
row *= 1 - intensity
image_slice *= 0
image_slice += self.inv_fourier_transform(spectrum)
return array
array = tensor.numpy()
spectrum = self.fourier_transform(array)

ri, rj, rk = np.round(restore_center * np.array(array.shape)).astype(np.uint16)
mi, mj, mk = np.array(array.shape) // 2

# Variable "planes" is the part the spectrum that will be modified
if axis == 0:
planes = spectrum[::num_ghosts, :, :]
restore = spectrum[mi, :, :].copy()
elif axis == 1:
planes = spectrum[:, ::num_ghosts, :]
restore = spectrum[:, mj, :].copy()
elif axis == 2:
planes = spectrum[:, :, ::num_ghosts]
restore = spectrum[:, :, mk].copy()

# Multiply by 0 if intensity is 1
planes *= 1 - intensity

# Restore the center of k-space to avoid extreme artifacts
if axis == 0:
spectrum[mi, :, :] = restore
elif axis == 1:
spectrum[:, mj, :] = restore
elif axis == 2:
spectrum[:, :, mk] = restore

array_ghosts = self.inv_fourier_transform(spectrum)
array_ghosts = np.real(array_ghosts)
return torch.from_numpy(array_ghosts)
21 changes: 7 additions & 14 deletions torchio/transforms/augmentation/intensity/random_motion.py
Expand Up @@ -7,7 +7,6 @@
"""

import warnings
from typing import Tuple, Optional, List
import torch
import numpy as np
Expand Down Expand Up @@ -61,6 +60,12 @@ def __init__(
super().__init__(p=p, seed=seed)
self.degrees_range = self.parse_degrees(degrees)
self.translation_range = self.parse_translation(translation)
if not 0 < num_transforms or not isinstance(num_transforms, int):
message = (
'Number of transforms must be a natural number,'
f' not {num_transforms}'
)
raise ValueError(message)
self.num_transforms = num_transforms
self.image_interpolation = self.parse_interpolation(image_interpolation)

Expand All @@ -82,18 +87,6 @@ def apply_transform(self, sample: Subject) -> dict:
'translation': translation_params,
}
random_parameters_images_dict[image_name] = random_parameters_dict
if (data[0] < -0.1).any():
# I use -0.1 instead of 0 because Python was warning me when
# a value in a voxel was -7.191084e-35
# There must be a better way of solving this
message = (
f'Image "{image_name}" from "{image_dict["stem"]}"'
' has negative values.'
' Results can be unexpected because the transformed sample'
' is computed as the absolute values'
' of an inverse Fourier transform'
)
warnings.warn(message)
image = self.nib_to_sitk(
data[0],
image_dict[AFFINE],
Expand Down Expand Up @@ -227,7 +220,7 @@ def add_artifact(
for spectrum, fin in zip(spectra, indices):
result_spectrum[..., ini:fin] = spectrum[..., ini:fin]
ini = fin
result_image = self.inv_fourier_transform(result_spectrum)
result_image = np.real(self.inv_fourier_transform(result_spectrum))
return result_image.astype(np.float32)


Expand Down
36 changes: 16 additions & 20 deletions torchio/transforms/augmentation/intensity/random_spike.py
@@ -1,4 +1,3 @@
import warnings
from typing import Tuple, Optional, Union
import torch
import numpy as np
Expand Down Expand Up @@ -28,7 +27,7 @@ class RandomSpike(RandomTransform):
def __init__(
self,
num_spikes: Union[int, Tuple[int, int]] = 1,
intensity: Union[float, Tuple[float, float]] = (0.1, 1),
intensity: Union[float, Tuple[float, float]] = (1, 3),
p: float = 1,
seed: Optional[int] = None,
):
Expand All @@ -53,18 +52,6 @@ def apply_transform(self, sample: Subject) -> dict:
'spikes_positions': spikes_positions_param,
}
random_parameters_images_dict[image_name] = random_parameters_dict
if (image_dict[DATA][0] < -0.1).any():
# I use -0.1 instead of 0 because Python was warning me when
# a value in a voxel was -7.191084e-35
# There must be a better way of solving this
message = (
f'Image "{image_name}" from "{image_dict["stem"]}"'
' has negative values.'
' Results can be unexpected because the transformed sample'
' is computed as the absolute values'
' of an inverse Fourier transform'
)
warnings.warn(message)
image_dict[DATA] = self.add_artifact(
image_dict.as_sitk(),
spikes_positions_param,
Expand All @@ -84,7 +71,7 @@ def get_params(
ns_min, ns_max = num_spikes_range
num_spikes_param = torch.randint(ns_min, ns_max + 1, (1,)).item()
intensity_param = torch.FloatTensor(1).uniform_(*intensity_range)
spikes_positions = torch.rand(num_spikes_param).numpy()
spikes_positions = torch.rand(num_spikes_param, 3).numpy()
return spikes_positions, intensity_param.item()

def add_artifact(
Expand All @@ -94,10 +81,19 @@ def add_artifact(
intensity_factor: float,
):
array = sitk.GetArrayViewFromImage(image).transpose()
spectrum = self.fourier_transform(array).ravel()
indices = np.floor(spikes_positions * len(spectrum)).astype(int)
spectrum = self.fourier_transform(array)
shape = np.array(spectrum.shape)
mid_shape = shape // 2
indices = np.floor(spikes_positions * shape).astype(int)
for index in indices:
spectrum[index] = spectrum.max() * intensity_factor
spectrum = spectrum.reshape(array.shape)
result = self.inv_fourier_transform(spectrum)
diff = index - mid_shape
i, j, k = mid_shape + diff
spectrum[i, j, k] = spectrum.max() * intensity_factor
# If we wanted to add a pure cosine, we should add spikes to both
# sides of k-space. However, having only one is a better
# representation og the actual cause of the artifact in real
# scans.
#i, j, k = mid_shape - diff
#spectrum[i, j, k] = spectrum.max() * intensity_factor
result = np.real(self.inv_fourier_transform(spectrum))
return result.astype(np.float32)
2 changes: 1 addition & 1 deletion torchio/transforms/augmentation/random_transform.py
Expand Up @@ -98,4 +98,4 @@ def fourier_transform(array: np.ndarray):
def inv_fourier_transform(fshift: np.ndarray):
f_ishift = np.fft.ifftshift(fshift)
img_back = np.fft.ifftn(f_ishift)
return np.abs(img_back)
return img_back

0 comments on commit 8cb194e

Please sign in to comment.