In [1]:
%pylab inline
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
from jax.scipy.ndimage import map_coordinates
from constant import * 
import warnings
from jax import jit, partial
warnings.filterwarnings("ignore")
np.printoptions(precision=2)

Populating the interactive namespace from numpy and matplotlib
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


<contextlib._GeneratorContextManager at 0x7f88c4200690>

In [2]:
cgrid = np.load("cgrid.npy")
bgrid = np.load("bgrid.npy")
kgrid = np.load("kgrid.npy")
hgrid = np.load("hgrid.npy")

In [3]:
# define approximation of fit

def actions(w, n, s, t):
    c = map_coordinates(cgrid[:,:,:,t],np.vstack((w/scale,n/scale,s)), order = 1, mode = 'nearest')
    b = map_coordinates(bgrid[:,:,:,t],np.vstack((w/scale,n/scale,s)), order = 1, mode = 'nearest')
    k = map_coordinates(kgrid[:,:,:,t],np.vstack((w/scale,n/scale,s)), order = 1, mode = 'nearest')
    h = map_coordinates(hgrid[:,:,:,t],np.vstack((w/scale,n/scale,s)), order = 1, mode = 'nearest')
    return c,b,k,h
      
@jit    
def transition(b,k,s,s_next):
    return b*(1+r_b[s]) + k*(1+r_k[s_next])

#Define the earning function, which applies for both employment, 27 states
@partial(jit, static_argnums=(0,)) 
def y(t, s):
    if t <= T_R:
        return detEarning[t] * (1+gGDP[jnp.array(s, dtype = jnp.int16)])
    else:
        return detEarning[t] * jnp.ones(len(s))
    
#Define the evolution of the amount in 401k account 
@partial(jit, static_argnums=(0,)) 
def gn(t, s, s_next, n):
    if t <= T_R:
        # if the person is employed, then yi portion of his income goes into 401k 
        n_cur = n + y(t, s) * yi
    else:
        # t > T_R, n/discounting amount will be withdraw from the 401k 
        n_cur = n  - n/Dt[t]
        # the 401 grow as the same rate as the stock 
    return (1+(r_b[s]+r_k[s_next])/2)*n_cur

In [4]:
import quantecon as qe
#number of economy
num = 500000
# markov chain used to generate economic states
mc = qe.MarkovChain(Ps)
econState = np.array([mc.simulate(ts_length=T_max - T_min, init=0) for _ in range(num)])

In [None]:
#initially with 5k wealth
w = np.ones(num)*5
n = np.zeros(num)
ws = np.zeros((T_max-T_min,num))
ns = np.zeros((T_max-T_min,num))
cs = np.zeros((T_max-T_min,num))
bs = np.zeros((T_max-T_min,num))
ks = np.zeros((T_max-T_min,num))
hs = np.zeros((T_max-T_min,num))
for t in range(T_max-T_min-1):
    s = econState[:,t]
    s_next = econState[:,t+1]
    c,b,k,h = actions(w,n,s, t)
    ws[t,:] = w
    ns[t,:] = n
    cs[t,:] = c
    bs[t,:] = b
    ks[t,:] = k
    hs[t,:] = h
    w = transition(b,k,s,s_next)
    n = gn(t, s, s_next, n)

In [None]:
plt.figure(figsize = [16,8])
plt.plot(ws.mean(axis = 1)[:-1], label = "wealth")
plt.plot(cs.mean(axis = 1)[:-1], label = "consumption")
plt.plot(bs.mean(axis = 1)[:-1], label = "bond")
plt.plot(ks.mean(axis = 1)[:-1], label = "stock")
plt.plot((hs*pr).mean(axis = 1)[:-1], label = "housing")
plt.legend()

In [None]:
plt.plot(ns.mean(axis = 1)[:-1], label = "401k")

In [None]:
yi