Skip to content

Commit

Permalink
Merge 27caf58 into 53ab14d
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Aug 17, 2020
2 parents 53ab14d + 27caf58 commit f096297
Show file tree
Hide file tree
Showing 22 changed files with 182 additions and 137 deletions.
27 changes: 18 additions & 9 deletions tests/data/test_image.py
Expand Up @@ -112,7 +112,7 @@ def test_nans_tensor(self):
def test_nans_file(self):
image = ScalarImage(self.get_image_path('repr_test', add_nans=True))
with self.assertWarns(UserWarning):
image._load()
image.load()

def test_get_center(self):
tensor = torch.rand(1, 3, 3, 3)
Expand All @@ -139,19 +139,28 @@ def test_with_a_list_of_images_with_different_shapes(self):
path2 = self.get_image_path('path2', shape=(7, 5, 5))
image = ScalarImage(path=[path1, path2])
with self.assertRaises(RuntimeError):
image._load()
image.load()

def test_with_a_list_of_images_with_different_affines(self):
path1 = self.get_image_path('path1', spacing=(1, 1, 1))
path2 = self.get_image_path('path2', spacing=(1, 2, 1))
image = ScalarImage(path=[path1, path2])
with self.assertWarns(RuntimeWarning):
image._load()
image.load()

def test_with_a_list_of_2d_paths(self):
shape = (5, 5)
path1 = self.get_image_path('path1', shape=shape)
path2 = self.get_image_path('path2', shape=shape)
image = ScalarImage(path=[path1, path2])
self.assertEqual(image.shape, (2, 1, 5, 5))
self.assertEqual(image[STEM], ['path1', 'path2'])
shape = (5, 6)
path1 = self.get_image_path('path1', shape=shape, suffix='.nii')
path2 = self.get_image_path('path2', shape=shape, suffix='.img')
path3 = self.get_image_path('path3', shape=shape, suffix='.hdr')
image = ScalarImage(path=[path1, path2, path3])
self.assertEqual(image.shape, (3, 5, 6, 1))
self.assertEqual(image[STEM], ['path1', 'path2', 'path3'])

def test_axis_name_2d(self):
path = self.get_image_path('im2d', shape=(5, 6))
image = ScalarImage(path)
height_idx = image.axis_name_to_index('h')
width_idx = image.axis_name_to_index('w')
self.assertEqual(image.height, image.shape[height_idx])
self.assertEqual(image.width, image.shape[width_idx])
16 changes: 8 additions & 8 deletions tests/test_utils.py
Expand Up @@ -53,7 +53,7 @@ def test_apply_transform_to_file(self):
)

def test_sitk_to_nib(self):
data = np.random.rand(10, 10)
data = np.random.rand(10, 12)
image = sitk.GetImageFromArray(data)
tensor, affine = sitk_to_nib(image)
self.assertAlmostEqual(data.sum(), tensor.sum())
Expand All @@ -64,36 +64,36 @@ def setUp(self):
super().setUp()
self.affine = np.eye(4)

def test_wrong_dims(self):
def test_wrong_num_dims(self):
with self.assertRaises(ValueError):
nib_to_sitk(np.random.rand(10, 10), self.affine)

def test_2d_single(self):
data = np.random.rand(1, 1, 10, 12)
data = np.random.rand(1, 10, 12, 1)
image = nib_to_sitk(data, self.affine)
assert image.GetDimension() == 2
assert image.GetSize() == (10, 12)
assert image.GetNumberOfComponentsPerPixel() == 1

def test_2d_multi(self):
data = np.random.rand(5, 1, 10, 12)
data = np.random.rand(5, 10, 12, 1)
image = nib_to_sitk(data, self.affine)
assert image.GetDimension() == 2
assert image.GetSize() == (10, 12)
assert image.GetNumberOfComponentsPerPixel() == 5

def test_2d_3d_single(self):
data = np.random.rand(1, 1, 10, 12)
data = np.random.rand(1, 10, 12, 1)
image = nib_to_sitk(data, self.affine, force_3d=True)
assert image.GetDimension() == 3
assert image.GetSize() == (1, 10, 12)
assert image.GetSize() == (10, 12, 1)
assert image.GetNumberOfComponentsPerPixel() == 1

def test_2d_3d_multi(self):
data = np.random.rand(5, 1, 10, 12)
data = np.random.rand(5, 10, 12, 1)
image = nib_to_sitk(data, self.affine, force_3d=True)
assert image.GetDimension() == 3
assert image.GetSize() == (1, 10, 12)
assert image.GetSize() == (10, 12, 1)
assert image.GetNumberOfComponentsPerPixel() == 5

