Skip to content

Commit

Permalink
Merge pull request #550 from effigies/enh/crop
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-brett committed Jun 2, 2018
2 parents 1584b3b + 919b71a commit 41e126a
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 6 deletions.
21 changes: 21 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@ Gerhard (SG) and Eric Larson (EL).

References like "pr/298" refer to github pull request numbers.

Upcoming Release
================

New features
------------
* Image slicing for SpatialImages (pr/550) (CM)

Enhancements
------------
* Simplfiy MGHImage and add footer fields (pr/569) (CM, reviewed by MB)

Bug fixes
---------

Maintenance
-----------

API changes and deprecations
----------------------------


2.2.1 (Wednesday 22 November 2017)
==================================

Expand Down
127 changes: 124 additions & 3 deletions nibabel/spatialimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
from .filebasedimages import ImageFileError # flake8: noqa; for back-compat
from .viewers import OrthoSlicer3D
from .volumeutils import shape_zoom_affine
from .fileslice import canonical_slicers
from .deprecated import deprecate_with_version
from .orientations import apply_orientation, inv_ornt_aff

Expand Down Expand Up @@ -321,9 +322,103 @@ class ImageDataError(Exception):
pass


class SpatialFirstSlicer(object):
''' Slicing interface that returns a new image with an updated affine
Checks that an image's first three axes are spatial
'''
def __init__(self, img):
# Local import to avoid circular import on module load
from .imageclasses import spatial_axes_first
if not spatial_axes_first(img):
raise ValueError("Cannot predict position of spatial axes for "
"Image type " + img.__class__.__name__)
self.img = img

def __getitem__(self, slicer):
try:
slicer = self.check_slicing(slicer)
except ValueError as err:
raise IndexError(*err.args)

dataobj = self.img.dataobj[slicer]
if any(dim == 0 for dim in dataobj.shape):
raise IndexError("Empty slice requested")

affine = self.slice_affine(slicer)
return self.img.__class__(dataobj.copy(), affine, self.img.header)

def check_slicing(self, slicer, return_spatial=False):
''' Canonicalize slicers and check for scalar indices in spatial dims
Parameters
----------
slicer : object
something that can be used to slice an array as in
``arr[sliceobj]``
return_spatial : bool
return only slices along spatial dimensions (x, y, z)
Returns
-------
slicer : object
Validated slicer object that will slice image's `dataobj`
without collapsing spatial dimensions
'''
slicer = canonical_slicers(slicer, self.img.shape)
# We can get away with this because we've checked the image's
# first three axes are spatial.
# More general slicers will need to be smarter, here.
spatial_slices = slicer[:3]
for subslicer in spatial_slices:
if subslicer is None:
raise IndexError("New axis not permitted in spatial dimensions")
elif isinstance(subslicer, int):
raise IndexError("Scalar indices disallowed in spatial dimensions; "
"Use `[x]` or `x:x+1`.")
return spatial_slices if return_spatial else slicer

def slice_affine(self, slicer):
""" Retrieve affine for current image, if sliced by a given index
Applies scaling if down-sampling is applied, and adjusts the intercept
to account for any cropping.
Parameters
----------
slicer : object
something that can be used to slice an array as in
``arr[sliceobj]``
Returns
-------
affine : (4,4) ndarray
Affine with updated scale and intercept
"""
slicer = self.check_slicing(slicer, return_spatial=True)

# Transform:
# sx 0 0 tx
# 0 sy 0 ty
# 0 0 sz tz
# 0 0 0 1
transform = np.eye(4, dtype=int)

for i, subslicer in enumerate(slicer):
if isinstance(subslicer, slice):
if subslicer.step == 0:
raise ValueError("slice step cannot be 0")
transform[i, i] = subslicer.step if subslicer.step is not None else 1
transform[i, 3] = subslicer.start or 0
# If slicer is None, nothing to do

return self.img.affine.dot(transform)


