In [1]:
# TODO: 
# move X, E, Z, etc into one vector called state

# Calculate partical filter --- whatever that is
# Estimated partical filter with another NN
# Use estimated partical filter to generate posterior for parameters given data.

# Collect data
# Decide on model
# Estimate parameters and compare to data for various numbers of agents
# ???
# Profit

In [2]:
import sys
sys.path.append('/home/emmet/Documents/code/hetero_simulation/lib')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import jax
import jax.numpy as jnp
from jax.example_libraries.optimizers import adam, unpack_optimizer_state, pack_optimizer_state
from hetero_simulation.archive.agent import log_utility
from hetero_simulation.ml.utils import *



In [12]:
# Parameters of wealth distribution
A = 0.5
B = 0.2

# Preference and production parameters
alpha = 0.36 # (0, 1)
beta = 0.96 # (0, 1)
delta = 0.025 # (0, 1)
rho_z = 0.95 # (0, 1)
rho_e = 0.9 # (0, 1)
sigma_z = 0.01 # (0, inf)
sigma_e = 0.2 * jnp.sqrt(1 - rho_e**2) # (0, inf)

STRUCT_PARAM = {
    'alpha': alpha, 'beta': beta, 'delta': delta, 'a': A, 'b': B,
    'sigma_z': sigma_z, 'sigma_e': sigma_e, 'rho_z': rho_z, 'rho_e': rho_e
}

STRUCT_PARAM_IDXS = {k: v for k, v in zip(STRUCT_PARAM.keys(), np.arange(len(STRUCT_PARAM.keys())))}
STRUCT_PARAM_ARR = jnp.asarray(np.fromiter(STRUCT_PARAM.values(), dtype=jnp.float32))

AGG_IDXS = {'Xs': jnp.array(list(range(5))), 'Es': jnp.array(list(range(5, 10))), 'Zs': jnp.array(list(range(10, 11)))}
IDO_IDXS = {'xs': jnp.array(list(range(1))), 'es': jnp.array(list(range(1, 2)))}

