Skip to content

Commit

Permalink
Merge pull request #7864 from andfoy/add_neighbors_rbf
Browse files Browse the repository at this point in the history
Add neighbors option to RbfInterpolator
  • Loading branch information
takagi committed Mar 11, 2024
2 parents 272dfda + 98c0df6 commit 640a834
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 45 deletions.
51 changes: 47 additions & 4 deletions cupyx/scipy/interpolate/_rbfinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import combinations_with_replacement

import cupy as cp
from cupyx.scipy.spatial import KDTree


# Define the kernel functions.
Expand Down Expand Up @@ -652,7 +653,6 @@ def __init__(self, y, d,
if neighbors is None:
nobs = ny
else:
raise NotImplementedError("neighbors is not implemented yet")
# Make sure the number of nearest neighbors used for interpolation
# does not exceed the number of observations.
neighbors = int(min(neighbors, ny))
Expand All @@ -679,8 +679,7 @@ def __init__(self, y, d,
self._coeffs = coeffs

else:
raise NotImplementedError
# self._tree = KDTree(y)
self._tree = KDTree(y)

self.y = y
self.d = d
Expand Down Expand Up @@ -783,7 +782,51 @@ def __call__(self, x):
self._scale,
self._coeffs, memory_budget=memory_budget)
else:
raise NotImplementedError # XXX: needs KDTree
# Get the indices of the k nearest observation points to each
# evaluation point.
_, yindices = self._tree.query(x, self.neighbors)

if self.neighbors == 1:
# `KDTree` squeezes the output when neighbors=1.
yindices = yindices[:, None]

# Multiple evaluation points may have the same neighborhood of
# observation points. Make the neighborhoods unique so that we only
# compute the interpolation coefficients once for each
# neighborhood.
yindices = cp.sort(yindices, axis=1)
yindices, inv = cp.unique(yindices, return_inverse=True, axis=0)
# `inv` tells us which neighborhood will be used by each evaluation
# point. Now we find which evaluation points will be using each
# neighborhood.
xindices = [[] for _ in range(len(yindices))]
for i, j in enumerate(inv.tolist()):
xindices[j].append(i)

out = cp.empty((nx, self.d.shape[1]), dtype=float)
for xidx, yidx in zip(xindices, yindices):
# `yidx` are the indices of the observations in this
# neighborhood. `xidx` are the indices of the evaluation points
# that are using this neighborhood.
xnbr = x[xidx]
ynbr = self.y[yidx]
dnbr = self.d[yidx]
snbr = self.smoothing[yidx]
shift, scale, coeffs = _build_and_solve_system(
ynbr,
dnbr,
snbr,
self.kernel,
self.epsilon,
self.powers,
)
out[xidx] = self._chunk_evaluator(
xnbr,
ynbr,
shift,
scale,
coeffs,
memory_budget=memory_budget)

out = out.view(self.d_dtype)
out = out.reshape((nx, ) + self.d_shape)
Expand Down
74 changes: 33 additions & 41 deletions tests/cupyx_tests/scipy_tests/interpolate_tests/test_rbfinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def test_conditionally_positive_definite(xp, scp, kernel):

@testing.with_requires("scipy>=1.7.0")
class _TestRBFInterpolator:
@testing.numpy_cupy_allclose(scipy_name='scp')
rtol = 5e-5
atol = 5e-5

@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
@pytest.mark.parametrize('kernel', sorted(_SCALE_INVARIANT))
def test_scale_invariance_1d(self, xp, scp, kernel):
# Verify that the functions in _SCALE_INVARIANT are insensitive to the
Expand All @@ -127,7 +130,7 @@ def test_scale_invariance_1d(self, xp, scp, kernel):
yitp2 = self.build(scp, x, y, epsilon=2.0, kernel=kernel)(xitp)
return yitp1, yitp2

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
@pytest.mark.parametrize('kernel', sorted(_SCALE_INVARIANT))
def test_scale_invariance_2d(self, xp, scp, kernel):
# Verify that the functions in _SCALE_INVARIANT are insensitive to the
Expand All @@ -140,7 +143,7 @@ def test_scale_invariance_2d(self, xp, scp, kernel):
yitp2 = self.build(scp, x, y, epsilon=2.0, kernel=kernel)(xitp)
return yitp1, yitp2

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
@pytest.mark.parametrize('kernel', sorted(_AVAILABLE))
def test_extreme_domains(self, xp, scp, kernel):
# Make sure the interpolant remains numerically stable for very
Expand Down Expand Up @@ -169,7 +172,7 @@ def test_extreme_domains(self, xp, scp, kernel):

return yitp1, yitp2

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
def test_polynomial_reproduction(self, xp, scp):
# If the observed data comes from a polynomial, then the interpolant
# should be able to reproduce the polynomial exactly, provided that
Expand Down Expand Up @@ -197,7 +200,7 @@ def test_polynomial_reproduction(self, xp, scp):

return yitp1, yitp2

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
def test_vector_data(self, xp, scp):
# Make sure interpolating a vector field is the same as interpolating
# each component separately.
Expand All @@ -215,7 +218,7 @@ def test_vector_data(self, xp, scp):

return yitp1, yitp2, yitp3

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
def test_complex_data(self, xp, scp):
# Interpolating complex input should be the same as interpolating the
# real and complex components.
Expand All @@ -232,7 +235,7 @@ def test_complex_data(self, xp, scp):

return yitp1, yitp2, yitp3

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize('kernel', sorted(_AVAILABLE))
def test_interpolation_misfit_1d(self, xp, scp, kernel):
# Make sure that each kernel, with its default `degree` and an
Expand All @@ -250,7 +253,7 @@ def test_interpolation_misfit_1d(self, xp, scp, kernel):
assert mse < 1.0e-4
return yitp

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
@pytest.mark.parametrize('kernel', sorted(_AVAILABLE))
def test_interpolation_misfit_2d(self, xp, scp, kernel):
# Make sure that each kernel, with its default `degree` and an
Expand All @@ -268,7 +271,7 @@ def test_interpolation_misfit_2d(self, xp, scp, kernel):
assert mse < 2.0e-4
return yitp

@testing.numpy_cupy_allclose(scipy_name='scp', atol=1e-8)
@testing.numpy_cupy_allclose(scipy_name='scp', atol=5e-2, rtol=5e-2)
@pytest.mark.parametrize('kernel', sorted(_AVAILABLE))
def test_smoothing_misfit(self, xp, scp, kernel):
# Make sure we can find a smoothing parameter for each kernel that
Expand Down Expand Up @@ -299,7 +302,7 @@ def test_smoothing_misfit(self, xp, scp, kernel):
assert rmse_within_tol
return ysmooth

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
def test_array_smoothing(self, xp, scp):
# Test using an array for `smoothing` to give less weight to a known
# outlier.
Expand Down Expand Up @@ -401,7 +404,7 @@ def test_rank_error(self, xp, scp):
with cupyx.errstate(linalg='raise'):
self.build(scp, y, d, kernel='thin_plate_spline')(y)

@testing.numpy_cupy_allclose(scipy_name='scp')
@testing.numpy_cupy_allclose(scipy_name='scp', rtol=rtol, atol=atol)
@pytest.mark.parametrize('dim', [1, 2, 3])
def test_single_point(self, xp, scp, dim):
# Make sure interpolation still works with only one point (in 1, 2, and
Expand Down Expand Up @@ -529,48 +532,37 @@ def _chunk_evaluator(*args, **kwargs):
testing.assert_allclose(yitp1, yitp2, atol=1e-8)


"""
# Disable `all neighbors not None` tests : they need KDTree
class TestRBFInterpolatorNeighbors20(_TestRBFInterpolator):
# RBFInterpolator using 20 nearest neighbors.
def build(self, *args, **kwargs):
return RBFInterpolator(*args, **kwargs, neighbors=20)
def build(self, scp, *args, **kwargs):
return scp.interpolate.RBFInterpolator(*args, **kwargs, neighbors=20)

def test_equivalent_to_rbf_interpolator(self):
@testing.numpy_cupy_allclose(scipy_name='scp')
def test_equivalent_to_rbf_interpolator(self, xp, scp):
seq = Halton(2, scramble=False, seed=_np.random.RandomState())

x = cp.asarray(seq.random(100))
xitp = cp.asarray(seq.random(100))
y = _2d_test_function(x)
yitp1 = self.build(x, y)(xitp)
yitp2 = []
tree = cKDTree(x)
for xi in xitp:
_, nbr = tree.query(xi, 20)
yitp2.append(RBFInterpolator(x[nbr], y[nbr])(xi[None])[0])
x = xp.asarray(seq.random(100))
xitp = xp.asarray(seq.random(100))

assert_allclose(yitp1, yitp2, atol=1e-8)
y = _2d_test_function(x, xp)
yitp1 = self.build(scp, x, y)(xitp)
return yitp1


class TestRBFInterpolatorNeighborsInf(TestRBFInterpolatorNeighborsNone):
# RBFInterpolator using neighbors=np.inf. This should give exactly the same
# results as neighbors=None, but it will be slower.
def build(self, *args, **kwargs):
return RBFInterpolator(*args, **kwargs, neighbors=cp.inf)
def build(self, scp, *args, **kwargs):
return scp.interpolate.RBFInterpolator(
*args, **kwargs, neighbors=cp.inf)

def test_equivalent_to_rbf_interpolator(self):
@testing.numpy_cupy_allclose(scipy_name='scp')
def test_equivalent_to_rbf_interpolator(self, xp, scp):
seq = Halton(1, scramble=False, seed=_np.random.RandomState())

x = cp.asarray(3*seq.random(50))
xitp = cp.asarray(3*seq.random(50))
y = _1d_test_function(x)
yitp1 = self.build(x, y)(xitp)
yitp2 = RBFInterpolator(x, y)(xitp)
x = xp.asarray(3*seq.random(50))
xitp = xp.asarray(3*seq.random(50))

assert_allclose(yitp1, yitp2, atol=1e-8)
"""
y = _1d_test_function(x, xp)
yitp1 = self.build(scp, x, y)(xitp)
return yitp1

0 comments on commit 640a834

Please sign in to comment.