Skip to content

Commit

Permalink
Merge 1cd8bba into 3c06229
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-brett committed Mar 2, 2021
2 parents 3c06229 + 1cd8bba commit e522ad7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
54 changes: 19 additions & 35 deletions nipy/algorithms/interpolation.py
Expand Up @@ -4,25 +4,25 @@
"""
from __future__ import absolute_import

import os
import tempfile

from distutils.version import LooseVersion as LV

import numpy as np

import scipy
from scipy import ndimage

from ..fixes.scipy.ndimage import map_coordinates
from ..utils import seq_prod


# Earlier versions of Scipy don't have mode for spline_filter
SPLINE_FILTER_HAS_MODE = True
# Ensure bare call does not cause error.
ndimage.spline_filter(np.zeros((2,2)))
try: # Does adding mode cause error?
ndimage.spline_filter(np.zeros((2,2)), mode='constant')
except TypeError:
SPLINE_FILTER_HAS_MODE = False
SCIPY_VERSION = LV(scipy.__version__)
SPLINE_FILTER_HAS_MODE = SCIPY_VERSION >= LV('1.2')
# Fixes in interpolation in scipy >= 1.6 force pre-padding
# in knot calculations.
SPLINE_FILTER_NEEDS_PAD = SCIPY_VERSION >= LV('1.6')


class ImageInterpolator(object):
Expand All @@ -33,7 +33,7 @@ class ImageInterpolator(object):

# Padding for prefilter calculation in 'nearest' and 'grid-constant' mode.
# See: https://github.com/scipy/scipy/issues/13600
n_prepad_if_needed = 12
n_prepad_if_needed = 12 if SPLINE_FILTER_NEEDS_PAD else 0

def __init__(self, image, order=3, mode='constant', cval=0.0):
"""
Expand Down Expand Up @@ -75,40 +75,24 @@ def order(self):
return self._order

def _buildknots(self):
data = np.nan_to_num(self.image.get_data()).astype(np.float64)
if self.order > 1:
in_data = np.nan_to_num(self.image.get_data())
if self.mode in ('nearest', 'grid-constant'):
# See: https://github.com/scipy/scipy/issues/13600
self._n_prepad = self.n_prepad_if_needed
in_data = np.pad(in_data, self._n_prepad, mode='edge')
if self._n_prepad != 0:
data = np.pad(data, self._n_prepad, mode='edge')
kwargs = dict(order=self.order)
if SPLINE_FILTER_HAS_MODE:
kwargs['mode'] = self.mode
data = ndimage.spline_filter(in_data, **kwargs)
else:
data = np.nan_to_num(self.image.get_data())
if self._datafile is None:
_, fname = tempfile.mkstemp()
self._datafile = open(fname, mode='wb')
else:
self._datafile = open(self._datafile.name, 'wb')
data = np.nan_to_num(data.astype(np.float64))
data = ndimage.spline_filter(data, **kwargs)
self._datafile = tempfile.TemporaryFile()
data.tofile(self._datafile)
datashape = data.shape
dtype = data.dtype
self._data = np.memmap(self._datafile,
dtype=data.dtype,
mode='r+',
shape=data.shape)
del(data)
self._datafile.close()
self._datafile = open(self._datafile.name)
self._data = np.memmap(self._datafile.name, dtype=dtype,
mode='r+', shape=datashape)

def __del__(self):
if self._datafile:
self._datafile.close()
try:
os.remove(self._datafile.name)
except:
pass

def evaluate(self, points):
""" Resample image at points in world space
Expand All @@ -133,7 +117,7 @@ def evaluate(self, points):
order=self.order,
mode=self.mode,
cval=self.cval,
prefilter=False)
prefilter=self.order < 2)
# ndimage.map_coordinates returns a flat array,
# it needs to be reshaped to the original shape
V.shape = output_shape
Expand Down
43 changes: 40 additions & 3 deletions nipy/algorithms/tests/test_interpolator.py
@@ -1,8 +1,12 @@
""" Testing interpolation module
"""

from itertools import product

import numpy as np

from scipy.ndimage import map_coordinates

from nipy.core.api import Image, vox2mni

from ..interpolation import ImageInterpolator
Expand Down Expand Up @@ -38,17 +42,50 @@ def test_interp_obj():


def test_interpolator():
arr = np.arange(24).reshape((2, 3, 4))
shape = (2, 3, 4)
arr = np.arange(24).reshape(shape)
coordmap = vox2mni(np.eye(4))
img = Image(arr, coordmap)
isx = np.indices(arr.shape)
ixs = np.indices(arr.shape).astype(float)
for order in range(5):
interp = ImageInterpolator(img, mode='nearest', order=order)
# Interpolate at existing points.
assert_almost_equal(interp.evaluate(isx), arr)
assert_almost_equal(interp.evaluate(ixs), arr)
# Interpolate at half voxel shift
if order == 2:
continue
ixs_x_shift = ixs.copy()
# Interpolate inside and outside at knots
ixs_x_shift[0] += 1
res = interp.evaluate(ixs_x_shift)
assert_almost_equal(res, np.tile(arr[1], (2, 1, 1)))
ixs_x_shift[0] -= 2
res = interp.evaluate(ixs_x_shift)
assert_almost_equal(res, np.tile(arr[0], (2, 1, 1)))
# Interpolate at mid-points inside and outside
ixs_x_shift[0] += 0.5
res = interp.evaluate(ixs_x_shift)
# Check inside.
mid_arr = np.mean(arr, axis=0) if order > 0 else arr[1]
assert_almost_equal(res[1], mid_arr)
# Interpolate off top right corner with different modes
assert_almost_equal(interp.evaluate([0, 0, 4]), arr[0, 0, -1])
interp = ImageInterpolator(img, mode='constant', order=order, cval=0)
assert_array_equal(interp.evaluate([0, 0, 4]), 0)
interp = ImageInterpolator(img, mode='constant', order=order, cval=1)
assert_array_equal(interp.evaluate([0, 0, 4]), 1)
# Check against direct ndimage interpolation
# Need floating point input array to replicate
# our floating point backing store.
farr = arr.astype(float)
for offset, axis, mode in product(np.linspace(-2, 2, 15),
range(3),
('nearest', 'constant')):
interp = ImageInterpolator(img, mode=mode, order=order)
coords = ixs.copy()
slicer = tuple(None if i == axis else 0 for i in range(3))
coords[slicer] = coords[slicer] + offset
actual = interp.evaluate(coords)
expected = map_coordinates(farr, coords, mode=mode, order=order)
assert_almost_equal(actual, expected)
del interp

0 comments on commit e522ad7

Please sign in to comment.