In [0]:
import valjax as vj
import jax
import jax.numpy as npx
import numpy as np0
from jax import lax

In [0]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.style.use('config/clean.mplstyle')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [0]:
import warnings
warnings.filterwarnings('ignore')

In [0]:
# fixed params
T = 300 # time periods
N = 100 # grid size

In [0]:
# capital grid
k_min, k_max = 0.2, 10.0
k_grid = npx.linspace(k_min, k_max, N)

In [0]:
# simple parameters
par0 = {
    'β': 0.95,
    'δ': 0.1,
    'α': 0.35,
    'z': 1.0,
}

In [0]:
# defined functions
def util(c, eps=1e-6, np=npx):
    c1 = np.maximum(eps, c)
    return np.log(c1)
def prod(k, z, α):
    return z*k**α

In [0]:
def value(par, grid, st, tv, np=npx):
    β = par['β']
    cp = grid['cp']
    vn = st['vn']

    # calculate optimal investment
    vp = util(cp, np=np) + β*vn[None,:]
    ip = np.argmax(vp, axis=1)

    # apply at values
    kp = k_grid[ip]
    v = vj.address(vp, (ip,), axis=(-1,), np=np)
    
    # compute update errors
    err = np.max(np.abs(v-vn))

    # return state and output
    stp = {
        'vn': v,
    }    
    out = {
        'v': v,
        'kp': kp,
        'err': err,
    }
    
    return stp, out

In [0]:
def solve(par, T):
    α = par['α']
    δ = par['δ']
    z = par['z']

    # precompute grid values
    y_grid = prod(k_grid, z, α)
    yd_grid = y_grid + (1-δ)*k_grid
    cp_grid = yd_grid[:,None] - k_grid[None,:]

    # partially apply grid
    grid = {
        'cp': cp_grid,
    }
    value1 = jax.partial(value, par, grid)

    # scan over time (backwards)
    st0 = {
        'vn': util(y_grid),
    }
    tv = {
        't': np.arange(T)[::-1],
    }
    last, path = lax.scan(value1, st0, tv)

    return path

In [0]:
jsolve = jax.jit(solve, static_argnums=(1,))

### Using JAX

In [0]:
ret = jsolve(par0, T)

In [0]:
%timeit -r 10 -n 10 jsolve(par0, T)

In [0]:
plt.plot(ret['err'])
plt.yscale('log');

### Pure Numpy

In [0]:
# pure numpy grid
k_grid0 = np0.linspace(k_min, k_max, N)

In [0]:
def solve_numpy(par):
    # get params
    z, α, δ = par['z'], par['α'], par['δ']

    # precompute grid values
    y_grid0 = prod(k_grid0, z, α)
    yd_grid0 = y_grid0 + (1-δ)*k_grid0
    cp_grid0 = yd_grid0[:,None] - k_grid0[None,:]

    # store history
    v_path = np.zeros((T, N))
    err = np.zeros(T)

    # call value directly
    grid0 = {'cp': cp_grid0}
    st0 = {'vn': util(y_grid0, np=np0)}
    tv0 = {}
    for t in range(T):
        stp0, _ = value(par0, grid0, st0, tv0, np=np0)
        err[t] = np.max(np.abs(stp0['vn']-st0['vn']))
        st0 = {'vn': stp0['vn']}
        v_path[t, :] = stp0['vn']

    return {'v': v_path, 'err': err}

In [0]:
%time ret = solve_numpy(par0)

In [0]:
plt.plot(ret['err'])
plt.yscale('log');