-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
switching to pure-theano impl of regular grid interp (#167)
- Loading branch information
Showing
7 changed files
with
61 additions
and
473 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,84 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
__all__ = ["RegularGridInterpolator"] | ||
__all__ = ["regular_grid_interp", "RegularGridInterpolator"] | ||
|
||
import itertools | ||
|
||
import aesara_theano_fallback.tensor as tt | ||
from aesara_theano_fallback import aesara as theano | ||
|
||
from .theano_ops.interp import RegularGridOp | ||
from .utils import as_tensor_variable | ||
|
||
|
||
class RegularGridInterpolator: | ||
"""Linear interpolation on a regular grid in arbitrary dimensions | ||
def regular_grid_interp(points, values, coords): | ||
"""Perform a linear interpolation in N-dimensions w a regular grid | ||
The data must be defined on a filled regular grid, but the spacing may be | ||
uneven in any of the dimensions. | ||
This implementation is based on the implementation in the | ||
``scipy.interpolate.RegularGridInterpolator`` class which, in turn, is | ||
based on the implementation from Johannes Buchner's ``regulargrid`` | ||
package https://github.com/JohannesBuchner/regulargrid. | ||
Args: | ||
points: A list of vectors with shapes ``(m1,), ... (mn,)``. These | ||
define the grid points in each dimension. | ||
values: A tensor defining the values at each point in the grid | ||
defined by ``points``. This must have the shape | ||
``(m1, ... mn, ..., nout)``. | ||
xi: A matrix defining the coordinates where the interpolation | ||
coords: A matrix defining the coordinates where the interpolation | ||
should be evaluated. This must have the shape ``(ntest, ndim)``. | ||
check_sorted: If ``True`` (default), check that the tensors in | ||
``points`` are all sorted in ascending order. This can be set to | ||
``False`` if the axes are known to be sorted, but the results will | ||
be unpredictable if this ends up being wrong. | ||
bounds_error: If ``False`` (default) extrapolate beyond the edges of | ||
the grid. Otherwise raise an exception. | ||
nout: An integer indicating the number of outputs if known at compile | ||
time. The default is to allow any number of outputs, but | ||
performance can be better if this is provided. | ||
""" | ||
points = [as_tensor_variable(p) for p in points] | ||
ndim = len(points) | ||
values = as_tensor_variable(values) | ||
coords = as_tensor_variable(coords) | ||
|
||
def __init__( | ||
self, points, values, check_sorted=True, bounds_error=False, nout=-1 | ||
): | ||
self.ndim = len(points) | ||
self.nout = int(nout) | ||
# Find where the points should be inserted | ||
indices = [] | ||
norm_distances = [] | ||
for n, grid in enumerate(points): | ||
x = coords[..., n] | ||
i = tt.extra_ops.searchsorted(grid, x) - 1 | ||
i = tt.clip(i, 0, grid.shape[0] - 2) | ||
indices.append(i) | ||
norm_distances.append((x - grid[i]) / (grid[i + 1] - grid[i])) | ||
|
||
result = tt.zeros(tuple(coords.shape[:-1]) + tuple(values.shape[ndim:])) | ||
for edge_indices in itertools.product(*((i, i + 1) for i in indices)): | ||
weight = tt.ones(coords.shape[:-1]) | ||
for ei, i, yi in zip(edge_indices, indices, norm_distances): | ||
weight *= tt.where(tt.eq(ei, i), 1 - yi, yi) | ||
result += values[edge_indices] * weight | ||
return result | ||
|
||
self.points = [theano.shared(p) for p in points] | ||
self.values = theano.shared(values) | ||
if self.values.ndim == self.ndim: | ||
self.values = tt.shape_padright(self.values) | ||
|
||
self.check_sorted = bool(check_sorted) | ||
self.bounds_error = bool(bounds_error) | ||
class RegularGridInterpolator: | ||
"""Linear interpolation on a regular grid in arbitrary dimensions | ||
self.interp_op = RegularGridOp( | ||
self.ndim, | ||
nout=self.nout, | ||
check_sorted=self.check_sorted, | ||
bounds_error=self.bounds_error, | ||
) | ||
The data must be defined on a filled regular grid, but the spacing may be | ||
uneven in any of the dimensions. | ||
Args: | ||
points: A list of vectors with shapes ``(m1,), ... (mn,)``. These | ||
define the grid points in each dimension. | ||
values: A tensor defining the values at each point in the grid | ||
defined by ``points``. This must have the shape | ||
``(m1, ... mn, ..., nout)``. | ||
""" | ||
|
||
def __init__(self, points, values, **kwargs): | ||
self.ndim = len(points) | ||
self.points = points | ||
self.values = values | ||
|
||
def evaluate(self, t): | ||
return self.interp_op(t, self.values, *self.points)[0] | ||
"""Interpolate the data | ||
Args: | ||
t: A matrix defining the coordinates where the interpolation | ||
should be evaluated. This must have the shape | ||
``(ntest, ndim)``. | ||
""" | ||
return regular_grid_interp(self.points, self.values, t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
__all__ = ["kepler", "contact_points", "starry", "interp"] | ||
__all__ = ["contact_points", "kepler", "starry"] | ||
|
||
from . import contact_points, interp, kepler, starry | ||
from . import contact_points, kepler, starry |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.