diff --git a/nitransforms/base.py b/nitransforms/base.py index 96f00edb..26c0d475 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -15,7 +15,6 @@ from nibabel import funcs as _nbfuncs from nibabel.nifti1 import intent_codes as INTENT_CODES from nibabel.cifti2 import Cifti2Image -from scipy import ndimage as ndi EQUALITY_TOL = 1e-5 @@ -178,7 +177,10 @@ def __ne__(self, other): class TransformBase: """Abstract image class to represent transforms.""" - __slots__ = ("_reference", "_ndim",) + __slots__ = ( + "_reference", + "_ndim", + ) def __init__(self, reference=None): """Instantiate a transform.""" @@ -222,101 +224,6 @@ def ndim(self): """Access the dimensions of the reference space.""" raise TypeError("TransformBase has no dimensions") - def apply( - self, - spatialimage, - reference=None, - order=3, - mode="constant", - cval=0.0, - prefilter=True, - output_dtype=None, - ): - """ - Apply a transformation to an image, resampling on the reference spatial object. - - Parameters - ---------- - spatialimage : `spatialimage` - The image object containing the data to be resampled in reference - space - reference : spatial object, optional - The image, surface, or combination thereof containing the coordinates - of samples that will be sampled. - order : int, optional - The order of the spline interpolation, default is 3. - The order has to be in the range 0-5. - mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional - Determines how the input image is extended when the resamplings overflows - a border. Default is 'constant'. - cval : float, optional - Constant value for ``mode='constant'``. Default is 0.0. - prefilter: bool, optional - Determines if the image's data array is prefiltered with - a spline filter before interpolation. The default is ``True``, - which will create a temporary *float64* array of filtered values - if *order > 1*. If setting this to ``False``, the output will be - slightly blurred if *order > 1*, unless the input is prefiltered, - i.e. it is the result of calling the spline filter on the original - input. - output_dtype: dtype specifier, optional - The dtype of the returned array or image, if specified. - If ``None``, the default behavior is to use the effective dtype of - the input image. If slope and/or intercept are defined, the effective - dtype is float64, otherwise it is equivalent to the input image's - ``get_data_dtype()`` (on-disk type). - If ``reference`` is defined, then the return value is an image, with - a data array of the effective dtype but with the on-disk dtype set to - the input image's on-disk dtype. - - Returns - ------- - resampled : `spatialimage` or ndarray - The data imaged after resampling to reference space. - - """ - if reference is not None and isinstance(reference, (str, Path)): - reference = _nbload(str(reference)) - - _ref = ( - self.reference if reference is None else SpatialReference.factory(reference) - ) - - if _ref is None: - raise TransformError("Cannot apply transform without reference") - - if isinstance(spatialimage, (str, Path)): - spatialimage = _nbload(str(spatialimage)) - - data = np.asanyarray(spatialimage.dataobj) - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim) - ) - - resampled = ndi.map_coordinates( - data, - targets.T, - output=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) - - if isinstance(_ref, ImageGrid): # If reference is grid, reshape - hdr = None - if _ref.header is not None: - hdr = _ref.header.copy() - hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype()) - moved = spatialimage.__class__( - resampled.reshape(_ref.shape), - _ref.affine, - hdr, - ) - return moved - - return resampled - def map(self, x, inverse=False): r""" Apply :math:`y = f(x)`. @@ -382,4 +289,8 @@ def _as_homogeneous(xyz, dtype="float32", dim=3): def _apply_affine(x, affine, dim): """Get the image array's indexes corresponding to coordinates.""" - return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T + return np.tensordot( + affine, + _as_homogeneous(x, dim=dim).T, + axes=1, + )[:dim, ...] diff --git a/nitransforms/cli.py b/nitransforms/cli.py index 63b8bed4..8f8f5ce0 100644 --- a/nitransforms/cli.py +++ b/nitransforms/cli.py @@ -5,6 +5,7 @@ from .linear import load as linload from .nonlinear import load as nlinload +from .resampling import apply def cli_apply(pargs): @@ -38,7 +39,8 @@ def cli_apply(pargs): # ensure a reference is set xfm.reference = pargs.ref or pargs.moving - moved = xfm.apply( + moved = apply( + xfm, pargs.moving, order=pargs.order, mode=pargs.mode, diff --git a/nitransforms/linear.py b/nitransforms/linear.py index af14f396..71df6a16 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -10,16 +10,12 @@ import warnings import numpy as np from pathlib import Path -from scipy import ndimage as ndi -from nibabel.loadsave import load as _nbload from nibabel.affines import from_matvec -from nibabel.arrayproxy import get_obj_dtype from nitransforms.base import ( ImageGrid, TransformBase, - SpatialReference, _as_homogeneous, EQUALITY_TOL, ) @@ -113,6 +109,10 @@ def __invert__(self): """ return self.__class__(self._inverse) + def __len__(self): + """Enable using len().""" + return 1 if self._matrix.ndim == 2 else len(self._matrix) + def __matmul__(self, b): """ Compose two Affines. @@ -330,10 +330,6 @@ def __getitem__(self, i): """Enable indexed access to the series of matrices.""" return Affine(self.matrix[i, ...], reference=self._reference) - def __len__(self): - """Enable using len().""" - return len(self._matrix) - def map(self, x, inverse=False): r""" Apply :math:`y = f(x)`. @@ -402,119 +398,6 @@ def to_filename(self, filename, fmt="X5", moving=None): ).to_filename(filename) return filename - def apply( - self, - spatialimage, - reference=None, - order=3, - mode="constant", - cval=0.0, - prefilter=True, - output_dtype=None, - ): - """ - Apply a transformation to an image, resampling on the reference spatial object. - - Parameters - ---------- - spatialimage : `spatialimage` - The image object containing the data to be resampled in reference - space - reference : spatial object, optional - The image, surface, or combination thereof containing the coordinates - of samples that will be sampled. - order : int, optional - The order of the spline interpolation, default is 3. - The order has to be in the range 0-5. - mode : {"constant", "reflect", "nearest", "mirror", "wrap"}, optional - Determines how the input image is extended when the resamplings overflows - a border. Default is "constant". - cval : float, optional - Constant value for ``mode="constant"``. Default is 0.0. - prefilter: bool, optional - Determines if the image's data array is prefiltered with - a spline filter before interpolation. The default is ``True``, - which will create a temporary *float64* array of filtered values - if *order > 1*. If setting this to ``False``, the output will be - slightly blurred if *order > 1*, unless the input is prefiltered, - i.e. it is the result of calling the spline filter on the original - input. - - Returns - ------- - resampled : `spatialimage` or ndarray - The data imaged after resampling to reference space. - - """ - - if reference is not None and isinstance(reference, (str, Path)): - reference = _nbload(str(reference)) - - _ref = ( - self.reference if reference is None else SpatialReference.factory(reference) - ) - - if isinstance(spatialimage, (str, Path)): - spatialimage = _nbload(str(spatialimage)) - - # Avoid opening the data array just yet - input_dtype = get_obj_dtype(spatialimage.dataobj) - output_dtype = output_dtype or input_dtype - - # Prepare physical coordinates of input (grid, points) - xcoords = _ref.ndcoords.astype("f4").T - - # Invert target's (moving) affine once - ras2vox = ~Affine(spatialimage.affine) - - if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]): - raise ValueError( - "Attempting to apply %d transforms on a file with " - "%d timepoints" % (len(self), spatialimage.shape[-1]) - ) - - # Order F ensures individual volumes are contiguous in memory - # Also matches NIfTI, making final save more efficient - resampled = np.zeros( - (xcoords.shape[0], len(self)), dtype=output_dtype, order="F" - ) - - dataobj = ( - np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if spatialimage.ndim in (2, 3) - else None - ) - - for t, xfm_t in enumerate(self): - # Map the input coordinates on to timepoint t of the target (moving) - ycoords = xfm_t.map(xcoords)[..., : _ref.ndim] - - # Calculate corresponding voxel coordinates - yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] - - # Interpolate - resampled[..., t] = ndi.map_coordinates( - ( - dataobj - if dataobj is not None - else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) - ), - yvoxels.T, - output=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) - - if isinstance(_ref, ImageGrid): # If reference is grid, reshape - newdata = resampled.reshape(_ref.shape + (len(self),)) - moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header) - moved.header.set_data_dtype(output_dtype) - return moved - - return resampled - def load(filename, fmt=None, reference=None, moving=None): """ diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 69c19d35..f4b95142 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -18,7 +18,6 @@ TransformBase, TransformError, ImageGrid, - SpatialReference, _as_homogeneous, ) from scipy.ndimage import map_coordinates @@ -71,21 +70,18 @@ def __init__(self, field=None, is_deltas=True, reference=None): is_deltas = True try: - self.reference = ImageGrid( - reference if reference is not None else field - ) + self.reference = ImageGrid(reference if reference is not None else field) except AttributeError: raise TransformError( "Field must be a spatial image if reference is not provided" - if reference is None else - "Reference is not a spatial image" + if reference is None + else "Reference is not a spatial image" ) - ndim = self._field.ndim - 1 - if self._field.shape[-1] != ndim: + if self._field.shape[-1] != self.ndim: raise TransformError( "The number of components of the field (%d) does not match " - "the number of dimensions (%d)" % (self._field.shape[-1], ndim) + "the number of dimensions (%d)" % (self._field.shape[-1], self.ndim) ) if is_deltas: @@ -167,19 +163,22 @@ def map(self, x, inverse=False): indexes = np.round(ijk).astype("int") if np.all(np.abs(ijk - indexes) < 1e-3): - indexes = tuple(tuple(i) for i in indexes.T) + indexes = tuple(tuple(i) for i in indexes) return self._field[indexes] - new_map = np.vstack(tuple( - map_coordinates( - self._field[..., i], - ijk.T, - order=3, - mode="constant", - cval=np.nan, - prefilter=True, - ) for i in range(self.reference.ndim) - )).T + new_map = np.vstack( + tuple( + map_coordinates( + self._field[..., i], + ijk.T, + order=3, + mode="constant", + cval=np.nan, + prefilter=True, + ) + for i in range(self.reference.ndim) + ) + ).T # Set NaN values back to the original coordinates value = no displacement new_map[np.isnan(new_map)] = x[np.isnan(new_map)] @@ -205,9 +204,9 @@ def __matmul__(self, b): True """ - retval = b.map( - self._field.reshape((-1, self._field.shape[-1])) - ).reshape(self._field.shape) + retval = b.map(self._field.reshape((-1, self._field.shape[-1]))).reshape( + self._field.shape + ) return DenseFieldTransform(retval, is_deltas=False, reference=self.reference) def __eq__(self, other): @@ -246,7 +245,7 @@ def from_filename(cls, filename, fmt="X5"): class BSplineFieldTransform(TransformBase): """Represent a nonlinear transform parameterized by BSpline basis.""" - __slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving'] + __slots__ = ["_coeffs", "_knots", "_weights", "_order", "_moving"] def __init__(self, coefficients, reference=None, order=3): """Create a smooth deformation field using B-Spline basis.""" @@ -261,10 +260,11 @@ def __init__(self, coefficients, reference=None, order=3): if reference is not None: self.reference = reference - if coefficients.shape[-1] != self.ndim: + if coefficients.shape[-1] != self.reference.ndim: raise TransformError( - 'Number of components of the coefficients does ' - 'not match the number of dimensions') + "Number of components of the coefficients does " + "not match the number of dimensions" + ) @property def ndim(self): @@ -274,20 +274,17 @@ def ndim(self): def to_field(self, reference=None, dtype="float32"): """Generate a displacements deformation field from this B-Spline field.""" _ref = ( - self.reference if reference is None else - ImageGrid(_ensure_image(reference)) + self.reference if reference is None else ImageGrid(_ensure_image(reference)) ) if _ref is None: raise TransformError("A reference must be defined") - ndim = self._coeffs.shape[-1] - if self._weights is None: self._weights = grid_bspline_weights(_ref, self._knots) - field = np.zeros((_ref.npoints, ndim)) + field = np.zeros((_ref.npoints, self.ndim)) - for d in range(ndim): + for d in range(self.ndim): # 1 x Nvox : (1 x K) @ (K x Nvox) field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights @@ -295,47 +292,6 @@ def to_field(self, reference=None, dtype="float32"): field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref ) - def apply( - self, - spatialimage, - reference=None, - order=3, - mode="constant", - cval=0.0, - prefilter=True, - output_dtype=None, - ): - """Apply a B-Spline transform on input data.""" - - _ref = ( - self.reference if reference is None else - SpatialReference.factory(_ensure_image(reference)) - ) - spatialimage = _ensure_image(spatialimage) - - # If locations to be interpolated are not on a grid, run map() - if not isinstance(_ref, ImageGrid): - return super().apply( - spatialimage, - reference=_ref, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - output_dtype=output_dtype, - ) - - # If locations to be interpolated are on a grid, generate a displacements field - return self.to_field(reference=reference).apply( - spatialimage, - reference=reference, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - output_dtype=output_dtype, - ) - def map(self, x, inverse=False): r""" Apply the transformation to a list of physical coordinate points. @@ -386,9 +342,9 @@ def _map_xyz(x, reference, knots, coeffs): # Probably this will change if the order of the B-Spline is different w_start, w_end = np.ceil(ijk - 2).astype(int), np.floor(ijk + 2).astype(int) # Generate a grid of indexes corresponding to the window - nonzero_knots = tuple([ - np.arange(start, end + 1) for start, end in zip(w_start, w_end) - ]) + nonzero_knots = tuple( + [np.arange(start, end + 1) for start, end in zip(w_start, w_end)] + ) nonzero_knots = tuple(np.meshgrid(*nonzero_knots, indexing="ij")) window = np.array(nonzero_knots).reshape((ndim, -1)) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py new file mode 100644 index 00000000..9de0d2d6 --- /dev/null +++ b/nitransforms/resampling.py @@ -0,0 +1,139 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +"""Resampling utilities.""" +from pathlib import Path +import numpy as np +from nibabel.loadsave import load as _nbload +from scipy import ndimage as ndi + +from nitransforms.base import ( + ImageGrid, + TransformError, + SpatialReference, + _as_homogeneous, +) + + +def apply( + transform, + spatialimage, + reference=None, + order=3, + mode="constant", + cval=0.0, + prefilter=True, + output_dtype=None, +): + """ + Apply a transformation to an image, resampling on the reference spatial object. + + Parameters + ---------- + spatialimage : `spatialimage` + The image object containing the data to be resampled in reference + space + reference : spatial object, optional + The image, surface, or combination thereof containing the coordinates + of samples that will be sampled. + order : int, optional + The order of the spline interpolation, default is 3. + The order has to be in the range 0-5. + mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional + Determines how the input image is extended when the resamplings overflows + a border. Default is 'constant'. + cval : float, optional + Constant value for ``mode='constant'``. Default is 0.0. + prefilter: bool, optional + Determines if the image's data array is prefiltered with + a spline filter before interpolation. The default is ``True``, + which will create a temporary *float64* array of filtered values + if *order > 1*. If setting this to ``False``, the output will be + slightly blurred if *order > 1*, unless the input is prefiltered, + i.e. it is the result of calling the spline filter on the original + input. + output_dtype: dtype specifier, optional + The dtype of the returned array or image, if specified. + If ``None``, the default behavior is to use the effective dtype of + the input image. If slope and/or intercept are defined, the effective + dtype is float64, otherwise it is equivalent to the input image's + ``get_data_dtype()`` (on-disk type). + If ``reference`` is defined, then the return value is an image, with + a data array of the effective dtype but with the on-disk dtype set to + the input image's on-disk dtype. + + Returns + ------- + resampled : `spatialimage` or ndarray + The data imaged after resampling to reference space. + + """ + if reference is not None and isinstance(reference, (str, Path)): + reference = _nbload(str(reference)) + + _ref = ( + transform.reference + if reference is None + else SpatialReference.factory(reference) + ) + + if _ref is None: + raise TransformError("Cannot apply transform without reference") + + if isinstance(spatialimage, (str, Path)): + spatialimage = _nbload(str(spatialimage)) + + data = np.asanyarray(spatialimage.dataobj) + + if data.ndim == 4 and data.shape[-1] != len(transform): + raise ValueError( + "The fourth dimension of the data does not match the tranform's shape." + ) + + if data.ndim < transform.ndim: + data = data[..., np.newaxis] + + # For model-based nonlinear transforms, generate the corresponding dense field + if hasattr(transform, "to_field") and callable(transform.to_field): + targets = ImageGrid(spatialimage).index( + _as_homogeneous( + transform.to_field(reference=reference).map(_ref.ndcoords.T), + dim=_ref.ndim, + ) + ) + else: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + ) + + if transform.ndim == 4: + targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + + resampled = ndi.map_coordinates( + data, + targets, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + + if isinstance(_ref, ImageGrid): # If reference is grid, reshape + hdr = None + if _ref.header is not None: + hdr = _ref.header.copy() + hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype()) + moved = spatialimage.__class__( + resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1,)), + _ref.affine, + hdr, + ) + return moved + + return resampled diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index 07a7e4ec..d32ce7f9 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -4,8 +4,14 @@ import pytest import h5py -from ..base import SpatialReference, SampledSpatialData, ImageGrid, TransformBase +from ..base import ( + SpatialReference, + SampledSpatialData, + ImageGrid, + TransformBase, +) from .. import linear as nitl +from ..resampling import apply def test_SpatialReference(testdata_path): @@ -42,10 +48,10 @@ def test_ImageGrid(get_testdata, image_orientation): ijk = [[10, 10, 10], [40, 4, 20], [0, 0, 0], [s - 1 for s in im.shape[:3]]] xyz = [img._affine.dot(idx + [1])[:-1] for idx in ijk] - assert np.allclose(img.ras(ijk[0]), xyz[0]) + assert np.allclose(np.squeeze(img.ras(ijk[0])), xyz[0]) assert np.allclose(np.round(img.index(xyz[0])), ijk[0]) - assert np.allclose(img.ras(ijk), xyz) - assert np.allclose(np.round(img.index(xyz)), ijk) + assert np.allclose(img.ras(ijk).T, xyz) + assert np.allclose(np.round(img.index(xyz)).T, ijk) # nd index / coords idxs = img.ndindex @@ -91,31 +97,23 @@ def _to_hdf5(klass, x5_root): img = nb.load(fname) imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype()) - # Test identity transform + # Test identity transform - setting reference xfm = TransformBase() - xfm.reference = fname with pytest.raises(TypeError): _ = xfm.ndim - moved = xfm.apply(fname, order=0) - assert np.all( - imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()) - ) - # Test identity transform - setting reference - xfm = TransformBase() + # Test to_filename + xfm.to_filename("data.x5") + + # Test identity transform + xfm = nitl.Affine() xfm.reference = fname - with pytest.raises(TypeError): - _ = xfm.ndim - moved = xfm.apply(str(fname), reference=fname, order=0) - assert np.all( - imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()) - ) + moved = apply(xfm, fname, order=0) + assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())) # Test ndim returned by affine assert nitl.Affine().ndim == 3 - assert nitl.LinearTransformsMapping( - [nitl.Affine(), nitl.Affine()] - ).ndim == 4 + assert nitl.LinearTransformsMapping([nitl.Affine(), nitl.Affine()]).ndim == 4 # Test applying to Gifti gii = nb.gifti.GiftiImage( @@ -126,11 +124,11 @@ def _to_hdf5(klass, x5_root): ) ] ) - giimoved = xfm.apply(fname, reference=gii, order=0) + giimoved = apply(xfm, fname, reference=gii, order=0) assert np.allclose(giimoved.reshape(xfm.reference.shape), moved.get_fdata()) # Test to_filename - xfm.to_filename("data.x5") + xfm.to_filename("data.xfm", fmt="itk") def test_SampledSpatialData(testdata_path): diff --git a/nitransforms/tests/test_io.py b/nitransforms/tests/test_io.py index bcee9198..0cc79d15 100644 --- a/nitransforms/tests/test_io.py +++ b/nitransforms/tests/test_io.py @@ -28,6 +28,8 @@ ) from nitransforms.io.base import LinearParameters, TransformIOError, TransformFileError from nitransforms.conftest import _datadir, _testdir +from nitransforms.resampling import apply + LPS = np.diag([-1, -1, 1, 1]) ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS)) @@ -497,10 +499,13 @@ def test_afni_oblique(tmpdir, parameters, swapaxes, testdata_path, dir_x, dir_y, assert np.allclose(card_aff, nb.load("deob_3drefit.nii.gz").affine) # Check that nitransforms can emulate 3drefit -deoblique - nt3drefit = Affine( - afni._cardinal_rotation(img.affine, False), - reference="deob_3drefit.nii.gz", - ).apply("orig.nii.gz") + nt3drefit = apply( + Affine( + afni._cardinal_rotation(img.affine, False), + reference="deob_3drefit.nii.gz", + ), + "orig.nii.gz", + ) diff = ( np.asanyarray(img.dataobj, dtype="uint8") @@ -509,10 +514,13 @@ def test_afni_oblique(tmpdir, parameters, swapaxes, testdata_path, dir_x, dir_y, assert np.sqrt((diff[10:-10, 10:-10, 10:-10] ** 2).mean()) < 0.1 # Check that nitransforms can revert 3drefit -deoblique - nt_undo3drefit = Affine( - afni._cardinal_rotation(img.affine, True), - reference="orig.nii.gz", - ).apply("deob_3drefit.nii.gz") + nt_undo3drefit = apply( + Affine( + afni._cardinal_rotation(img.affine, True), + reference="orig.nii.gz", + ), + "deob_3drefit.nii.gz", + ) diff = ( np.asanyarray(img.dataobj, dtype="uint8") @@ -531,16 +539,21 @@ def test_afni_oblique(tmpdir, parameters, swapaxes, testdata_path, dir_x, dir_y, assert np.allclose(deobaff, deobnii.affine) # Check resampling in deobliqued grid - ntdeobnii = Affine(np.eye(4), reference=deobnii.__class__( - np.zeros(deobshape, dtype="uint8"), - deobaff, - deobnii.header - )).apply(img, order=0) + ntdeobnii = apply( + Affine(np.eye(4), reference=deobnii.__class__( + np.zeros(deobshape, dtype="uint8"), + deobaff, + deobnii.header + )), + img, + order=0, + ) # Generate an internal box to exclude border effects box = np.zeros(img.shape, dtype="uint8") box[10:-10, 10:-10, 10:-10] = 1 - ntdeobmask = Affine(np.eye(4), reference=ntdeobnii).apply( + ntdeobmask = apply( + Affine(np.eye(4), reference=ntdeobnii), nb.Nifti1Image(box, img.affine, img.header), order=0, ) diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 2957f59c..50cc5371 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -13,6 +13,7 @@ from nibabel.affines import from_matvec from nitransforms import linear as nitl from nitransforms import io +from nitransforms.resampling import apply from .utils import assert_affines_by_filename RMSE_TOL = 0.1 @@ -285,7 +286,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient assert exit_code == 0 sw_moved_mask = nb.load("resampled_brainmask.nii.gz") - nt_moved_mask = xfm.apply(msk, order=0) + nt_moved_mask = apply(xfm, msk, order=0) nt_moved_mask.set_data_dtype(msk.get_data_dtype()) nt_moved_mask.to_filename("ntmask.nii.gz") diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) @@ -305,7 +306,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient sw_moved = nb.load("resampled.nii.gz") sw_moved.set_data_dtype(img.get_data_dtype()) - nt_moved = xfm.apply(img, order=0) + nt_moved = apply(xfm, img, order=0) diff = ( np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) @@ -314,7 +315,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient # A certain tolerance is necessary because of resampling at borders assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL - nt_moved = xfm.apply("img.nii.gz", order=0) + nt_moved = apply(xfm, "img.nii.gz", order=0) diff = ( np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) @@ -343,8 +344,8 @@ def test_LinearTransformsMapping_apply(tmp_path, data_path, testdata_path): assert isinstance(hmc, nitl.LinearTransformsMapping) # Test-case: realign functional data on to sbref - nii = hmc.apply( - testdata_path / "func.nii.gz", order=1, reference=testdata_path / "sbref.nii.gz" + nii = apply( + hmc, testdata_path / "func.nii.gz", order=1, reference=testdata_path / "sbref.nii.gz" ) assert nii.dataobj.shape[-1] == len(hmc) @@ -352,13 +353,17 @@ def test_LinearTransformsMapping_apply(tmp_path, data_path, testdata_path): hmcinv = nitl.LinearTransformsMapping( np.linalg.inv(hmc.matrix), reference=testdata_path / "func.nii.gz" ) - nii = hmcinv.apply(testdata_path / "fmap.nii.gz", order=1) + + nii = apply( + hmcinv, testdata_path / "fmap.nii.gz", order=1 + ) assert nii.dataobj.shape[-1] == len(hmc) # Ensure a ValueError is issued when trying to do weird stuff hmc = nitl.LinearTransformsMapping(hmc.matrix[:1, ...]) with pytest.raises(ValueError): - hmc.apply( + apply( + hmc, testdata_path / "func.nii.gz", order=1, reference=testdata_path / "sbref.nii.gz", diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 93d3fd4c..cfaa12c2 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -8,6 +8,7 @@ import numpy as np import nibabel as nb +from nitransforms.resampling import apply from nitransforms.base import TransformError from nitransforms.io.base import TransformFileError from nitransforms.nonlinear import ( @@ -28,7 +29,7 @@ 3dNwarpApply -nwarp {transform} -source {moving} \ -master {reference} -interp NN -prefix {output} {extra}\ """.format, - 'fsl': """\ + "fsl": """\ applywarp -i {moving} -r {reference} -o {output} {extra}\ -w {transform} --interp=nn""".format, } @@ -38,7 +39,9 @@ def test_itk_disp_load(size): """Checks field sizes.""" with pytest.raises(TransformFileError): - ITKDisplacementsField.from_image(nb.Nifti1Image(np.zeros(size), np.eye(4), None)) + ITKDisplacementsField.from_image( + nb.Nifti1Image(np.zeros(size), np.eye(4), None) + ) @pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3), (20, 20, 20, 1, 4)]) @@ -96,15 +99,17 @@ def test_bsplines_references(testdata_path): ).to_field() with pytest.raises(TransformError): - BSplineFieldTransform( - testdata_path / "someones_bspline_coefficients.nii.gz" - ).apply(testdata_path / "someones_anatomy.nii.gz") + apply( + BSplineFieldTransform( + testdata_path / "someones_bspline_coefficients.nii.gz" + ), + testdata_path / "someones_anatomy.nii.gz", + ) - BSplineFieldTransform( - testdata_path / "someones_bspline_coefficients.nii.gz" - ).apply( + apply( + BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"), testdata_path / "someones_anatomy.nii.gz", - reference=testdata_path / "someones_anatomy.nii.gz" + reference=testdata_path / "someones_anatomy.nii.gz", ) @@ -168,7 +173,7 @@ def test_displacements_field1( nt_moved_mask.set_data_dtype(msk.get_data_dtype()) diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj) - assert np.sqrt((diff ** 2).mean()) < RMSE_TOL + assert np.sqrt((diff**2).mean()) < RMSE_TOL brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool) # Then apply the transform and cross-check with software @@ -177,7 +182,7 @@ def test_displacements_field1( reference=tmp_path / "reference.nii.gz", moving=tmp_path / "reference.nii.gz", output=tmp_path / "resampled.nii.gz", - extra="--output-data-type uchar" if sw_tool == "itk" else "" + extra="--output-data-type uchar" if sw_tool == "itk" else "", ) exit_code = check_call([cmd], shell=True) @@ -188,10 +193,9 @@ def test_displacements_field1( nt_moved.set_data_dtype(nii.get_data_dtype()) nt_moved.to_filename("nt_resampled.nii.gz") sw_moved.set_data_dtype(nt_moved.get_data_dtype()) - diff = ( - np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) - - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) - ) + diff = np.asanyarray( + sw_moved.dataobj, dtype=sw_moved.get_data_dtype() + ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) # A certain tolerance is necessary because of resampling at borders assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL @@ -228,12 +232,11 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool): nt_moved = xfm.apply(img_fname, order=0) nt_moved.to_filename("nt_resampled.nii.gz") sw_moved.set_data_dtype(nt_moved.get_data_dtype()) - diff = ( - np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype()) - - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) - ) + diff = np.asanyarray( + sw_moved.dataobj, dtype=sw_moved.get_data_dtype() + ) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype()) # A certain tolerance is necessary because of resampling at borders - assert np.sqrt((diff ** 2).mean()) < RMSE_TOL + assert np.sqrt((diff**2).mean()) < RMSE_TOL def test_bspline(tmp_path, testdata_path): @@ -247,12 +250,16 @@ def test_bspline(tmp_path, testdata_path): bsplxfm = BSplineFieldTransform(bs_name, reference=img_name) dispxfm = DenseFieldTransform(disp_name) - out_disp = dispxfm.apply(img_name) - out_bspl = bsplxfm.apply(img_name) + out_disp = apply(dispxfm, img_name) + out_bspl = apply(bsplxfm, img_name) out_disp.to_filename("resampled_field.nii.gz") out_bspl.to_filename("resampled_bsplines.nii.gz") - assert np.sqrt( - (out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32")) ** 2 - ).mean() < 0.2 + assert ( + np.sqrt( + (out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32")) + ** 2 + ).mean() + < 0.2 + )