Skip to content

Commit

Permalink
Adding fill_value to interpolator (#231)
Browse files Browse the repository at this point in the history
* adding fill value to interpolator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixing fill_value logic

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] committed Sep 16, 2021
1 parent 40f973b commit 26debab
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/exoplanet/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import itertools

import aesara_theano_fallback.tensor as tt
import numpy as np

from .utils import as_tensor_variable


def regular_grid_interp(points, values, coords):
def regular_grid_interp(points, values, coords, *, fill_value=None):
"""Perform a linear interpolation in N-dimensions w a regular grid
The data must be defined on a filled regular grid, but the spacing may be
Expand Down Expand Up @@ -38,9 +39,11 @@ def regular_grid_interp(points, values, coords):
# Find where the points should be inserted
indices = []
norm_distances = []
out_of_bounds = tt.zeros(coords.shape[:-1], dtype=bool)
for n, grid in enumerate(points):
x = coords[..., n]
i = tt.extra_ops.searchsorted(grid, x) - 1
out_of_bounds |= (i < 0) | (i >= grid.shape[0] - 1)
i = tt.clip(i, 0, grid.shape[0] - 2)
indices.append(i)
norm_distances.append((x - grid[i]) / (grid[i + 1] - grid[i]))
Expand All @@ -51,6 +54,10 @@ def regular_grid_interp(points, values, coords):
for ei, i, yi in zip(edge_indices, indices, norm_distances):
weight *= tt.where(tt.eq(ei, i), 1 - yi, yi)
result += values[edge_indices] * weight

if fill_value is not None:
result = tt.switch(out_of_bounds, fill_value, result)

return result


Expand All @@ -68,10 +75,11 @@ class RegularGridInterpolator:
``(m1, ... mn, ..., nout)``.
"""

def __init__(self, points, values, **kwargs):
def __init__(self, points, values, fill_value=None, **kwargs):
self.ndim = len(points)
self.points = points
self.values = values
self.fill_value = fill_value

def evaluate(self, t):
"""Interpolate the data
Expand All @@ -81,4 +89,6 @@ def evaluate(self, t):
should be evaluated. This must have the shape
``(ntest, ndim)``.
"""
return regular_grid_interp(self.points, self.values, t)
return regular_grid_interp(
self.points, self.values, t, fill_value=self.fill_value
)
16 changes: 16 additions & 0 deletions tests/interp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,19 @@ def f(x, y, z):
op = RegularGridInterpolator((x, y, z), data)
f = np.squeeze(op.evaluate(pts).eval())
assert np.allclose(f, f0)


def test_fill_value():
def f(x, y, z):
return 2 * x ** 3 + 3 * y ** 2 - z

x = np.linspace(1, 4, 11)
y = np.linspace(4, 7, 22)
z = np.linspace(7, 9, 33)

data = f(*np.meshgrid(x, y, z, indexing="ij", sparse=True))
pts = np.array([[0.1, 6.2, 8.3], [3.3, 5.2, 10.1]])

op = RegularGridInterpolator((x, y, z), data, fill_value=np.nan)
f = op.evaluate(pts).eval()
assert np.all(np.isnan(f))

0 comments on commit 26debab

Please sign in to comment.