class SpatialImage(DataobjImage):
''' Template class for volumetric (3D/4D) images '''
header_class = SpatialHeader
ImageSlicer = SpatialFirstSlicer

def __init__(self, dataobj, affine, header=None,
extra=None, file_map=None):
Expand Down Expand Up @@ -461,12 +556,38 @@ def from_image(klass, img):
klass.header_class.from_header(img.header),
extra=img.extra.copy())

@property
def slicer(self):
""" Slicer object that returns cropped and subsampled images
The image is resliced in the current orientation; no rotation or
resampling is performed, and no attempt is made to filter the image
to avoid `aliasing`_.
The affine matrix is updated with the new intercept (and scales, if
down-sampling is used), so that all values are found at the same RAS
locations.
Slicing may include non-spatial dimensions.
However, this method does not currently adjust the repetition time in
the image header.
.. _aliasing: https://en.wikipedia.org/wiki/Aliasing
"""
return self.ImageSlicer(self)


def __getitem__(self, idx):
''' No slicing or dictionary interface for images
Use the slicer attribute to perform cropping and subsampling at your
own risk.
'''
raise TypeError("Cannot slice image objects; consider slicing image "
"array data with `img.dataobj[slice]` or "
"`img.get_data()[slice]`")
raise TypeError(
"Cannot slice image objects; consider using `img.slicer[slice]` "
"to generate a sliced image (see documentation for caveats) or "
"slicing image array data with `img.dataobj[slice]` or "
"`img.get_data()[slice]`")

