In [1]:
import jax 
import jax.numpy as jnp
from jax.scipy import optimize
from jax.experimental.optimizers import adam
from agent import log_utility, ces_utility

In [5]:
def V_hat(params):
    @jax.jit
    def f(X):
        X_tilde = jnp.dot(params['theta'], X)
        return jnp.squeeze(jnp.dot(params['alpha'].T, X_tilde) - 
                           jnp.dot(jnp.dot(X_tilde.T, params['gamma']), jnp.dot(params['gamma'].T, X_tilde)))
    return f


@jax.jit
def future(params, beta, X):
    v_hat = V_hat(params)
    u = ces_utility(2.)
    
    f = lambda c: - (u(c) + beta * v_hat(X - jnp.sum(c)))
    c0 = jnp.ones(c_shape)
    
    c_star = optimize.minimize(f, c0, method='BFGS', options={'line_search_maxiter': 10000, 'gtol': 1e-2}).x
    
    return u(c_star) + beta * v_hat(X - jnp.sum(c_star))

  
@jax.jit
def epsilon(params, beta, X):
    v_hat = V_hat(params)
    
    f = future(params, beta, X)
    v = v_hat(X)
    
    return (v - f)**2
    
    
e = lambda params: epsilon(params, beta, X)

In [6]:
c_shape = 1
k = 100
m = 8
beta = 0.95

X = jnp.exp(jax.random.normal(jax.random.PRNGKey(123), shape=(k, 1)))
theta0 = jax.random.normal(jax.random.PRNGKey(7), shape=(m, k))
alpha0 = jax.random.normal(jax.random.PRNGKey(9), shape=(m, 1))
gamma0 = jax.random.normal(jax.random.PRNGKey(2), shape=(m, m))

params0 = {'theta': theta0, 'alpha': alpha0, 'gamma': gamma0}

In [7]:
opt_init, opt_update, get_params = adam(step_size=0.01)
opt_state = opt_init(params0)

i = 0
max_iter = 1000
tol = 1e-8
err = jnp.inf

while i < max_iter and err > tol:
    params = get_params(opt_state)
    grad = jax.jacfwd(e)(params)
    err = e(params)
    if jnp.isnan(jnp.concatenate([v for k, v in grad.items()], axis=1)).any():
        print(f'Grad in iteration {i} is nan, terminating')
        break
    opt_state = opt_update(i, grad, opt_state)
    if i % 10 == 0:
        print(f'iteration: {i}\nerror: {err}')

    i += 1

iteration: 0
error: 269689152.0
iteration: 10
error: 2971398.25
iteration: 20
error: 211190.0625
iteration: 30
error: 804676.0
iteration: 40
error: 118723.65625
iteration: 50
error: 14383.0126953125
iteration: 60
error: 34965.98828125
iteration: 70
error: 11723.083984375
iteration: 80
error: 852.893798828125
iteration: 90
error: 65.23243713378906
iteration: 100
error: 299.40997314453125
iteration: 110
error: 206.85125732421875
iteration: 120
error: 84.374267578125
iteration: 130
error: 25.911479949951172
iteration: 140
error: 6.640435695648193
iteration: 150
error: 1.5134525299072266
iteration: 160
error: 0.32581353187561035
iteration: 170
error: 0.07172876596450806
iteration: 180
error: 0.01731640100479126
iteration: 190
error: 0.0055084228515625
iteration: 200
error: 0.0022202134132385254
iteration: 210
error: 0.001102447509765625
iteration: 220
error: 0.0005841851234436035
iteration: 230
error: 0.00022178888320922852
iteration: 240
error: 8.159875869750977e-05
iteration: 250
error: 

In [8]:
params

{'alpha': DeviceArray([[ 0.8543441 ],
              [-0.72775424],
              [-1.6737317 ],
              [-0.3773765 ],
              [ 2.5686846 ],
              [ 0.5673227 ],
              [ 1.0751585 ],
              [ 1.535103  ]], dtype=float32),
 'gamma': DeviceArray([[ 2.0228012e-01, -7.4826962e-01,  1.8043938e-01,
                3.4251466e-01, -6.7078620e-02,  1.5271394e+00,
               -1.4843062e+00,  1.2805520e-01],
              [-1.0476917e-01,  8.5109037e-01, -1.5940050e+00,
               -3.6588252e-01, -1.1905104e+00,  8.0257165e-01,
                2.9651366e-02,  1.5829191e-02],
              [-6.3769870e-02, -1.3892974e+00, -4.0490654e-01,
                1.2577935e+00,  8.2991803e-01, -5.0333643e-01,
               -1.6242344e+00, -2.8344333e-01],
              [-1.6429379e+00,  1.1012220e+00, -6.8136024e-01,
                1.3712860e+00,  2.3882405e-01, -8.2709527e-01,
                3.5112992e-01, -1.6502929e+00],
              [ 1.1305002e+00,  2.653