# Hyper-parameters
N = 2 ** 10
MB = 2 ** 8
N_EPOCH = 1000
N_ITER = 10 * (N // MB)
N_FORWARD = 50

K = 5 # number of agents (~ size of state space)
M = 4
nn_shapes = jnp.array([M, M, M, M])

In [41]:
@jax.jit
def neural_network(model_params, struct_params, agg_state, ido_state):
    x = ido_state[IDO_IDXS['xs']]
    X_tilde = jnp.concatenate([struct_params.reshape(1, -1),
                               agg_state.reshape(1, -1),
                               ido_state.reshape(1, -1)], 
                              axis=1)
    X_tilde = X_tilde @ model_params['theta']
    # l1 = tanh(X_tilde, model_params['w0'], model_params['b0'])
    l2 = tanh(X_tilde, model_params['w1'], model_params['b1'])
    l3 = tanh(l2, model_params['w2'], model_params['b2'])
    # l4 = tanh(jnp.concatenate((l3, e[..., jnp.newaxis], x[..., jnp.newaxis])), model_params['w3'], model_params['b3'])
    return jnp.array([jnp.squeeze(x * sigmoid(l3, model_params['cwf'], model_params['cbf'])),
                      jnp.squeeze(exp(l3, model_params['lwf'], model_params['lbf']))])


@jax.jit
def fischer_burmeister(a, b):
    return a + b - jnp.sqrt(jnp.power(a, 2) + jnp.power(b, 2))


@jax.jit
def next_X(model_params, struct_params, agg_state):
    R, W = prices(struct_params, agg_state)
    w = jax.vmap(lambda x, e: (R * x) + (W * jnp.exp(e)))(agg_state[AGG_IDXS['Xs']], agg_state[AGG_IDXS['Es']])
    ido_state = jnp.concatenate((agg_state[AGG_IDXS['Es']], jnp.squeeze(w))).reshape(5, 2, -1)
    out = jax.vmap(neural_network, in_axes=(None, None, None, 0))(model_params, struct_params, agg_state, ido_state)
    c = out[..., 0]
    return w - c


@jax.jit
def prices(struct_params, agg_state):
    sumk = jnp.sum(agg_state[AGG_IDXS['Xs']])
    sumexpl = jnp.sum(jnp.exp(agg_state[AGG_IDXS['Es']]))
    w = (1 - struct_params[STRUCT_PARAM_IDXS['alpha']]) * jnp.exp(agg_state[AGG_IDXS['Zs']]) * jnp.power(sumk, struct_params[STRUCT_PARAM_IDXS['alpha']]) * jnp.power(sumexpl, -1 * struct_params[STRUCT_PARAM_IDXS['alpha']])
    r = 1 - struct_params[STRUCT_PARAM_IDXS['delta']] + struct_params[STRUCT_PARAM_IDXS['alpha']] * jnp.exp(agg_state[AGG_IDXS['Zs']]) * jnp.power(sumk, struct_params[STRUCT_PARAM_IDXS['alpha']] - 1) * jnp.power(sumexpl, 1 - struct_params[STRUCT_PARAM_IDXS['alpha']])
    return r, w


@jax.jit
def loss(model_params, struct_params, agg_state, key):
    X = agg_states[AGG_IDXS['Xs']]
    E = agg_states[AGG_IDXS['Es']]
    Z = agg_states[AGG_IDXS['Zs']]
    
    Z1, E1 = next_state(Z, E, struct_params, key)
    R, W = prices(struct_params, X, Z, E)
    w = jax.vmap(lambda x, e: (R * x) + (W * jnp.exp(e)))(X, E)
    outputs = jax.vmap(lambda i: neural_network(model_params, struct_params, X, E, Z, E[i], w[i]))(jnp.arange(K))
    c = outputs[..., 0]
    lm = outputs[..., 1]
    c_rel = c / w
    X1 = w - c
    R1, W1 = prices(struct_params, X1, Z1, E1)
    w1 = jax.vmap(lambda x, e: (R * x) + (W * jnp.exp(e)))(X1, E1)
    c1 = jax.vmap(lambda i: neural_network(model_params, struct_params, X1, E1, Z1, E1[i], w1[i])[0])(jnp.arange(K))

    u = lambda c: log_utility()(c)
    g = jax.vmap(lambda c: struct_params[STRUCT_PARAM_IDXS['beta']] * R1 * jax.grad(u)(c))(c1)
    up = jax.vmap(jax.grad(u))(c)
    g_diff = jax.vmap(lambda g, up, lm: (g / up) - lm)(g.reshape(-1, 1), up.reshape(-1, 1), lm.reshape(-1, 1))
    lm_diff = jax.vmap(lambda c, lm: fischer_burmeister(1 - c, 1 - lm))(c_rel.reshape(-1, 1), lm.reshape(-1, 1))

    return g_diff, lm_diff, c_rel


@jax.jit
def batch_loss(model_params, struct_params, agg_states, keys):
    g_diff_1, lm_diff_1, c_rels = jax.vmap(loss, in_axes=(None, 0, 0, None))(model_params, struct_params, agg_states, keys[0])
    g_diff_2, lm_diff_2, c_rels = jax.vmap(loss, in_axes=(None, 0, 0, None))(model_params, struct_params, agg_states, keys[1])
    g2 = g_diff_1 * g_diff_2
    lm2 = lm_diff_1 * lm_diff_2

    return jnp.squeeze(jnp.mean(g2 + lm2)), (jnp.squeeze(jnp.mean(g2)), jnp.squeeze(jnp.mean(lm2)), c_rels)


@jax.jit
def next_state(model_params, struct_params, agg_states, key):
    print(struct_params[STRUCT_PARAM_IDXS['rho_e']].shape)
    print(agg_states[:, AGG_IDXS['Es']])
    
    Xs_prime = jax.vmap(next_X, in_axes=(None, 0, 0))(model_params, struct_params, agg_states)
    Zs_prime = struct_params[STRUCT_PARAM_IDXS['rho_z']] * agg_states[:, AGG_IDXS['Zs']] + struct_params[STRUCT_PARAM_IDXS['sigma_z']] * jax.random.normal(key,)
    Es_prime = struct_params[STRUCT_PARAM_IDXS['rho_e']] * agg_states[:, AGG_IDXS['Es']] + struct_params[STRUCT_PARAM_IDXS['sigma_e']] * jax.random.normal(key, shape=(K,))
    
    agg_states = jnp.concatenate((Xs_prime, Es_prime, Zs_prime.reshape(-1, 1)), axis=1)
    return agg_states


@jax.jit
def simulate_state_forward(model_params, struct_params, agg_states, key):
    keys = jax.random.split(key, N_FORWARD)
    
    @jax.jit
    def inner_loop(agg_states, i):
        return next_state(model_params, struct_params, agg_states, keys[i]), None
    
    agg_states, _ = jax.lax.scan(inner_loop, agg_states, jnp.arange(N_FORWARD))
    
    return agg_states, keys[-1, -1]


@jax.jit
def generate_struct_params(key):
    keys = jax.random.split(key, 3)
    
    alpha = jax.random.uniform(keys[0], shape=(N // 2, 1))
    beta = jax.random.uniform(keys[1], shape=(N // 2, 1))
    delta = jax.random.uniform(keys[2], shape=(N // 2, 1))
    a = STRUCT_PARAM['a'] * jnp.ones((N // 2, 1))
    b = STRUCT_PARAM['b'] * jnp.ones((N // 2, 1))
    rho_z = STRUCT_PARAM['rho_z'] * jnp.ones((N // 2, 1))
    rho_e = STRUCT_PARAM['rho_e'] * jnp.ones((N // 2, 1))
    sigma_z = STRUCT_PARAM['sigma_z'] * jnp.ones((N // 2, 1))
    sigma_e = STRUCT_PARAM['sigma_e'] * jnp.ones((N // 2, 1))
    
    struct_params = jnp.concatenate((alpha, beta, delta, a, b, rho_z, rho_e, sigma_z, sigma_e), axis=1)
        
    return struct_params


@jax.jit
def generate_random_state(model_params, key): 
    struct_params = generate_struct_params(key)
    Zs = jnp.zeros(shape=(N // 2,))
    Es = jnp.zeros(shape=(N // 2, K))
    Xs = jnp.exp(STRUCT_PARAM['a'] * jax.random.normal(
        jax.random.PRNGKey(np.random.randint(1, int(1e8))), 
        shape=(N // 2, K))) + STRUCT_PARAM['b']
    
    agg_states = jnp.concatenate((Xs, Es, Zs.reshape(-1, 1)), axis=1)
 
    if N_FORWARD > 0:
        agg_states, key = simulate_state_forward(model_params, struct_params, agg_states, key)
        
    return agg_states, struct_params, key

In [42]:
scale = 0.5
init_keys = jax.random.split(jax.random.PRNGKey(5), 11)
theta0 = jax.random.gamma(init_keys[0], scale, shape=(2 * K + 3 + len(STRUCT_PARAM.values()), nn_shapes[0]))
w00 = scale * jax.random.normal(init_keys[1], shape=(2 * K + 3 + len(STRUCT_PARAM.values()), nn_shapes[0]))
w01 = scale * jax.random.normal(init_keys[2], shape=(nn_shapes[0], nn_shapes[1]))
w02 = scale * jax.random.normal(init_keys[3], shape=(nn_shapes[1], nn_shapes[2]))
w03 = scale * jax.random.normal(init_keys[4], shape=(nn_shapes[2] + 2, nn_shapes[3]))
w0f = scale * jax.random.normal(init_keys[5], shape=(nn_shapes[1], 1))
b00 = scale * jax.random.normal(init_keys[6], shape=(1, nn_shapes[0]))
b01 = scale * jax.random.normal(init_keys[7], shape=(1, nn_shapes[1]))
b02 = scale * jax.random.normal(init_keys[8], shape=(1, nn_shapes[2]))
b03 = scale * jax.random.normal(init_keys[9], shape=(1, nn_shapes[3]))
b0f = scale * jax.random.normal(init_keys[10], shape=(1, 1))

params0 = {
    'theta': theta0, 'w0': w00, 'w1': w01, 'w2': w02, 'w3': w03, 'cwf': w0f, 'lwf': w0f, 'b0': b00, 'b1': b01, 'b2': b02, 'b3': b03, 'cbf': b0f, 'lbf': b0f
}

In [43]:
def training_loop(opt_state, tol=1e-10, max_iter=10 ** 4):
    j = 0
    key = jax.random.PRNGKey(np.random.randint(1, int(1e8)))
    val_loss = jnp.inf
    grad = {'0': jnp.inf}
    opt_init, opt_update, get_params = adam(step_size=0.001)
    model_params = get_params(opt_state)

    agg_states, struct_params, key = generate_random_state(model_params, key)

    while j < max_iter and max([jnp.max(jnp.abs(v)) for k, v in grad.items()]) > tol and jnp.abs(val_loss) > tol:
        jj = 0
        while jj < N_ITER:
            keys = jax.random.split(key, 2)
            key = keys[-1]
            model_params = get_params(opt_state)

            sample = jax.random.choice(jax.random.PRNGKey(np.random.randint(1, int(1e8))), jnp.arange(N // 2), shape=(2, MB // 2,))
            val, grad = jax.value_and_grad(batch_loss, has_aux=True)(model_params, struct_params[sample[0]], agg_states[sample[1]], keys)
            val_loss = jnp.abs(val[0])
            c_star_rel = val[1][-4]
            assert (c_star_rel < 1).all()
            if jnp.isnan(val_loss):
                raise ValueError('Loss is nan')

            c_val = jnp.abs(val[1][0])
            kt_val = jnp.abs(val[1][1])
            opt_state = opt_update(j * N_ITER + jj, grad, opt_state)

            jj += 1

        # Start from a new random position, before moving into the current implied ergodic set
        params = get_params(opt_state)
        if N_FORWARD > 0:
            agg_states, struct_params, key = generate_random_state(params, key)

        if j % 10 == 0:
            trained_params = unpack_optimizer_state(opt_state)
            pickle.dump(trained_params, open(f'./ks_cont_models/ks_cont_model_{K}_{j}.pkl', 'wb'))
            print(f'Iteration: {j}  Total Loss: {val_loss:.2e}  C Loss: {c_val:.2e}  KT Loss: {kt_val:.2e}' +\
                  f'  Max Grad: {max([jnp.max(jnp.abs(v)) for k, v in grad.items()]):.2e}' +\
                  f'  Max Param: {max([jnp.max(jnp.abs(v)) for k, v in params.items()]):.2e}')
        j += 1
        
    return opt_state

In [44]:
opt_init, opt_update, get_params = adam(step_size=0.001)
# saved_params = pickle.load(open(f'./share/models/ks_model_{K}_final.pkl', 'rb'))
# opt_state = pack_optimizer_state(saved_params)
opt_state = opt_init(params0)

opt_state = training_loop(opt_state, max_iter=N_EPOCH)
# params = training_loop(max_iter=n_epoch)
# params = unpack_optimizer_state(opt_state)
# pickle.dump(params, open(f'./share/models/ks_cont_model_{K}_final.pkl', 'wb'))

(9,)
Traced<ShapedArray(float32[512,5])>with<DynamicJaxprTrace(level=1/4)>


ValueError: Incompatible shapes for broadcasting: ((1, 9), (512, 5))

In [None]:
Xs, Zs, Es, struct_params, key = generate_random_state(params0, jax.random.PRNGKey(1))

In [None]:
agg_state = jnp.concatenate((Xs, Es, Zs.reshape(-1, 1)), axis=1)
agg_idxs = [(0, 5), (5, 10), (10, 11)]

In [None]:
agg_idxs = {'Xs': list(range(5)), 'Es': list(range(5, 10)), 'Zs': list(range(10, 11))}

In [None]:
jnp.all(agg_state[:, agg_idxs['Xs']] == Xs)

In [None]:
jnp.all(agg_state[:, agg_idxs['Es']] == Es)

In [None]:
jnp.all(agg_state[:, agg_idxs['Zs']] == Zs.reshape(-1, 1))

In [None]:
_, _, get_params = adam(step_size=0.001)
saved_state = pickle.load(open(f'./ks_cont_models/ks_cont_model_{K}_990.pkl', 'rb'))
opt_state = pack_optimizer_state(saved_state)
params = get_params(opt_state)

key = jax.random.PRNGKey(np.random.randint(1, int(1e8)))
keys = jax.random.split(key, N // 2)
Xs, Zs, Es, struct_params, key = generate_random_state(params, key)

In [None]:
key = jax.random.PRNGKey(np.random.randint(1, int(1e8)))
idx = jax.random.choice(key, jnp.arange(N), shape=(5,))

X = Xs
Z = Zs
E = Es
e = E[:, 0]

# x_range = jnp.linspace(jnp.min(ws), jnp.max(ws), 100)
x_range = jnp.linspace(0, 10, 100)
c_range = jax.vmap(lambda X, E, Z, e: jax.vmap(lambda x: neural_network(params, STRUCT_PARAM_ARR, X, E, Z, e, x)[0])(x_range))(X, Z, E, e)
data = pd.DataFrame(jnp.concatenate((x_range.reshape(-1, 1), c_range.T), axis=1), dtype=np.float32)

fig, ax = plt.subplots(1, 1, figsize=(16, 8))
_ = data.plot(x=0, y=jnp.arange(Xs.shape[0]) + 1, ax=ax, legend=False)
_ = fig.suptitle(f'Optimal consumption (as a function of idiosyncratic state) across some aggregate states drawn from the ergodic distribution', fontsize=16)

In [None]:
key = jax.random.PRNGKey(np.random.randint(1, int(1e8)))
idx = jax.random.choice(key, jnp.arange(N), shape=(5,))

X = Xs[0]
Z = Zs[0]
E = Es[0]
e = E[0]

# x_range = jnp.linspace(jnp.min(ws), jnp.max(ws), 100)
x_range = jnp.linspace(0, 10, 100)
c_range = jax.vmap(lambda config: jax.vmap(lambda x: neural_network(params, config, X, E, Z, e, x)[0])(x_range))(struct_params)

data = pd.DataFrame(jnp.concatenate((x_range.reshape(-1, 1), c_range.T), axis=1), dtype=np.float32)

fig, ax = plt.subplots(1, 1, figsize=(16, 8))
_ = data.plot(x=0, y=jnp.arange(struct_params.shape[0]) + 1, ax=ax, legend=False)
_ = fig.suptitle(f'Optimal consumption (as a function of idiosyncratic state) across some structural parameters drawn', fontsize=16)