Skip to content

Commit

Permalink
Avoid using SimpleITK in some transforms (#334)
Browse files Browse the repository at this point in the history
* Use PyTorch to crop

* Use Crop transform to create patches

* Remove typing hint

* Use NumPy to pad

* Use scipy to blur

* Fix kwarg name and default values
  • Loading branch information
fepegar committed Oct 16, 2020
1 parent b5aa0ad commit 281b2ea
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 166 deletions.
22 changes: 0 additions & 22 deletions tests/data/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ def test_tensor_affine(self):
sample_input = torch.ones((4, 10, 10, 10))
RandomAffine()(sample_input)

def test_crop_attributes(self):
cropped = self.sample.crop((1, 1, 1), (5, 5, 5))
self.assertIs(self.sample.t1['pre_affine'], cropped.t1['pre_affine'])

def test_crop_does_not_create_wrong_path(self):
data = torch.ones((1, 10, 10, 10))
image = ScalarImage(tensor=data)
cropped = image.crop((1, 1, 1), (5, 5, 5))
self.assertIs(cropped.path, None)

def test_scalar_image_type(self):
data = torch.ones((1, 10, 10, 10))
image = ScalarImage(tensor=data)
Expand All @@ -68,18 +58,6 @@ def test_wrong_label_map_type(self):
with self.assertRaises(ValueError):
LabelMap(tensor=data, type=INTENSITY)

def test_crop_scalar_image_type(self):
data = torch.ones((1, 10, 10, 10))
image = ScalarImage(tensor=data)
cropped = image.crop((1, 1, 1), (5, 5, 5))
self.assertIs(cropped.type, INTENSITY)

def test_crop_label_map_type(self):
data = torch.ones((1, 10, 10, 10))
label = LabelMap(tensor=data)
cropped = label.crop((1, 1, 1), (5, 5, 5))
self.assertIs(cropped.type, LABEL)

def test_no_input(self):
with self.assertRaises(ValueError):
ScalarImage()
Expand Down
20 changes: 20 additions & 0 deletions tests/transforms/preprocessing/test_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import SimpleITK as sitk
from torchio.utils import sitk_to_nib
from torchio.transforms import Pad
from ...utils import TorchioTestCase


class TestPad(TorchioTestCase):
"""Tests for `Pad`."""
def test_pad(self):
image = self.sample.t1
padding = 1, 2, 3, 4, 5, 6
sitk_image = image.as_sitk()
low, high = padding[::2], padding[1::2]
sitk_padded = sitk.ConstantPad(sitk_image, low, high, 0)
tio_padded = Pad(padding, padding_mode=0)(image)
sitk_tensor, sitk_affine = sitk_to_nib(sitk_padded)
tio_tensor, tio_affine = sitk_to_nib(tio_padded.as_sitk())
self.assertTensorEqual(sitk_tensor, tio_tensor)
self.assertTensorEqual(sitk_affine, tio_affine)
2 changes: 1 addition & 1 deletion torchio/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
def __len__(self):
return len(self.subjects)

def __getitem__(self, index: int) -> dict:
def __getitem__(self, index: int) -> Subject:
if not isinstance(index, int):
raise ValueError(f'Index "{index}" must be int, not {type(index)}')
subject = self.subjects[index]
Expand Down
18 changes: 0 additions & 18 deletions torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,24 +472,6 @@ def plot(self, **kwargs) -> None:
from ..visualization import plot_volume # avoid circular import
plot_volume(self, **kwargs)

def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
new_origin = nib.affines.apply_affine(self.affine, index_ini)
new_affine = self.affine.copy()
new_affine[:3, 3] = new_origin
i0, j0, k0 = index_ini
i1, j1, k1 = index_fin
patch = self.data[:, i0:i1, j0:j1, k0:k1].clone()
kwargs = dict(
tensor=patch,
affine=new_affine,
type=self.type,
path=self.path,
)
for key, value in self.items():
if key in PROTECTED_KEYS: continue
kwargs[key] = value # should I copy? deepcopy?
return self.__class__(**kwargs)


class ScalarImage(Image):
"""Alias for :py:class:`~torchio.Image` of type :py:attr:`torchio.INTENSITY`.
Expand Down
4 changes: 2 additions & 2 deletions torchio/data/inference/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class GridAggregator:
information about patch-based sampling.
"""
def __init__(self, sampler: GridSampler, overlap_mode: str = 'crop'):
sample = sampler.sample
subject = sampler.subject
self.volume_padded = sampler.padding_mode is not None
self.spatial_shape = sample.spatial_shape
self.spatial_shape = subject.spatial_shape
self._output_tensor = None
self.patch_overlap = sampler.patch_overlap
self.parse_overlap_mode(overlap_mode)
Expand Down
13 changes: 6 additions & 7 deletions torchio/data/inference/grid_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ def __init__(
patch_overlap: TypeTuple = (0, 0, 0),
padding_mode: Union[str, float, None] = None,
):
self.sample = sample
self.subject = sample
self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
self.padding_mode = padding_mode
if padding_mode is not None:
from ...transforms import Pad
border = self.patch_overlap // 2
padding = border.repeat(2)
pad = Pad(padding, padding_mode=padding_mode)
self.sample = pad(self.sample)
self.subject = pad(self.subject)
PatchSampler.__init__(self, patch_size)
sizes = self.sample.spatial_shape, self.patch_size, self.patch_overlap
sizes = self.subject.spatial_shape, self.patch_size, self.patch_overlap
self.parse_sizes(*sizes)
self.locations = self.get_patches_locations(*sizes)

Expand All @@ -68,10 +68,9 @@ def __getitem__(self, index):
# Assume 3D
location = self.locations[index]
index_ini = location[:3]
index_fin = location[3:]
cropped_sample = self.sample.crop(index_ini, index_fin)
cropped_sample[LOCATION] = location
return cropped_sample
cropped_subject = self.crop(self.subject, index_ini, self.patch_size)
cropped_subject[LOCATION] = location
return cropped_subject

@staticmethod
def parse_sizes(
Expand Down
36 changes: 30 additions & 6 deletions torchio/data/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,38 @@ def __init__(self, patch_size: TypePatchSize):

def extract_patch(
self,
sample: Subject,
subject: Subject,
index_ini: TypeTripletInt,
) -> Subject:
cropped_subject = self.crop(subject, index_ini, self.patch_size)
cropped_subject['index_ini'] = np.array(index_ini).astype(int)
return cropped_subject

def crop(
self,
subject: Subject,
index_ini: TypeTripletInt,
patch_size: TypeTripletInt,
) -> Subject:
index_ini = np.array(index_ini)
index_fin = index_ini + self.patch_size
cropped_sample = sample.crop(index_ini, index_fin)
cropped_sample['index_ini'] = index_ini.astype(int)
return cropped_sample
transform = self.get_crop_transform(subject, index_ini, patch_size)
return transform(subject)

@staticmethod
def get_crop_transform(
subject,
index_ini,
patch_size: TypePatchSize,
):
from ...transforms.preprocessing.spatial.crop import Crop
shape = np.array(subject.spatial_shape, dtype=np.uint16)
index_ini = np.array(index_ini, dtype=np.uint16)
patch_size = np.array(patch_size, dtype=np.uint16)
index_fin = index_ini + patch_size
crop_ini = index_ini.tolist()
crop_fin = (shape - index_fin).tolist()
start = ()
cropping = sum(zip(crop_ini, crop_fin), start)
return Crop(cropping)


class RandomSampler(PatchSampler):
Expand Down
9 changes: 4 additions & 5 deletions torchio/data/sampler/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,14 @@ def get_cumulative_distribution_function(

def extract_patch(
self,
sample: Subject,
subject: Subject,
probability_map: np.ndarray,
cdf: np.ndarray
) -> Subject:
index_ini = self.get_random_index_ini(probability_map, cdf)
index_fin = index_ini + self.patch_size
cropped_sample = sample.crop(index_ini, index_fin)
cropped_sample['index_ini'] = index_ini.astype(int)
return cropped_sample
cropped_subject = self.crop(subject, index_ini, self.patch_size)
cropped_subject['index_ini'] = index_ini.astype(int)
return cropped_subject

def get_random_index_ini(
self,
Expand Down
14 changes: 0 additions & 14 deletions torchio/data/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,6 @@ def load(self):
for image in self.get_images(intensity_only=False):
image.load()

def crop(self, index_ini, index_fin):
"""Make a copy of the subject with a reduced field of view (patch)."""
result_dict = {}
for key, value in self.items():
if isinstance(value, Image):
# patch.clone() is much faster than copy.deepcopy(patch)
value = value.crop(index_ini, index_fin)
else:
value = copy.deepcopy(value)
result_dict[key] = value
new = Subject(result_dict)
new.history = self.history
return new

def update_attributes(self):
# This allows to get images using attribute notation, e.g. subject.t1
self.__dict__.update(self)
Expand Down
1 change: 1 addition & 0 deletions torchio/torchio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TypeNumber = Union[int, float]
TypeData = Union[torch.Tensor, np.ndarray]
TypeTripletInt = Tuple[int, int, int]
TypeSextetInt = Tuple[int, int, int, int, int, int]
TypeTripletFloat = Tuple[float, float, float]
TypeTuple = Union[int, TypeTripletInt]
TypeRangeInt = Union[int, Tuple[int, int]]
Expand Down
21 changes: 12 additions & 9 deletions torchio/transforms/augmentation/intensity/random_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
import numpy as np
import SimpleITK as sitk
from ....utils import nib_to_sitk, sitk_to_nib
from ....torchio import DATA, AFFINE, TypeData
import scipy.ndimage as ndi
from ....torchio import DATA, AFFINE, TypeData, TypeTripletFloat
from ....data.subject import Subject
from ... import IntensityTransform
from .. import RandomTransform
Expand All @@ -25,7 +25,7 @@ class RandomBlur(RandomTransform, IntensityTransform):
"""
def __init__(
self,
std: Union[float, Tuple[float, float]] = (0, 4),
std: Union[float, Tuple[float, float]] = (0, 2),
p: float = 1,
seed: Optional[int] = None,
keys: Optional[List[str]] = None,
Expand All @@ -44,7 +44,7 @@ def apply_transform(self, sample: Subject) -> dict:
random_parameters_images_dict[key] = random_parameters_dict
transformed_tensor = blur(
tensor,
image[AFFINE],
image.spacing,
std,
)
transformed_tensors.append(transformed_tensor)
Expand All @@ -58,10 +58,13 @@ def get_params(std_range: Tuple[float, float]) -> np.ndarray:
return std


def blur(data: TypeData, affine: TypeData, std: np.ndarray) -> torch.Tensor:
def blur(
data: TypeData,
spacing: TypeTripletFloat,
std_voxel: np.ndarray,
) -> torch.Tensor:
assert data.ndim == 3
image = nib_to_sitk(data[np.newaxis], affine)
image = sitk.DiscreteGaussian(image, std.tolist())
array, _ = sitk_to_nib(image)
tensor = torch.from_numpy(array[0])
std_physical = np.array(std_voxel) * np.array(spacing)
blurred = ndi.gaussian_filter(data, std_physical)
tensor = torch.from_numpy(blurred)
return tensor
31 changes: 1 addition & 30 deletions torchio/transforms/preprocessing/spatial/bounds_transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Union, Tuple, List, Optional
import torch
import numpy as np
import SimpleITK as sitk
from ....data.subject import Subject
from ....torchio import DATA, AFFINE, TypeTripletInt
from ... import SpatialTransform
Expand Down Expand Up @@ -38,7 +37,7 @@ def bounds_function(self):
raise NotImplementedError

@staticmethod
def parse_bounds(bounds_parameters: TypeBounds) -> Tuple[int, ...]:
def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
try:
bounds_parameters = tuple(bounds_parameters)
except TypeError:
Expand All @@ -65,31 +64,3 @@ def parse_bounds(bounds_parameters: TypeBounds) -> Tuple[int, ...]:
f' 3 or 6 integers, not {bounds_parameters}'
)
raise ValueError(message)

def apply_transform(self, sample: Subject) -> dict:
low = self.bounds_parameters[::2]
high = self.bounds_parameters[1::2]
for image in self.get_images(sample):
itk_image = image.as_sitk()
result = self._apply_bounds_function(itk_image, low, high)
data, affine = self.sitk_to_nib(result)
tensor = torch.from_numpy(data)
image[DATA] = tensor
image[AFFINE] = affine
return sample

def _apply_bounds_function(self, image, low, high):
num_components = image.GetNumberOfComponentsPerPixel()
if self.bounds_function == sitk.Crop or num_components == 1:
result = self.bounds_function(image, low, high)
else: # padding not supported for vector images
components = [
sitk.VectorIndexSelectionCast(image, i)
for i in range(num_components)
]
components_padded = [
self.bounds_function(component, low, high)
for component in components
]
result = sitk.Compose(components_padded)
return result
22 changes: 16 additions & 6 deletions torchio/transforms/preprocessing/spatial/crop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable
import SimpleITK as sitk
import numpy as np
import nibabel as nib
from ....torchio import TypeTripletInt, DATA
from .bounds_transform import BoundsTransform


Expand All @@ -21,8 +22,17 @@ class Crop(BoundsTransform):
If only one value :math:`n` is provided, then
:math:`w_{ini} = w_{fin} = h_{ini} = h_{fin}
= d_{ini} = d_{fin} = n`.
"""
@property
def bounds_function(self) -> Callable:
return sitk.Crop
def apply_transform(self, sample) -> dict:
low = self.bounds_parameters[::2]
high = self.bounds_parameters[1::2]
index_ini = low
index_fin = np.array(sample.spatial_shape) - high
for image in self.get_images(sample):
new_origin = nib.affines.apply_affine(image.affine, index_ini)
new_affine = image.affine.copy()
new_affine[:3, 3] = new_origin
i0, j0, k0 = index_ini
i1, j1, k1 = index_fin
image[DATA] = image[DATA][:, i0:i1, j0:j1, k0:k1].clone()
return sample
Loading

0 comments on commit 281b2ea

Please sign in to comment.