In [5]:
# Prep...
from numba import njit # important if we want to keep this in Python (which seems to be an (un?)stated goal)
import numpy as np # we have to use numpy arrays to jit functions on arrays

# G2EGM - interpolation
The previous notebook considered a very small subset of the G2EGM step, handling the barycentric weights. Here, we'll make some changes to the way the weights are computed to accomodate efficient computation in the total G2EGM step.

As noted in a comment in the previous notebook, the inverse of the denominator can be precomputed. It's a detail, but it's worth it to make the overall performance better, since the reinterpolation step consists of: calculating the inverse (general for the segment across points in the common grid), calculating weights for the current point ($x$) in the common grid, and an inner product between the policy at the simplex an the $x$-specific weights. Since the first step involves a division, we really don't want to recompute it for all $x\in \mathbb{G}$ (the common grid).

We create a function that takes in the simplex and returns the inverse of the denominator of the weights.

In [25]:
@njit
def BarycentricInverseDenominator(ABC):
   # (nX,mX) are simply the two states in the "triangle"/R^2 case at vertex X
    nA, mA = ABC[0][0], ABC[0][1]
    nB, mB = ABC[1][0], ABC[1][1]
    nC, mC = ABC[2][0], ABC[2][1]
    
    # inverse of denominator for barycentric weights
    return 1/((nB - nC)*(mA - mC) + (mC - mB)*(nA - nC))


This will hopefully be worth it, and probably will. Then, the function to calculate the weights looks as below.

In [26]:
@njit
def BarycentricWeights(ABC, x):
    
    # get inverse denominator
    inv_denom = BarycentricInverseDenominator(ABC)

    # (nX,mX) are simply the two states in the "triangle"/R^2 case at vertex X
    nA, mA = ABC[0][0], ABC[0][1]
    nB, mB = ABC[1][0], ABC[1][1]
    nC, mC = ABC[2][0], ABC[2][1]
    
    # (n,m) is the common grid point in R^2
    n, m = x[0], x[1]
    
    wA = ((nB - nC)*(m - mC) + (mC - mB)*(n - nC))*inv_denom
    wB = ((nC - nA)*(m - mC) + (mA - mC)*(n - nC))*inv_denom

    # tuple for generic capabilities in function that consumes
    # this output
    return (wA, wB, 1-wA-wB)

As we can see, we now return a tuple (different from last time), as this allows for generic handling of the interpolation below.

In [27]:
def BarycentricInterp(Simplex, SimplexValues, x):
    Weights = BarycentricWeights(Simplex, x)
    
    return np.inner(Weights, SimplexValues)

In [28]:
ASimplex = np.array([[1.0,3.0],[3.,4.0],[1.5,3.0]])
ASimplex = np.array([[1.0,3.0],[3.,4.0],[1.5,3.0]])
print(ASimplex)

[[1.  3. ]
 [3.  4. ]
 [1.5 3. ]]


In [31]:
# Check the simplex interpolated values (should equal the second inputs)
BarycentricInterp(ASimplex, np.array([1.0,2.0,3.0]), np.array([1,3]))

1.0

In [43]:
# Check the simplex interpolated values (should equal the second input to the function)
SimplexValues = np.array([1.0,2.0,3.0])
for i in range(3):
    print("Values match at index ",i,": ", BarycentricInterp(ASimplex, SimplexValues, ASimplex[i]) == SimplexValues[i])

Values match at index  0 :  True
Values match at index  1 :  True
Values match at index  2 :  True


In [49]:
# Check simple combined values (should equal the second input to the function)
SimplexValues = np.array([1.0,2.0,3.0])
for i in range(2):
    print("Values match at max ",i+1,": ", BarycentricInterp(ASimplex, SimplexValues, (ASimplex[i]+ASimplex[i+1])/2) == (SimplexValues[i]+SimplexValues[i+1])/2)

Values match at max  1 :  True
Values match at max  2 :  True


The fact that inverses can be precomputed, and actually also `(nB - nC)` etc, means that we should either structure our code in a way that relies on small helper functions, or more naturally as an object. Though, normal objects cannot be used if we want to jit our functions, as `self` will trigger python-mode, and kill performance, so we'll have to work with `jitclass`

In [52]:
from numba import jitclass
from numba import float64

In [132]:
FieldSpec = [
    ('InvDenominator', float64),
    ('dBC', float64),
    ('dCB', float64),
    ('dCA', float64),
    ('dAC', float64),
    ('Simplex', float64[:, :]),
    ('Values', float64[:])
]
@jitclass(FieldSpec)
class BarycentricInterpolant(object):
    def __init__(self, Simplex, SimplexValues):
        nA, mA = Simplex[0][0], Simplex[0][1]
        nB, mB = Simplex[1][0], Simplex[1][1]
        nC, mC = Simplex[2][0], Simplex[2][1]
        self.Simplex = Simplex
        self.InvDenominator = self.BarycentricInverseDenominator()
        self.dBC = nB - nC
        self.dCB = mC - mB
        self.dCA = nC - nA
        self.dAC = mA - mC
        self.Values = SimplexValues

        
    def BarycentricInverseDenominator(self):
       # (nX,mX) are simply the two states in the "triangle"/R^2 case at vertex X
        ABC = self.Simplex
        nA, mA = ABC[0][0], ABC[0][1]
        nB, mB = ABC[1][0], ABC[1][1]
        nC, mC = ABC[2][0], ABC[2][1]

        # inverse of denominator for barycentric weights
        return 1.0/((nB - nC)*(mA - mC) + (mC - mB)*(nA - nC))

    
    def BarycentricWeights(self, x):
        # (nX,mX) are simply the two states in the "triangle"/R^2 case at vertex X
        ABC = self.Simplex
        nA, mA = ABC[0][0], ABC[0][1]
        nB, mB = ABC[1][0], ABC[1][1]
        nC, mC = ABC[2][0], ABC[2][1]

        # (n,m) is the common grid point in R^2
        n, m = x[0], x[1]

        wA = (self.dBC*(m - mC) + self.dCB*(n - nC))*self.InvDenominator
        wB = (self.dCA*(m - mC) + self.dAC*(n - nC))*self.InvDenominator

        # tuple for generic capabilities in function that consumes
        # this output
        return (wA, wB, 1.0-wA-wB)
     
        
    def Interp(self, x):
        Weights = self.BarycentricWeights(x)
        
        # for some reason this won't work with np.inner! Maybe there's an asanyarray in there...
        return Weights[0]*self.Values[0]+Weights[1]*self.Values[1]+Weights[2]*self.Values[2]

In [133]:
# Check the simplex interpolated values (should equal the second inputs)
bi = BarycentricInterpolant(ASimplex, np.array([1.0,2.0,3.0]))

In [134]:
bi.Interp(np.array([1.0,3.0]))

1.0

In [135]:
bi.Interp(np.array([3.0,4.0]))

2.0

In [136]:
bi.Interp(np.array([1.5, 3.0]))

3.0

In [137]:
bi.Interp(np.array([2.0, 3.5]))

1.5

It works as expected! The code has to be documented and made more generic to support $N$ dimensional models though.