In [1]:
# Prep
from numba import njit # important if we want to keep this in Python (which seems to be an (un?)stated goal)
import numpy as np # we have to use numpy arrays to jit functions on arrays

# G2EGM
This notebook briefly shows some prototypes for functions that will be useful and necessary (though not sufficient) to people working with models that G2EGM can handle. To see why `@njit` might be important to get right, we also show some simple timings.

The central step in G2EGM requires keeping track of all the **candidate** policies ($C$) at the vertices of a collection of simpleces ($S$). Elements of $C$ are vectors in $\mathbb{R}^n$ and elements of $S$ are simpleces with vertices in $\mathbb{R}^n$ where $n$ in both cases refer to the number of (continuous) states. These are used to calculate the barycentric weights for the common grid ($G$) when interpolating from the endogeous grids that are the output of each so-called *segment* calculation.

Let us calculate the barycentric weights $(wA, wB, wC)$ used to interpolate the policy onto a new point in $R^n$ from data consisting of a simplex $ABC$ (for the triangle it represents with vertices $(A,B,C)$) and the new point $x$.

In [1]:
def BarycentricWeightsNoJIT(ABC, x):
    # (nX,mX) are simply the two states in the "triangle"/R^2 case at vertex X
    nA, mA = ABC[0][0], ABC[0][1]
    nB, mB = ABC[1][0], ABC[1][1]
    nC, mC = ABC[2][0], ABC[2][1]
    
    # This line right here suggest that we might want to input a collection of points X,
    # as inv_denom will be the same for all x, and there might be many of them to be evaluate
    # for each simplex.
    inv_denom = 1/((nB - nC)*(mA - mC) + (mC - mB)*(nA - nC))
    
    # (n,m) is the common grid point in R^2
    n, m = x[0], x[1]
    
    # this is where
    wA = ((nB - nC)*(m - mC) + (mC - mB)*(n - nC))*inv_denom
    wB = ((nC - nA)*(m - mC) + (mA - mC)*(n - nC))*inv_denom

    return wA, wB, 1-wA-wB

Then let's construct a simplex with vertices in $\mathbb{R}^2$.

In [3]:
ASimplex = np.array([[1.0,3.0],[3.,4.0],[1.5,3.0]])
print(ASimplex)

[[1.  3. ]
 [3.  4. ]
 [1.5 3. ]]


In [4]:
ASimplex.shape # print the shape to verify it is indeed 3 x 2

(3, 2)

Simple checks of the function include inputting the vertices for $x$, linear combinations and outside points.

In [5]:
# Should put all weight on first vertex
BarycentricWeightsNoJIT(ASimplex, np.array([1.0,3.0]))

(1.0, 0.0, 0.0)

In [6]:
# Should put all weight on second vertex
BarycentricWeightsNoJIT(ASimplex, np.array([3.0,4.0]))

(0.0, 1.0, 0.0)

In [7]:
# Should put all weight on third vertex
BarycentricWeightsNoJIT(ASimplex, np.array([1.5,3.0]))

(0.0, 0.0, 1.0)

In [8]:
# Should put equal weight between first two verteces
BarycentricWeightsNoJIT(ASimplex, np.array([2.0,3.5]))

(0.5, 0.5, 0.0)

In [9]:
# Should put equal weight between last two verteces
BarycentricWeightsNoJIT(ASimplex, np.array([4.5/2,3.5]))

(0.0, 0.5, 0.5)

In [10]:
# Should be quite far outside of simplex.
BarycentricWeightsNoJIT(ASimplex, ASimplex[0]-1)

(0.0, -1.0, 2.0)

## JIT
So the function seems to work. It was written to only use numpy operations if $ABC$ and $x$ are numpy arrays, so it should be jittable. Let's try to use `@njit`.

