Skip to content

Commit

Permalink
switching to pure-theano impl of regular grid interp (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Apr 22, 2021
1 parent d3828e8 commit 5bd4e84
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 473 deletions.
93 changes: 59 additions & 34 deletions src/exoplanet/interp.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,84 @@
# -*- coding: utf-8 -*-

__all__ = ["RegularGridInterpolator"]
__all__ = ["regular_grid_interp", "RegularGridInterpolator"]

import itertools

import aesara_theano_fallback.tensor as tt
from aesara_theano_fallback import aesara as theano

from .theano_ops.interp import RegularGridOp
from .utils import as_tensor_variable


class RegularGridInterpolator:
"""Linear interpolation on a regular grid in arbitrary dimensions
def regular_grid_interp(points, values, coords):
"""Perform a linear interpolation in N-dimensions w a regular grid
The data must be defined on a filled regular grid, but the spacing may be
uneven in any of the dimensions.
This implementation is based on the implementation in the
``scipy.interpolate.RegularGridInterpolator`` class which, in turn, is
based on the implementation from Johannes Buchner's ``regulargrid``
package https://github.com/JohannesBuchner/regulargrid.
Args:
points: A list of vectors with shapes ``(m1,), ... (mn,)``. These
define the grid points in each dimension.
values: A tensor defining the values at each point in the grid
defined by ``points``. This must have the shape
``(m1, ... mn, ..., nout)``.
xi: A matrix defining the coordinates where the interpolation
coords: A matrix defining the coordinates where the interpolation
should be evaluated. This must have the shape ``(ntest, ndim)``.
check_sorted: If ``True`` (default), check that the tensors in
``points`` are all sorted in ascending order. This can be set to
``False`` if the axes are known to be sorted, but the results will
be unpredictable if this ends up being wrong.
bounds_error: If ``False`` (default) extrapolate beyond the edges of
the grid. Otherwise raise an exception.
nout: An integer indicating the number of outputs if known at compile
time. The default is to allow any number of outputs, but
performance can be better if this is provided.
"""
points = [as_tensor_variable(p) for p in points]
ndim = len(points)
values = as_tensor_variable(values)
coords = as_tensor_variable(coords)

def __init__(
self, points, values, check_sorted=True, bounds_error=False, nout=-1
):
self.ndim = len(points)
self.nout = int(nout)
# Find where the points should be inserted
indices = []
norm_distances = []
for n, grid in enumerate(points):
x = coords[..., n]
i = tt.extra_ops.searchsorted(grid, x) - 1
i = tt.clip(i, 0, grid.shape[0] - 2)
indices.append(i)
norm_distances.append((x - grid[i]) / (grid[i + 1] - grid[i]))

result = tt.zeros(tuple(coords.shape[:-1]) + tuple(values.shape[ndim:]))
for edge_indices in itertools.product(*((i, i + 1) for i in indices)):
weight = tt.ones(coords.shape[:-1])
for ei, i, yi in zip(edge_indices, indices, norm_distances):
weight *= tt.where(tt.eq(ei, i), 1 - yi, yi)
result += values[edge_indices] * weight
return result

self.points = [theano.shared(p) for p in points]
self.values = theano.shared(values)
if self.values.ndim == self.ndim:
self.values = tt.shape_padright(self.values)

self.check_sorted = bool(check_sorted)
self.bounds_error = bool(bounds_error)
class RegularGridInterpolator:
"""Linear interpolation on a regular grid in arbitrary dimensions
self.interp_op = RegularGridOp(
self.ndim,
nout=self.nout,
check_sorted=self.check_sorted,
bounds_error=self.bounds_error,
)
The data must be defined on a filled regular grid, but the spacing may be
uneven in any of the dimensions.
Args:
points: A list of vectors with shapes ``(m1,), ... (mn,)``. These
define the grid points in each dimension.
values: A tensor defining the values at each point in the grid
defined by ``points``. This must have the shape
``(m1, ... mn, ..., nout)``.
"""

def __init__(self, points, values, **kwargs):
self.ndim = len(points)
self.points = points
self.values = values

def evaluate(self, t):
return self.interp_op(t, self.values, *self.points)[0]
"""Interpolate the data
Args:
t: A matrix defining the coordinates where the interpolation
should be evaluated. This must have the shape
``(ntest, ndim)``.
"""
return regular_grid_interp(self.points, self.values, t)
4 changes: 2 additions & 2 deletions src/exoplanet/theano_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

__all__ = ["kepler", "contact_points", "starry", "interp"]
__all__ = ["contact_points", "kepler", "starry"]

from . import contact_points, interp, kepler, starry
from . import contact_points, kepler, starry
35 changes: 0 additions & 35 deletions src/exoplanet/theano_ops/build_utils.py

This file was deleted.

5 changes: 0 additions & 5 deletions src/exoplanet/theano_ops/interp/__init__.py

This file was deleted.

88 changes: 0 additions & 88 deletions src/exoplanet/theano_ops/interp/include/theano_helpers.h

This file was deleted.

0 comments on commit 5bd4e84

Please sign in to comment.