Skip to content

Commit

Permalink
DOC: interpolate: add an example of CT + nearest-neighbor extrapolation
Browse files Browse the repository at this point in the history
The example is contributed by Ajay Shanker Tripathi in scipygh-14386
  • Loading branch information
ev-br committed Oct 28, 2022
1 parent 1124f77 commit c991283
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions doc/source/tutorial/interpolate/extrapolation_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,96 @@ However the basic idea is the same.



.. _tutorial-extrapolation-CT_NN:

Exrapolation in ``D > 1``
=========================

The basic idea of implementing extrapolation manually in a wrapper class
or function can be easily generalized to higher dimensions. As an
example, we consider a C1-smooth interpolation problem of 2D data using
`CloughTocher2DInterpolator`. By default, it fills the out of bounds values
with ``nan``\ s, and we want to instead use for each query
point the value of its nearest neighbor.

Since `CloughTocher2DInterpolator` accepts either 2D data or a Delaunay
triangulation of the data points, the efficient way of finding nearest
neighbors of query points would be to construct the triangulation (using
`scipy.spatial` tools) and use it to find nearest neighbors on the convex hull
of the data.

We will instead use a simpler, naive method and rely on looping over the
whole dataset using NumPy broadcasting.

.. plot::

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CloughTocher2DInterpolator as CT

def my_CT(xy, z):
"""CT interpolator + nearest-neighbor extrapolation.

Parameters
----------
xy : ndarray, shape (npoints, ndim)
Coordinates of data points
z : ndarray, shape (npoints)
Values at data points

Returns
-------
func : callable
A callable object which mirrors the CT behavior,
with an additional neareast-neighbor extrapolation
outside of the data range.
"""
x = xy[:, 0]
y = xy[:, 1]
f = CT(xy, z)

# this inner function will be returned to a user
def new_f(xx, yy):
# evaluate the CT interpolator. Out-of-bounds values are nan.
zz = f(xx, yy)
nans = np.isnan(zz)

if nans.any():
# for each nan point, find its nearest neighbor
inds = np.argmin(
(x[:, None] - xx[nans])**2 +
(y[:, None] - yy[nans])**2
, axis=0)
# ... and use its value
zz[nans] = z[inds]
return zz

return new_f

# Now illustrate the difference between the original ``CT`` interpolant
# and ``my_CT`` on a small example:

x = np.array([1, 1, 1, 2, 2, 2, 4, 4, 4])
y = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
z = np.array([0, 7, 8, 3, 4, 7, 1, 3, 4])

xy = np.c_[x, y]
lut = CT(xy, z)
lut2 = my_CT(xy, z)

X = np.linspace(min(x) - 0.5, max(x) + 0.5, 71)
Y = np.linspace(min(y) - 0.5, max(y) + 0.5, 71)
X, Y = np.meshgrid(X, Y)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.plot_wireframe(X, Y, lut(X, Y), label='CT')
ax.plot_wireframe(X, Y, lut2(X, Y), color='m',
cstride=10, rstride=10, alpha=0.7, label='CT + n.n.')

ax.scatter(x, y, z, 'o', color='k', s=48, label='data')
ax.legend()
plt.tight_layout()


0 comments on commit c991283

Please sign in to comment.