def test_3d_single(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/transforms/test_transforms.py
Expand Up @@ -14,11 +14,11 @@ def get_transform(self, channels, is_3d=True, labels=True):
landmarks_dict = {
channel: np.linspace(0, 100, 13) for channel in channels
}
disp = 1 if is_3d else (0.01, 1, 1)
disp = 1 if is_3d else (1, 1, 0.01)
elastic = torchio.RandomElasticDeformation(max_displacement=disp)
cp_args = (9, 21, 30) if is_3d else (1, 21, 30)
cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
flip_axes = (0, 1, 2) if is_3d else (0, 1)
swap_patch = (2, 3, 4) if is_3d else (1, 3, 4)
swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
transforms = [
Expand Down
9 changes: 6 additions & 3 deletions tests/utils.py
Expand Up @@ -63,7 +63,7 @@ def setUp(self):
def make_2d(self, sample):
sample = copy.deepcopy(sample)
for image in sample.get_images(intensity_only=False):
image[DATA] = image[DATA][:, 0:1, ...]
image[DATA] = image[DATA][..., :1]
return sample

def make_4d(self, sample):
Expand Down Expand Up @@ -135,15 +135,18 @@ def get_image_path(
shape=(10, 20, 30),
spacing=(1, 1, 1),
components=1,
add_nans=False
add_nans=False,
suffix=None,
):
shape = (*shape, 1) if len(shape) == 2 else shape
data = np.random.rand(components, *shape)
if binary:
data = (data > 0.5).astype(np.uint8)
if add_nans:
data[:] = np.nan
affine = np.diag((*spacing, 1))
suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img'))
if suffix is None:
suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img'))
path = self.dir / f'{stem}{suffix}'
if np.random.rand() > 0.5:
path = str(path)
Expand Down
39 changes: 30 additions & 9 deletions torchio/data/image.py
Expand Up @@ -61,7 +61,7 @@ class Image(dict):
:py:class:`~torchio.data.sampler.weighted.WeightedSampler`.
tensor: If :py:attr:`path` is not given, :attr:`tensor` must be a 4D
:py:class:`torch.Tensor` or NumPy array with dimensions
:math:`(C, D, H, W)`. If it is not 4D, TorchIO will try to guess
:math:`(C, H, W, D)`. If it is not 4D, TorchIO will try to guess
the dimensions meanings. If 2D, the shape will be interpreted as
:math:`(H, W)`. If 3D, the number of spatial dimensions should be
determined in :attr:`num_spatial_dims`. If :attr:`num_spatial_dims`
Expand Down Expand Up @@ -174,7 +174,7 @@ def __repr__(self):
def __getitem__(self, item):
if item in (DATA, AFFINE):
if item not in self:
self._load()
self.load()
return super().__getitem__(item)

def __array__(self):
Expand Down Expand Up @@ -217,6 +217,21 @@ def shape(self) -> Tuple[int, int, int, int]:
def spatial_shape(self) -> TypeTripletInt:
return self.shape[1:]

def check_is_2d(self):
if not self.is_2d():
message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
raise RuntimeError(message)

@property
def height(self) -> int:
self.check_is_2d()
return self.spatial_shape[0]

@property
def width(self) -> int:
self.check_is_2d()
return self.spatial_shape[1]

@property
def orientation(self):
return nib.aff2axcodes(self.affine)
Expand Down Expand Up @@ -248,11 +263,11 @@ def axis_name_to_index(self, axis: str):
raise ValueError('Axis must be a string')
axis = axis[0].upper()

# Generally, TorchIO tensors are (C, D, H, W)
# Generally, TorchIO tensors are (C, H, W, D)
if axis == 'H':
return -2
return 1
elif axis == 'W':
return -1
return 2
else:
try:
index = self.orientation.index(axis)
Expand Down Expand Up @@ -350,11 +365,11 @@ def parse_affine(affine: np.ndarray) -> np.ndarray:
raise ValueError(f'Affine shape must be (4, 4), not {affine.shape}')
return affine

def _load(self) -> None:
def load(self) -> None:
r"""Load the image from disk.
Returns:
Tuple containing a 4D tensor of size :math:`(C, D, H, W)` and a 2D
Tuple containing a 4D tensor of size :math:`(C, H, W, D)` and a 2D
:math:`4 \times 4` affine matrix to convert voxel indices to world
coordinates.
"""
Expand All @@ -377,7 +392,13 @@ def _load(self) -> None:
warnings.warn(f'NaNs found in file "{path}"')

if not np.array_equal(affine, new_affine):
message = 'Files have different affine matrices'
message = (
'Files have different affine matrices.'
f'\nMatrix of {paths[0]}:'
f'\n{affine}'
f'\nMatrix of {path}:'
f'\n{new_affine}'
)
warnings.warn(message, RuntimeWarning)

if not tensor.shape[1:] == new_tensor.shape[1:]:
Expand Down Expand Up @@ -414,7 +435,7 @@ def save(self, path, squeeze=True, channels_last=True):
)

def is_2d(self) -> bool:
return self.shape[-3] == 1
return self.shape[-1] == 1

def numpy(self) -> np.ndarray:
"""Get a NumPy array containing the image data."""
Expand Down
10 changes: 5 additions & 5 deletions torchio/data/inference/grid_sampler.py
Expand Up @@ -18,18 +18,18 @@ class GridSampler(PatchSampler, Dataset):
Args:
sample: Instance of :py:class:`~torchio.data.subject.Subject`
from which patches will be extracted.
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
patch_size: Tuple of integers :math:`(h, w, d)` to generate patches
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided,
:math:`d = h = w = n`.
patch_overlap: Tuple of even integers :math:`(d_o, h_o, w_o)` specifying
:math:`h = w = d = n`.
patch_overlap: Tuple of even integers :math:`(h_o, w_o, d_o)` specifying
the overlap between patches for dense inference. If a single number
:math:`n` is provided, :math:`d_o = h_o = w_o = n`.
:math:`n` is provided, :math:`h_o = w_o = d_o = n`.
padding_mode: Same as :attr:`padding_mode` in
:py:class:`~torchio.transforms.Pad`. If ``None``, the volume will
not be padded before sampling and patches at the border will not be
cropped by the aggregator. Otherwise, the volume will be padded with
:math:`\left(\frac{d_o}{2}, \frac{h_o}{2}, \frac{w_o}{2}\right)`
:math:`\left(\frac{h_o}{2}, \frac{w_o}{2}, \frac{d_o}{2} \right)`
on each side before sampling. If the sampler is passed to a
:py:class:`~torchio.data.GridAggregator`, it will crop the output
to its original size.
Expand Down
24 changes: 19 additions & 5 deletions torchio/data/io.py
Expand Up @@ -37,11 +37,16 @@ def _read_nibabel(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
return tensor, affine


def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
def _read_sitk(
path: TypePath,
transpose_2d: bool = True,
) -> Tuple[torch.Tensor, np.ndarray]:
if Path(path).is_dir(): # assume DICOM
image = _read_dicom(path)
else:
image = sitk.ReadImage(str(path))
if image.GetDimension() == 2 and transpose_2d:
image = sitk.PermuteAxes(image, (1, 0))
data, affine = sitk_to_nib(image, keepdim=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
Expand Down Expand Up @@ -103,10 +108,16 @@ def _write_nibabel(
if channels_last:
tensor = tensor.permute(1, 2, 3, 0)
tensor = tensor.squeeze() if squeeze else tensor
nii = nib.Nifti1Image(np.asarray(tensor), affine)
nii.header['qform_code'] = 1
nii.header['sform_code'] = 0
nii.to_filename(str(path))
suffix = Path(str(path).replace('.gz', '')).suffix
if '.nii' in suffix:
img = nib.Nifti1Image(np.asarray(tensor), affine)
elif '.hdr' in suffix or '.img' in suffix:
img = nib.Nifti1Pair(np.asarray(tensor), affine)
else:
raise nib.loadsave.ImageFileError
img.header['qform_code'] = 1
img.header['sform_code'] = 0
img.to_filename(str(path))


def _write_sitk(
Expand All @@ -115,13 +126,16 @@ def _write_sitk(
path: TypePath,
squeeze: bool = True,
use_compression: bool = True,
transpose_2d: bool = True,
) -> None:
assert tensor.ndim == 4
path = Path(path)
if path.suffix in ('.png', '.jpg', '.jpeg'):
warnings.warn(f'Casting to uint 8 before saving to {path}')
tensor = tensor.numpy().astype(np.uint8)
image = nib_to_sitk(tensor, affine, squeeze=squeeze)
if image.GetDimension() == 2 and transpose_2d:
image = sitk.PermuteAxes(image, (1, 0))
sitk.WriteImage(image, str(path), use_compression)


Expand Down
12 changes: 6 additions & 6 deletions torchio/data/sampler/sampler.py
Expand Up @@ -12,9 +12,9 @@ class PatchSampler:
r"""Base class for TorchIO samplers.
Args:
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided, :math:`d = h = w = n`.
patch_size: Tuple of integers :math:`(h, w, d)` to generate patches
of size :math:`h \times w \times d`.
If a single number :math:`n` is provided, :math:`h = w = d = n`.
"""
def __init__(self, patch_size: TypePatchSize):
patch_size_array = np.array(to_tuple(patch_size, length=3))
Expand Down Expand Up @@ -43,9 +43,9 @@ class RandomSampler(PatchSampler):
r"""Base class for TorchIO samplers.
Args:
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided, :math:`d = h = w = n`.
patch_size: Tuple of integers :math:`(h, w, d)` to generate patches
of size :math:`h \times w \times d`.
If a single number :math:`n` is provided, :math:`h = w = d = n`.
"""
def __call__(
self,
Expand Down
2 changes: 1 addition & 1 deletion torchio/data/subject.py
Expand Up @@ -148,7 +148,7 @@ def add_transform(

def load(self):
for image in self.get_images(intensity_only=False):
image._load()
image.load()

def crop(self, index_ini, index_fin):
result_dict = {}
Expand Down
Expand Up @@ -92,7 +92,7 @@ def apply_transform(self, sample: Subject) -> dict:
for image_name, image in self.get_images_dict(sample).items():
transformed_tensors = []
is_2d = image.is_2d()
axes = [a for a in self.axes if a != 0] if is_2d else self.axes
axes = [a for a in self.axes if a != 2] if is_2d else self.axes
for channel_idx, tensor in enumerate(image[DATA]):
params = self.get_params(
self.num_ghosts_range,
Expand Down

0 comments on commit f096297

Please sign in to comment.