In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

from cell_inference.config import paths, params

In [2]:
from scipy.spatial import qhull

def interp_weights(points, xi):
    """Calculate the indices of the vertices of the enclosing simplex and the weights for the interpolation"""
    points = np.asarray(points)
    d = points.shape[1]
    tri = qhull.Delaunay(points)
    simplex = tri.find_simplex(xi)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = xi - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    weights = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))
    return vertices, weights

def interpolate(values, vertices, weights, fill_value=None):
    """Calculate interpolation values"""
    vi = np.einsum('nj,nj->n', np.take(values, vertices), weights)
    if fill_value is not None:
        vi[np.any(weights < 0, axis=1)] = fill_value
    return vi

In [3]:
elec_pos = params.ELECTRODE_POSITION[:, :2]
elec_grid = params.ELECTRODE_GRID
xx, yy = np.meshgrid(elec_grid[0], elec_grid[1], indexing='ij')
grid = np.column_stack((xx.ravel(), yy.ravel()))
grid = grid[(grid[:, 1]>=-params.Y_WINDOW_SIZE/2) & (grid[:, 1]<=params.Y_WINDOW_SIZE/2) ,:]

print(elec_pos.shape)
print(grid.shape)

(384, 2)
(192, 2)


In [4]:
f = np.random.randn(params.WINDOW_SIZE, elec_pos.shape[0])

In [5]:
vtx, wts = interp_weights(elec_pos, grid)

In [6]:
np.allclose(interpolate(f[0, :], vtx, wts), griddata(elec_pos, f[0, :], grid))

True

In [7]:
%timeit for i in range(params.WINDOW_SIZE): griddata(elec_pos, f[i, :], grid)

1.87 s ± 68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit interp_weights(elec_pos, grid)

12 ms ± 571 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%timeit for i in range(params.WINDOW_SIZE): interpolate(f[i, :], vtx, wts)

3.48 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
