Skip to content

Commit

Permalink
Merge 53dba02 into 638b206
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Dec 17, 2020
2 parents 638b206 + 53dba02 commit a31b767
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 34 deletions.
11 changes: 9 additions & 2 deletions tests/transforms/augmentation/test_random_motion.py
Expand Up @@ -15,14 +15,21 @@ def test_no_movement(self):
num_transforms=1
)
transformed = transform(self.sample_subject)
self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data)
self.assertTensorAlmostEqual(
self.sample_subject.t1.data,
transformed.t1.data,
decimal=4,
)

def test_with_movement(self):
transform = RandomMotion(
num_transforms=1
)
transformed = transform(self.sample_subject)
self.assertTensorNotEqual(self.sample_subject.t1.data, transformed.t1.data)
self.assertTensorNotEqual(
self.sample_subject.t1.data,
transformed.t1.data,
)

def test_negative_degrees(self):
with self.assertRaises(ValueError):
Expand Down
15 changes: 12 additions & 3 deletions tests/transforms/augmentation/test_random_spike.py
Expand Up @@ -7,17 +7,26 @@ class TestRandomSpike(TorchioTestCase):
def test_with_zero_intensity(self):
transform = RandomSpike(intensity=0)
transformed = transform(self.sample_subject)
self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data)
self.assertTensorAlmostEqual(
self.sample_subject.t1.data,
transformed.t1.data,
)

def test_with_zero_spike(self):
transform = RandomSpike(num_spikes=0)
transformed = transform(self.sample_subject)
self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data)
self.assertTensorAlmostEqual(
self.sample_subject.t1.data,
transformed.t1.data,
)

def test_with_spikes(self):
transform = RandomSpike()
transformed = transform(self.sample_subject)
self.assertTensorNotEqual(self.sample_subject.t1.data, transformed.t1.data)
self.assertTensorNotEqual(
self.sample_subject.t1.data,
transformed.t1.data,
)

def test_negative_num_spikes(self):
with self.assertRaises(ValueError):
Expand Down
18 changes: 8 additions & 10 deletions torchio/transforms/augmentation/intensity/random_ghosting.py
Expand Up @@ -175,23 +175,22 @@ def add_artifact(
if not num_ghosts or not intensity:
return tensor

array = tensor.numpy()
spectrum = self.fourier_transform(array)
spectrum = self.fourier_transform(tensor)

shape = np.array(array.shape)
shape = np.array(tensor.shape)
ri, rj, rk = np.round(restore_center * shape).astype(np.uint16)
mi, mj, mk = np.array(array.shape) // 2
mi, mj, mk = np.array(tensor.shape) // 2

# Variable "planes" is the part of the spectrum that will be modified
if axis == 0:
planes = spectrum[::num_ghosts, :, :]
restore = spectrum[mi, :, :].copy()
restore = spectrum[mi, :, :].clone()
elif axis == 1:
planes = spectrum[:, ::num_ghosts, :]
restore = spectrum[:, mj, :].copy()
restore = spectrum[:, mj, :].clone()
elif axis == 2:
planes = spectrum[:, :, ::num_ghosts]
restore = spectrum[:, :, mk].copy()
restore = spectrum[:, :, mk].clone()

# Multiply by 0 if intensity is 1
planes *= 1 - intensity
Expand All @@ -204,9 +203,8 @@ def add_artifact(
elif axis == 2:
spectrum[:, :, mk] = restore

array_ghosts = self.inv_fourier_transform(spectrum)
array_ghosts = np.real(array_ghosts).astype(np.float32)
return torch.from_numpy(array_ghosts)
tensor_ghosts = self.inv_fourier_transform(spectrum)
return tensor_ghosts.real.float()


def _parse_restore(restore):
Expand Down
14 changes: 8 additions & 6 deletions torchio/transforms/augmentation/intensity/random_motion.py
Expand Up @@ -256,17 +256,19 @@ def add_artifact(
interpolation: str,
):
images = self.resample_images(image, transforms, interpolation)
arrays = [sitk.GetArrayViewFromImage(im) for im in images]
arrays = [array.transpose() for array in arrays] # ITK to NumPy
spectra = [self.fourier_transform(array) for array in arrays]
spectra = []
for image in images:
array = sitk.GetArrayFromImage(image).transpose() # sitk to np
spectrum = self.fourier_transform(torch.from_numpy(array))
spectra.append(spectrum)
self.sort_spectra(spectra, times)
result_spectrum = np.empty_like(spectra[0])
result_spectrum = torch.empty_like(spectra[0])
last_index = result_spectrum.shape[2]
indices = (last_index * times).astype(int).tolist()
indices.append(last_index)
ini = 0
for spectrum, fin in zip(spectra, indices):
result_spectrum[..., ini:fin] = spectrum[..., ini:fin]
ini = fin
result_image = np.real(self.inv_fourier_transform(result_spectrum))
return result_image.astype(np.float32)
result_image = self.inv_fourier_transform(result_spectrum).real.float()
return result_image
12 changes: 7 additions & 5 deletions torchio/transforms/augmentation/intensity/random_spike.py
Expand Up @@ -126,15 +126,17 @@ def add_artifact(
spikes_positions: np.ndarray,
intensity_factor: float,
):
array = np.asarray(tensor)
spectrum = self.fourier_transform(array)
if intensity_factor == 0 or len(spikes_positions) == 0:
return tensor
spectrum = self.fourier_transform(tensor)
shape = np.array(spectrum.shape)
mid_shape = shape // 2
indices = np.floor(spikes_positions * shape).astype(int)
for index in indices:
diff = index - mid_shape
i, j, k = mid_shape + diff
artifact = spectrum.max() * intensity_factor
# As of torch 1.7, "max is not yet implemented for complex tensors"
artifact = spectrum.cpu().numpy().max() * intensity_factor
if self.invert_transform:
spectrum[i, j, k] -= artifact
else:
Expand All @@ -145,5 +147,5 @@ def add_artifact(
# scans. Therefore the next two lines have been removed.
# #i, j, k = mid_shape - diff
# #spectrum[i, j, k] = spectrum.max() * intensity_factor
result = np.real(self.inv_fourier_transform(spectrum))
return torch.from_numpy(result.astype(np.float32))
result = self.inv_fourier_transform(spectrum).real.float()
return result
27 changes: 19 additions & 8 deletions torchio/transforms/fourier.py
@@ -1,16 +1,27 @@
import torch
import numpy as np


class FourierTransform:

@staticmethod
def fourier_transform(array: np.ndarray) -> np.ndarray:
transformed = np.fft.fftn(array)
fshift = np.fft.fftshift(transformed)
return fshift
def fourier_transform(tensor: torch.Tensor) -> torch.Tensor:
try:
import torch.fft
return torch.fft.fftn(tensor)
except ModuleNotFoundError:
import torch
transformed = np.fft.fftn(tensor)
fshift = np.fft.fftshift(transformed)
return torch.from_numpy(fshift)

@staticmethod
def inv_fourier_transform(fshift: np.ndarray) -> np.ndarray:
f_ishift = np.fft.ifftshift(fshift)
img_back = np.fft.ifftn(f_ishift)
return img_back
def inv_fourier_transform(tensor: torch.Tensor) -> torch.Tensor:
try:
import torch.fft
return torch.fft.ifftn(tensor)
except ModuleNotFoundError:
import torch
f_ishift = np.fft.ifftshift(tensor)
img_back = np.fft.ifftn(f_ishift)
return torch.from_numpy(img_back)

0 comments on commit a31b767

Please sign in to comment.