In [0]:
import valjax as vj

In [0]:
import jax
import jax.numpy as np
from jax import lax

In [0]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.style.use('config/clean.mplstyle')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [0]:
K = 100 # number of z values
N = 200 # number of grid points
f_lo, f_hi = 0.5, 2.0 # proportional grid range

In [0]:
par0 = {
    'β': 0.95,
    'δ': 0.1,
    'α': 0.35,
    'z': np.linspace(0.9, 1.1, K),
}

In [0]:
# find steady state
def get_kss(par):
    β, δ, z, α = par['β'], par['δ'], par['z'], par['α']
    rhs = (1-β)/β + δ
    k = (α*z/rhs)**(1/(1-α))
    return k
k_ss = get_kss(par0)

In [0]:
# construct capital grid
k_min, k_max = f_lo*k_ss[0], f_hi*k_ss[-1]
k_grid = np.linspace(k_min, k_max, N)

In [0]:
def util(c, eps=1e-6):
    c1 = np.maximum(eps, c)
    return np.log(c1)

In [0]:
def prod(z, k, α):
    return z*k**α

In [0]:
def value(par, grid, st, tv):
    β = par['β']
    cp = grid['cp']

    vn = st['vn']
    t = tv['t']

    vp = util(cp) + β*vn[:,None,:]
    ip = vj.smoothmax(vp, -1)

    kp = vj.interp(k_grid, ip)
    v = vj.interp_address(vp, ip, -1)
    
    stp = {
        'vn': v,
    }
    
    out = {
        'v': v,
        'kp': kp,
    }
    
    return stp, out

In [0]:
def solve(par, T):
    α = par['α']
    δ = par['δ']
    z = par['z']

    y_grid = prod(z[:,None], k_grid[None,:], α)
    yd_grid = y_grid + (1-δ)*k_grid[None,:]
    cp_grid = yd_grid[:,:,None] - k_grid[None,None,:]

    grid = {
        'cp': cp_grid
    }
    value1 = jax.partial(value, par, grid)

    st0 = {
        'vn': util(y_grid),
    }
    tv = {
        't': np.arange(T)[::-1],
    }
    last, path = lax.scan(value1, st0, tv)

    return path

In [0]:
jsolve = jax.jit(solve, static_argnums=(1,))

In [0]:
ret = jsolve(par0, 30)

In [0]:
%timeit -r 10 -n 10 jsolve(par0, 30)

In [0]:
plt.plot(k_grid, ret['v'][-10:,0,:].T);

In [0]:
fig, ax = plt.subplots()
z_mid = K // 2
ax.plot(k_grid, ret['kp'][-1,z_mid,:]-k_grid);
ax.hlines(0, *plt.xlim(), linewidth=1, linestyle='--', color='k')
ax.scatter(k_ss[z_mid], 0, color='k', zorder=10);

In [0]:
def moments(par, T):
    ret = solve(par, T)
    kp = ret['kp']
    i = ret['kp'] - k_grid[None, :]
    μi = np.mean(i)
    return μi

In [0]:
gmoment = jax.jit(jax.grad(moments, argnums=(0,)), static_argnums=(1,))

In [0]:
grad, = gmoment(par0, 30)

In [0]:
%timeit -r 10 -n 10 grad, = gmoment(par0, 30);

In [0]:
grad['z']