def orthoview(self):
"""Plot the image using OrthoSlicer3D
Expand Down
134 changes: 131 additions & 3 deletions nibabel/tests/test_spatialimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from io import BytesIO
from ..spatialimages import (SpatialHeader, SpatialImage, HeaderDataError,
Header, ImageDataError)
from ..imageclasses import spatial_axes_first

from unittest import TestCase
from nose.tools import (assert_true, assert_false, assert_equal,
Expand Down Expand Up @@ -385,9 +386,10 @@ def test_get_data(self):
img[0, 0, 0]
# Make sure the right message gets raised:
assert_equal(str(exception_manager.exception),
("Cannot slice image objects; consider slicing image "
"array data with `img.dataobj[slice]` or "
"`img.get_data()[slice]`"))
"Cannot slice image objects; consider using "
"`img.slicer[slice]` to generate a sliced image (see "
"documentation for caveats) or slicing image array data "
"with `img.dataobj[slice]` or `img.get_data()[slice]`")
assert_true(in_data is img.dataobj)
out_data = img.get_data()
assert_true(in_data is out_data)
Expand All @@ -411,6 +413,132 @@ def test_get_data(self):
assert_false(rt_img.get_data() is out_data)
assert_array_equal(rt_img.get_data(), in_data)

def test_slicer(self):
img_klass = self.image_class
in_data_template = np.arange(240, dtype=np.int16)
base_affine = np.eye(4)
t_axis = None
for dshape in ((4, 5, 6, 2), # Time series
(8, 5, 6)): # Volume
in_data = in_data_template.copy().reshape(dshape)
img = img_klass(in_data, base_affine.copy())

if not spatial_axes_first(img):
with assert_raises(ValueError):
img.slicer
continue

assert_true(hasattr(img.slicer, '__getitem__'))

# Note spatial zooms are always first 3, even when
spatial_zooms = img.header.get_zooms()[:3]

# Down-sample with [::2, ::2, ::2] along spatial dimensions
sliceobj = [slice(None, None, 2)] * 3 + \
[slice(None)] * (len(dshape) - 3)
downsampled_img = img.slicer[tuple(sliceobj)]
assert_array_equal(downsampled_img.header.get_zooms()[:3],
np.array(spatial_zooms) * 2)

max4d = (hasattr(img.header, '_structarr') and
'dims' in img.header._structarr.dtype.fields and
img.header._structarr['dims'].shape == (4,))
# Check newaxis and single-slice errors
with assert_raises(IndexError):
img.slicer[None]
with assert_raises(IndexError):
img.slicer[0]
# Axes 1 and 2 are always spatial
with assert_raises(IndexError):
img.slicer[:, None]
with assert_raises(IndexError):
img.slicer[:, 0]
with assert_raises(IndexError):
img.slicer[:, :, None]
with assert_raises(IndexError):
img.slicer[:, :, 0]
if len(img.shape) == 4:
if max4d:
with assert_raises(ValueError):
img.slicer[:, :, :, None]
else:
# Reorder non-spatial axes
assert_equal(img.slicer[:, :, :, None].shape,
img.shape[:3] + (1,) + img.shape[3:])
# 4D to 3D using ellipsis or slices
assert_equal(img.slicer[..., 0].shape, img.shape[:-1])
assert_equal(img.slicer[:, :, :, 0].shape, img.shape[:-1])
else:
# 3D Analyze/NIfTI/MGH to 4D
assert_equal(img.slicer[:, :, :, None].shape, img.shape + (1,))
if len(img.shape) == 3:
# Slices exceed dimensions
with assert_raises(IndexError):
img.slicer[:, :, :, :, None]
elif max4d:
with assert_raises(ValueError):
img.slicer[:, :, :, :, None]
else:
assert_equal(img.slicer[:, :, :, :, None].shape,
img.shape + (1,))

# Crop by one voxel in each dimension
sliced_i = img.slicer[1:]
sliced_j = img.slicer[:, 1:]
sliced_k = img.slicer[:, :, 1:]
sliced_ijk = img.slicer[1:, 1:, 1:]

# No scaling change
assert_array_equal(sliced_i.affine[:3, :3], img.affine[:3, :3])
assert_array_equal(sliced_j.affine[:3, :3], img.affine[:3, :3])
assert_array_equal(sliced_k.affine[:3, :3], img.affine[:3, :3])
assert_array_equal(sliced_ijk.affine[:3, :3], img.affine[:3, :3])
# Translation
assert_array_equal(sliced_i.affine[:, 3], [1, 0, 0, 1])
assert_array_equal(sliced_j.affine[:, 3], [0, 1, 0, 1])
assert_array_equal(sliced_k.affine[:, 3], [0, 0, 1, 1])
assert_array_equal(sliced_ijk.affine[:, 3], [1, 1, 1, 1])

# No change to affines with upper-bound slices
assert_array_equal(img.slicer[:1, :1, :1].affine, img.affine)

# Yell about step = 0
with assert_raises(ValueError):
img.slicer[:, ::0]
with assert_raises(ValueError):
img.slicer.slice_affine((slice(None), slice(None, None, 0)))

# Don't permit zero-length slices
with assert_raises(IndexError):
img.slicer[:0]

# No fancy indexing
with assert_raises(IndexError):
img.slicer[[0]]
with assert_raises(IndexError):
img.slicer[[-1]]
with assert_raises(IndexError):
img.slicer[[0], [-1]]

# Check data is consistent with slicing numpy arrays
slice_elems = (None, Ellipsis, 0, 1, -1, [0], [1], [-1],
slice(None), slice(1), slice(-1), slice(1, -1))
for n_elems in range(6):
for _ in range(1 if n_elems == 0 else 10):
sliceobj = tuple(
np.random.choice(slice_elems, n_elems).tolist())
try:
sliced_img = img.slicer[sliceobj]
except (IndexError, ValueError):
# Only checking valid slices
pass
else:
sliced_data = in_data[sliceobj]
assert_array_equal(sliced_data, sliced_img.get_data())
assert_array_equal(sliced_data, sliced_img.dataobj)
assert_array_equal(sliced_data, img.dataobj[sliceobj])
assert_array_equal(sliced_data, img.get_data()[sliceobj])

def test_api_deprecations(self):

class FakeImage(self.image_class):
Expand Down

0 comments on commit 41e126a

Please sign in to comment.