Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Outsource apply() from transform objects #195

Merged
merged 16 commits into from
May 17, 2024
107 changes: 9 additions & 98 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from nibabel import funcs as _nbfuncs
from nibabel.nifti1 import intent_codes as INTENT_CODES
from nibabel.cifti2 import Cifti2Image
from scipy import ndimage as ndi

EQUALITY_TOL = 1e-5

Expand Down Expand Up @@ -178,7 +177,10 @@ def __ne__(self, other):
class TransformBase:
"""Abstract image class to represent transforms."""

__slots__ = ("_reference", "_ndim",)
__slots__ = (
"_reference",
"_ndim",
)

def __init__(self, reference=None):
"""Instantiate a transform."""
Expand Down Expand Up @@ -222,101 +224,6 @@ def ndim(self):
"""Access the dimensions of the reference space."""
raise TypeError("TransformBase has no dimensions")

def apply(
self,
spatialimage,
reference=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.

Parameters
----------
spatialimage : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional
Determines how the input image is extended when the resamplings overflows
a border. Default is 'constant'.
cval : float, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: bool, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
output_dtype: dtype specifier, optional
The dtype of the returned array or image, if specified.
If ``None``, the default behavior is to use the effective dtype of
the input image. If slope and/or intercept are defined, the effective
dtype is float64, otherwise it is equivalent to the input image's
``get_data_dtype()`` (on-disk type).
If ``reference`` is defined, then the return value is an image, with
a data array of the effective dtype but with the on-disk dtype set to
the input image's on-disk dtype.

Returns
-------
resampled : `spatialimage` or ndarray
The data imaged after resampling to reference space.

"""
if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

_ref = (
self.reference if reference is None else SpatialReference.factory(reference)
)

if _ref is None:
raise TransformError("Cannot apply transform without reference")

if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(spatialimage.dataobj)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim)
)

resampled = ndi.map_coordinates(
data,
targets.T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
hdr = None
if _ref.header is not None:
hdr = _ref.header.copy()
hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype())
moved = spatialimage.__class__(
resampled.reshape(_ref.shape),
_ref.affine,
hdr,
)
return moved

return resampled

def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand Down Expand Up @@ -382,4 +289,8 @@ def _as_homogeneous(xyz, dtype="float32", dim=3):

def _apply_affine(x, affine, dim):
"""Get the image array's indexes corresponding to coordinates."""
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T
return np.tensordot(
affine,
_as_homogeneous(x, dim=dim).T,
axes=1,
)[:dim, ...]
4 changes: 3 additions & 1 deletion nitransforms/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .linear import load as linload
from .nonlinear import load as nlinload
from .resampling import apply


def cli_apply(pargs):
Expand Down Expand Up @@ -38,7 +39,8 @@ def cli_apply(pargs):
# ensure a reference is set
xfm.reference = pargs.ref or pargs.moving

moved = xfm.apply(
moved = apply(
xfm,
pargs.moving,
order=pargs.order,
mode=pargs.mode,
Expand Down
125 changes: 4 additions & 121 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
import warnings
import numpy as np
from pathlib import Path
from scipy import ndimage as ndi

from nibabel.loadsave import load as _nbload
from nibabel.affines import from_matvec
from nibabel.arrayproxy import get_obj_dtype

from nitransforms.base import (
ImageGrid,
TransformBase,
SpatialReference,
_as_homogeneous,
EQUALITY_TOL,
)
Expand Down Expand Up @@ -113,6 +109,10 @@ def __invert__(self):
"""
return self.__class__(self._inverse)

def __len__(self):
"""Enable using len()."""
return 1 if self._matrix.ndim == 2 else len(self._matrix)

def __matmul__(self, b):
"""
Compose two Affines.
Expand Down Expand Up @@ -330,10 +330,6 @@ def __getitem__(self, i):
"""Enable indexed access to the series of matrices."""
return Affine(self.matrix[i, ...], reference=self._reference)

def __len__(self):
"""Enable using len()."""
return len(self._matrix)

def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand Down Expand Up @@ -402,119 +398,6 @@ def to_filename(self, filename, fmt="X5", moving=None):
).to_filename(filename)
return filename

def apply(
self,
spatialimage,
reference=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.

Parameters
----------
spatialimage : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : {"constant", "reflect", "nearest", "mirror", "wrap"}, optional
Determines how the input image is extended when the resamplings overflows
a border. Default is "constant".
cval : float, optional
Constant value for ``mode="constant"``. Default is 0.0.
prefilter: bool, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.

Returns
-------
resampled : `spatialimage` or ndarray
The data imaged after resampling to reference space.

"""

if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

_ref = (
self.reference if reference is None else SpatialReference.factory(reference)
)

if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

# Prepare physical coordinates of input (grid, points)
xcoords = _ref.ndcoords.astype("f4").T

# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)

if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), spatialimage.shape[-1])
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.shape[0], len(self)), dtype=output_dtype, order="F"
)

dataobj = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
if spatialimage.ndim in (2, 3)
else None
)

for t, xfm_t in enumerate(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been holding on this one - the new implementation is 4D, while the old implementation was 3D+t (meaning that you could not resample off-gried on the fourth dimension).

We should make sure we meet today @effigies, @jmarabotto and discuss this.

cc/ @sgiavasis who may be interested in learning more.

# Map the input coordinates on to timepoint t of the target (moving)
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
dataobj
if dataobj is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
newdata = resampled.reshape(_ref.shape + (len(self),))
moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header)
moved.header.set_data_dtype(output_dtype)
return moved

return resampled


def load(filename, fmt=None, reference=None, moving=None):
"""
Expand Down