From 4b1d85d93bbeee30b93a80a3b1d861492f5b9184 Mon Sep 17 00:00:00 2001 From: Fernando Date: Thu, 17 Dec 2020 11:36:32 +0000 Subject: [PATCH 1/2] Use PyTorch for FFT if available --- .../augmentation/test_random_motion.py | 11 ++++++-- .../augmentation/test_random_spike.py | 15 ++++++++--- .../augmentation/intensity/random_ghosting.py | 18 ++++++------- .../augmentation/intensity/random_motion.py | 14 +++++----- .../augmentation/intensity/random_spike.py | 12 +++++---- torchio/transforms/fourier.py | 27 +++++++++++++------ 6 files changed, 63 insertions(+), 34 deletions(-) diff --git a/tests/transforms/augmentation/test_random_motion.py b/tests/transforms/augmentation/test_random_motion.py index 163674ed6..9ce4a79b8 100644 --- a/tests/transforms/augmentation/test_random_motion.py +++ b/tests/transforms/augmentation/test_random_motion.py @@ -15,14 +15,21 @@ def test_no_movement(self): num_transforms=1 ) transformed = transform(self.sample_subject) - self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data) + self.assertTensorAlmostEqual( + self.sample_subject.t1.data, + transformed.t1.data, + decimal=4, + ) def test_with_movement(self): transform = RandomMotion( num_transforms=1 ) transformed = transform(self.sample_subject) - self.assertTensorNotEqual(self.sample_subject.t1.data, transformed.t1.data) + self.assertTensorNotEqual( + self.sample_subject.t1.data, + transformed.t1.data, + ) def test_negative_degrees(self): with self.assertRaises(ValueError): diff --git a/tests/transforms/augmentation/test_random_spike.py b/tests/transforms/augmentation/test_random_spike.py index 04ed40b1e..3c335b200 100644 --- a/tests/transforms/augmentation/test_random_spike.py +++ b/tests/transforms/augmentation/test_random_spike.py @@ -7,17 +7,26 @@ class TestRandomSpike(TorchioTestCase): def test_with_zero_intensity(self): transform = RandomSpike(intensity=0) transformed = transform(self.sample_subject) - self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data) + self.assertTensorAlmostEqual( + self.sample_subject.t1.data, + transformed.t1.data, + ) def test_with_zero_spike(self): transform = RandomSpike(num_spikes=0) transformed = transform(self.sample_subject) - self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data) + self.assertTensorAlmostEqual( + self.sample_subject.t1.data, + transformed.t1.data, + ) def test_with_spikes(self): transform = RandomSpike() transformed = transform(self.sample_subject) - self.assertTensorNotEqual(self.sample_subject.t1.data, transformed.t1.data) + self.assertTensorNotEqual( + self.sample_subject.t1.data, + transformed.t1.data, + ) def test_negative_num_spikes(self): with self.assertRaises(ValueError): diff --git a/torchio/transforms/augmentation/intensity/random_ghosting.py b/torchio/transforms/augmentation/intensity/random_ghosting.py index 782fbed63..3ad8aa037 100644 --- a/torchio/transforms/augmentation/intensity/random_ghosting.py +++ b/torchio/transforms/augmentation/intensity/random_ghosting.py @@ -175,23 +175,22 @@ def add_artifact( if not num_ghosts or not intensity: return tensor - array = tensor.numpy() - spectrum = self.fourier_transform(array) + spectrum = self.fourier_transform(tensor) - shape = np.array(array.shape) + shape = np.array(tensor.shape) ri, rj, rk = np.round(restore_center * shape).astype(np.uint16) - mi, mj, mk = np.array(array.shape) // 2 + mi, mj, mk = np.array(tensor.shape) // 2 # Variable "planes" is the part of the spectrum that will be modified if axis == 0: planes = spectrum[::num_ghosts, :, :] - restore = spectrum[mi, :, :].copy() + restore = spectrum[mi, :, :].clone() elif axis == 1: planes = spectrum[:, ::num_ghosts, :] - restore = spectrum[:, mj, :].copy() + restore = spectrum[:, mj, :].clone() elif axis == 2: planes = spectrum[:, :, ::num_ghosts] - restore = spectrum[:, :, mk].copy() + restore = spectrum[:, :, mk].clone() # Multiply by 0 if intensity is 1 planes *= 1 - intensity @@ -204,9 +203,8 @@ def add_artifact( elif axis == 2: spectrum[:, :, mk] = restore - array_ghosts = self.inv_fourier_transform(spectrum) - array_ghosts = np.real(array_ghosts).astype(np.float32) - return torch.from_numpy(array_ghosts) + tensor_ghosts = self.inv_fourier_transform(spectrum) + return tensor_ghosts.real.float() def _parse_restore(restore): diff --git a/torchio/transforms/augmentation/intensity/random_motion.py b/torchio/transforms/augmentation/intensity/random_motion.py index 4e2b392bf..b274e5b99 100644 --- a/torchio/transforms/augmentation/intensity/random_motion.py +++ b/torchio/transforms/augmentation/intensity/random_motion.py @@ -256,11 +256,13 @@ def add_artifact( interpolation: str, ): images = self.resample_images(image, transforms, interpolation) - arrays = [sitk.GetArrayViewFromImage(im) for im in images] - arrays = [array.transpose() for array in arrays] # ITK to NumPy - spectra = [self.fourier_transform(array) for array in arrays] + spectra = [] + for image in images: + array = sitk.GetArrayFromImage(image).transpose() # sitk to np + spectrum = self.fourier_transform(torch.from_numpy(array)) + spectra.append(spectrum) self.sort_spectra(spectra, times) - result_spectrum = np.empty_like(spectra[0]) + result_spectrum = torch.empty_like(spectra[0]) last_index = result_spectrum.shape[2] indices = (last_index * times).astype(int).tolist() indices.append(last_index) @@ -268,5 +270,5 @@ def add_artifact( for spectrum, fin in zip(spectra, indices): result_spectrum[..., ini:fin] = spectrum[..., ini:fin] ini = fin - result_image = np.real(self.inv_fourier_transform(result_spectrum)) - return result_image.astype(np.float32) + result_image = self.inv_fourier_transform(result_spectrum).real.float() + return result_image diff --git a/torchio/transforms/augmentation/intensity/random_spike.py b/torchio/transforms/augmentation/intensity/random_spike.py index 81607b92f..2dbe60195 100644 --- a/torchio/transforms/augmentation/intensity/random_spike.py +++ b/torchio/transforms/augmentation/intensity/random_spike.py @@ -126,15 +126,17 @@ def add_artifact( spikes_positions: np.ndarray, intensity_factor: float, ): - array = np.asarray(tensor) - spectrum = self.fourier_transform(array) + if intensity_factor == 0 or len(spikes_positions) == 0: + return tensor + spectrum = self.fourier_transform(tensor) shape = np.array(spectrum.shape) mid_shape = shape // 2 indices = np.floor(spikes_positions * shape).astype(int) for index in indices: diff = index - mid_shape i, j, k = mid_shape + diff - artifact = spectrum.max() * intensity_factor + # As of torch 1.7, "max is not yet implemented for complex tensors" + artifact = spectrum.numpy().max() * intensity_factor if self.invert_transform: spectrum[i, j, k] -= artifact else: @@ -145,5 +147,5 @@ def add_artifact( # scans. Therefore the next two lines have been removed. # #i, j, k = mid_shape - diff # #spectrum[i, j, k] = spectrum.max() * intensity_factor - result = np.real(self.inv_fourier_transform(spectrum)) - return torch.from_numpy(result.astype(np.float32)) + result = self.inv_fourier_transform(spectrum).real.float() + return result diff --git a/torchio/transforms/fourier.py b/torchio/transforms/fourier.py index 0dafd3929..00ce3ae77 100644 --- a/torchio/transforms/fourier.py +++ b/torchio/transforms/fourier.py @@ -1,16 +1,27 @@ +import torch import numpy as np class FourierTransform: @staticmethod - def fourier_transform(array: np.ndarray) -> np.ndarray: - transformed = np.fft.fftn(array) - fshift = np.fft.fftshift(transformed) - return fshift + def fourier_transform(tensor: torch.Tensor) -> torch.Tensor: + try: + import torch.fft + return torch.fft.fftn(tensor) + except ModuleNotFoundError: + import torch + transformed = np.fft.fftn(tensor) + fshift = np.fft.fftshift(transformed) + return torch.from_numpy(fshift) @staticmethod - def inv_fourier_transform(fshift: np.ndarray) -> np.ndarray: - f_ishift = np.fft.ifftshift(fshift) - img_back = np.fft.ifftn(f_ishift) - return img_back + def inv_fourier_transform(tensor: torch.Tensor) -> torch.Tensor: + try: + import torch.fft + return torch.fft.ifftn(tensor) + except ModuleNotFoundError: + import torch + f_ishift = np.fft.ifftshift(tensor) + img_back = np.fft.ifftn(f_ishift) + return torch.from_numpy(img_back) From 53dba02777b445f4bd9488c5e27fd1105bbf3d30 Mon Sep 17 00:00:00 2001 From: Fernando Perez-Garcia Date: Thu, 17 Dec 2020 12:42:26 +0500 Subject: [PATCH 2/2] Make sure tensor in on CPU before calling numpy() --- torchio/transforms/augmentation/intensity/random_spike.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchio/transforms/augmentation/intensity/random_spike.py b/torchio/transforms/augmentation/intensity/random_spike.py index 2dbe60195..dfb7f2545 100644 --- a/torchio/transforms/augmentation/intensity/random_spike.py +++ b/torchio/transforms/augmentation/intensity/random_spike.py @@ -136,7 +136,7 @@ def add_artifact( diff = index - mid_shape i, j, k = mid_shape + diff # As of torch 1.7, "max is not yet implemented for complex tensors" - artifact = spectrum.numpy().max() * intensity_factor + artifact = spectrum.cpu().numpy().max() * intensity_factor if self.invert_transform: spectrum[i, j, k] -= artifact else: