In [1]:
%pylab inline
from solveMDP_richHigh import * 
Vgrid = np.load("richHigh.npy")
matplotlib.rcParams['figure.figsize'] = [16, 8]
plt.rcParams.update({'font.size': 15})

Populating the interactive namespace from numpy and matplotlib




Model Solved! 


In [3]:
import pandas as pd
df_1999 = pd.read_csv("df_1999_30to60.csv")
df = df_1999[(df_1999["skillLevel"] == "High")&(df_1999["financeExperience"] == "No")]
df["ab"] = 30
df["wealth"] = df["liquidWealth"] + df["investmentAmount"]
codes = {'employed':1, 'unemployed': 0, "retired": 0}
df["employmentStatus"] = df["employmentStatus"].map(codes)
codes = {'owner':1, 'renter': 0}
df["ownership"] = df["ownership"].map(codes)
initialStates = df[["ageHead","wealth","ab","year","employmentStatus","ownership","participation"]]
initialStates["year"] = imaginedEconState[0]
initialStates = jnp.array(initialStates.values)

# risk free interest rate depending on current S state 
bondReturn = jnp.array(econRate[:,2])
# stock return depending on current S state
stockReturn = jnp.array(econRate[:,1])

@partial(jit, static_argnums=(0,1))
def transition_real(t,age,a,x):
    '''
        Input:
            x = [w,ab,s,e,o,z] single action 
            x = [0,1, 2,3,4,5] 
            a = [c,b,k,h,action] single state
            a = [0,1,2,3,4]
        Output:
            w_next
            ab_next
            s_next
            e_next
            o_next
            z_next
            
            prob_next
    '''
    s = jnp.array(x[2], dtype = jnp.int8)
    e = jnp.array(x[3], dtype = jnp.int8)
    # actions taken
    b = a[1]
    k = a[2]
    action = a[4]
    w_next = ((1+bondReturn[t])*b + (1+stockReturn[t])*k).repeat(nE)
    ab_next = (1-x[4])*(t*(action == 1)).repeat(nE) + x[4]*(x[1]*jnp.ones(nE))
    s_next = econ[t+1].repeat(nE)
    e_next = jnp.array([e,(1-e)])*(t+age-20<T_R) + jnp.array([0,0])*(t+age-20>=T_R)
    z_next = x[5]*jnp.ones(nE) + ((1-x[5]) * (k > 0)).repeat(nE)
    # job status changing probability and econ state transition probability
    pe = Pe[s, e]
    prob_next = jnp.array([1-pe, pe])
    # owner
    o_next_own = (x[4] - action).repeat(nE)
    # renter
    o_next_rent = action.repeat(nE)
    o_next = x[4] * o_next_own + (1-x[4]) * o_next_rent   
    return jnp.column_stack((w_next,ab_next,s_next,e_next,o_next,z_next,prob_next))

'''
    # [w,ab,s,e,o,z]
    # w explicitly 
    # assume ab = 30 the strong assumption we made 
    # s is known 
    # e is known 
    # o is known
    # z is known
'''
from jax import random

def simulation(key, period = yearCount):
    x = initialStates[key.sum()%initialStates.shape[0]][1:]
    age = int(initialStates[key.sum()%initialStates.shape[0]][0])
    path = []
    move = []
    for t in range(0, period):
        key, subkey = random.split(key)
        if t == T_max-1:
            _,a = V_solve(t + age - 20,Vgrid[:,:,:,:,:,:,t + age - 20],x)
        else:
            _,a = V_solve(t + age - 20,Vgrid[:,:,:,:,:,:,t + age - 20],x)
        xp = transition_real(t,age,a,x)           
        p = xp[:,-1]
        x_next = xp[:,:-1]
        path.append(x)
        move.append(a)
        x = x_next[random.choice(a = nE, p=p, key = subkey)]
    path.append(x)
    return jnp.array(path), jnp.array(move)

# total number of agents
num = initialStates.shape[0] * 10
# simulation part 
keys = vmap(random.PRNGKey)(jnp.arange(num))
KEY = [key for key in keys]
from multiprocessing import Pool
p = Pool(processes=48)
PathsMoves = list(p.map(simulation,KEY))

path = []
move = []
for i in tqdm(range(len(PathsMoves))):
    p,m = PathsMoves[i]
    path.append(p)
    move.append(m)
Paths = jnp.array(path)
Moves = jnp.array(move)

# x = [w,ab,s,e,o,z]
# x = [0,1, 2,3,4,5]
ws = Paths[:,:,0].T
ab = Paths[:,:,1].T
ss = Paths[:,:,2].T
es = Paths[:,:,3].T
os = Paths[:,:,4].T
zs = Paths[:,:,5].T
cs = Moves[:,:,0].T
bs = Moves[:,:,1].T
ks = Moves[:,:,2].T
hs = Moves[:,:,3].T
ms = Ms[jnp.append(jnp.array([0]),jnp.arange(yearCount)).reshape(-1,1) - jnp.array(ab, dtype = jnp.int8)]*os

TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value 1764.0

In [33]:
list(map(simulation,[keys[0],keys[1],keys[3]]))[0][0]

DeviceArray([[ 14.09   ,  30.     ,   3.     ,   1.     ,   1.     ,
                0.     ],
             [ 89.89631,  30.     ,   1.     ,   1.     ,   1.     ,
                0.     ],
             [171.87746,  30.     ,   2.     ,   1.     ,   1.     ,
                0.     ],
             [253.1452 ,  30.     ,   0.     ,   1.     ,   1.     ,
                0.     ],
             [332.722  ,  30.     ,   6.     ,   1.     ,   1.     ,
                0.     ],
             [411.01544,  30.     ,   3.     ,   1.     ,   1.     ,
                0.     ],
             [493.3709 ,  30.     ,   3.     ,   1.     ,   1.     ,
                0.     ],
             [587.296  ,  30.     ,   4.     ,   1.     ,   1.     ,
                0.     ],
             [693.8182 ,  30.     ,   3.     ,   1.     ,   1.     ,
                0.     ],
             [802.1082 ,  30.     ,   0.     ,   1.     ,   1.     ,
                0.     ],
             [891.853  ,  30.     ,   4.     ,   0

In [34]:
list(map(simulation,[keys[0],keys[1],keys[3]]))[0][1]

DeviceArray([[8.5635990e-02, 8.5550354e+01, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [1.6214260e-01, 1.6198045e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [2.4485324e-01, 2.4460838e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [3.2652456e-01, 3.2619803e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [4.0638766e-01, 4.0598126e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [4.8470387e-01, 4.8421915e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [5.6734598e-01, 5.6677863e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [6.6181886e-01, 6.6115704e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [7.6811552e-01, 7.6734735e+02, 0.0000000e+00, 1.3000000e+03,
              0.0000000e+00],
             [8.7670213e-01, 8.7582538e+02, 0.0000000e+00, 1.300

In [3]:
np.save("w2", ws)

In [4]:
np.save("waseozcbkhm2", np.array([ws,ab,ss,es,os,zs,cs,bs,ks,hs,ms]))