In [11]:
@njit
def BarycentricWeights(ABC, x):
    # (nX,mX) are simply the two states in the "triangle"/R^2 case at vertex X
    nA, mA = ABC[0][0], ABC[0][1]
    nB, mB = ABC[1][0], ABC[1][1]
    nC, mC = ABC[2][0], ABC[2][1]

    # (n,m) is the common grid point in R^2
    n, m = x[0], x[1]
    
    inv_denom = 1/((nB - nC)*(mA - mC) + (mC - mB)*(nA - nC))  # might be good to loop over points anyway!
    
    wA = ((nB - nC)*(m - mC) + (mC - mB)*(n - nC))*inv_denom
    wB = ((nC - nA)*(m - mC) + (mA - mC)*(n - nC))*inv_denom

    return wA, wB, 1-wA-wB

... and verify that it still works!

In [12]:
# Should put all weight on first vertex
bw1 = BarycentricWeights(ASimplex, np.array([1.0,3.0]))
# Should put all weight on second vertex
bw2 = BarycentricWeights(ASimplex, np.array([3.0,4.0]))
# Should put all weight on third vertex
bw3 = BarycentricWeights(ASimplex, np.array([1.5,3.0]))
# Should put equal weight between first two verteces
bw4 = BarycentricWeights(ASimplex, np.array([2.0,3.5]))
# Should put equal weight between last two verteces
bw5 = BarycentricWeights(ASimplex, np.array([4.5/2,3.5]))
# Internal point has weights that sum to one
bw6 = BarycentricWeights(ASimplex, np.array(sum(ASimplex)/3))
print("1) Verify that all weight is on first vertex.")
print("    weights: ", bw1)
print("    status: ", np.array([bw1[0]==1.0, bw1[1] == 0.0, bw1[2] == 0.0]).all())
print("2) Verify that all weight is on second vertex.")
print("    weights: ", bw2)
print("    status: ", np.array([bw2[0]==0.0, bw2[1] == 1.0, bw2[2] == 0.0]).all())
print("3) Verify that all weight is on third vertex.")
print("    weights: ", bw3)
print("    status: ", np.array([bw3[0]==0.0, bw3[1] == 0.0, bw3[2] == 1.0]).all())
print("4) Verify that all weight is put equally on first verteces.")
print("    weights: ", bw4)
print("    status: ", np.array([bw4[0]==0.50, bw4[1] == 0.50, bw4[2] == 0.0]).all())
print("5) Verify that all weight is put equally on last verteces.")
print("    weights: ", bw5)
print("    status: ", np.array([bw5[0]==0.00, bw5[1] == 0.50, bw5[2] == 0.50]).all())
print("5) Verify that an internal point has weights that sum to one.")
print("    weights: ", bw6)
print("    status: ", np.array([bw6[0], bw6[1], bw6[2]]).sum())

1) Verify that all weight is on first vertex.
    weights:  (1.0, 0.0, 0.0)
    status:  True
2) Verify that all weight is on second vertex.
    weights:  (0.0, 1.0, 0.0)
    status:  True
3) Verify that all weight is on third vertex.
    weights:  (0.0, 0.0, 1.0)
    status:  True
4) Verify that all weight is put equally on first verteces.
    weights:  (0.5, 0.5, 0.0)
    status:  True
5) Verify that all weight is put equally on last verteces.
    weights:  (0.0, 0.5, 0.5)
    status:  True
5) Verify that an internal point has weights that sum to one.
    weights:  (0.3333333333333339, 0.3333333333333335, 0.3333333333333326)
    status:  1.0


## Time to run
So it seems to work! But are they equally fast? Let's find out. The merit of jitting is to compile the code instead of interpreting it, and do so given the types we present to it. Very unscientifically, we use `timeit.timeit` to time the total time it takes to calculate 1e6 (default value) weights.

In [13]:
from timeit import timeit

In [14]:
def timebw():
    return BarycentricWeights(ASimplex, np.array([1.0,2.0]))
def timebwnj():
    return BarycentricWeightsNoJIT(ASimplex, np.array([1.0,2.0]))

In [17]:
timeit(timebw)

1.4942797016078266

In [18]:
timeit(timebwnj)

4.862568766745028

As we can see, we are not talking orders of magnitude, or magic, here, but jitting did seem to make it a factor two faster.