In [1]:
import time
import jax 
import jax.numpy as jnp
from jax.scipy import optimize
from jax.experimental.optimizers import adam
from agent import log_utility, ces_utility

In [2]:
def V_hat(params):
    @jax.jit
    def f(X):
        X_tilde = jnp.dot(X, params['theta'])
        l1 = jnp.clip(jnp.dot(X_tilde, params['w0']) + params['b0'], 0)
        l2 = jnp.clip(jnp.dot(l1, params['w1']) + params['b1'], 0)
        l3 = jnp.clip(jnp.dot(l2, params['w2']) + params['b2'], 0)
        return jnp.squeeze(jnp.dot(l3, params['wf']) + params['bf'])
    return f


@jax.jit
def future(params, beta, X):
    v_hat = V_hat(params)
    u = ces_utility(2.)
    
    f = lambda c: - (u(c) + beta * v_hat(X - jnp.sum(c)))
    c0 = jnp.ones(c_shape)
    
    c_star = optimize.minimize(f, c0, method='BFGS', options={'line_search_maxiter': 10000, 'gtol': 1e-2}).x
    
    return u(c_star) + beta * v_hat(X - jnp.sum(c_star))

  
@jax.jit
def epsilon(params, beta, X):
    v_hat = V_hat(params)
    
    f = future(params, beta, X)
    v = v_hat(X)
    
    return (v - f)**2
    
    
e = lambda params: epsilon(params, beta, X)

In [3]:
c_shape = 1
k = 100
m = 10
nn_shapes = jnp.array([10, 20, 10])
beta = 0.95

X = jnp.exp(jax.random.normal(jax.random.PRNGKey(123), shape=(1, k)))
theta0 = jax.random.normal(jax.random.PRNGKey(129), shape=(k, m))
w00 = jax.random.normal(jax.random.PRNGKey(6), shape=(m, nn_shapes[0]))
w01 = jax.random.normal(jax.random.PRNGKey(7), shape=(nn_shapes[0], nn_shapes[1]))
w02 = jax.random.normal(jax.random.PRNGKey(8), shape=(nn_shapes[1], nn_shapes[2]))
w0f = jax.random.normal(jax.random.PRNGKey(9), shape=(nn_shapes[2], 1))
b00 = jax.random.normal(jax.random.PRNGKey(52), shape=(1, nn_shapes[0]))
b01 = jax.random.normal(jax.random.PRNGKey(51), shape=(1, nn_shapes[1]))
b02 = jax.random.normal(jax.random.PRNGKey(58), shape=(1, nn_shapes[2]))
b0f = jax.random.normal(jax.random.PRNGKey(48), shape=(1, 1))

# params0 = {'theta': theta0, 'w0': w00, 'w1': w01, 'wf': w0f, 'b0': b00, 'b1': b01, 'bf':b0f}
params0 = {'theta': theta0, 'w0': w00, 'w1': w01, 'w2': w02, 'wf': w0f, 'b0': b00, 'b1': b01, 'b2': b02, 'bf':b0f}

In [4]:
opt_init, opt_update, get_params = adam(step_size=0.01)
opt_state = opt_init(params0)

i = 0
tol = 1e-8
err = jnp.inf
st = time.time()

while err > tol:
    params = get_params(opt_state)
    grad = jax.jacfwd(e)(params)
    err = e(params)
    if any([jnp.isnan(v).any() for k, v in grad.items()]):
        print(f'Grad in iteration {i} is nan, terminating')
        break
    opt_state = opt_update(i, grad, opt_state)
    if i % 10 == 0:
        print(f'iteration: {i}\nerror: {err}\ntime elapsed: {time.time() - st}')

    i += 1

iteration: 0
error: 1034323.125
time elapsed: 30.449758291244507
iteration: 10
error: 1091.4166259765625
time elapsed: 50.65397047996521
iteration: 20
error: 5184.83056640625
time elapsed: 51.92743492126465
iteration: 30
error: 2643.070068359375
time elapsed: 53.17137289047241
iteration: 40
error: 2548.939453125
time elapsed: 54.57697010040283
iteration: 50
error: 2307.89697265625
time elapsed: 55.841893911361694
iteration: 60
error: 2029.8502197265625
time elapsed: 57.087433099746704
iteration: 70
error: 1762.4273681640625
time elapsed: 58.34239077568054
iteration: 80
error: 1522.6334228515625
time elapsed: 59.53592491149902
iteration: 90
error: 1314.093505859375
time elapsed: 60.635592222213745
iteration: 100
error: 1135.1519775390625
time elapsed: 61.86457347869873
iteration: 110
error: 982.4012451171875
time elapsed: 63.160945892333984
iteration: 120
error: 852.146240234375
time elapsed: 64.40913248062134
iteration: 130
error: 740.9515380859375
time elapsed: 65.49854183197021
itera

iteration: 1120
error: 4.641764098778367e-05
time elapsed: 186.9206063747406
iteration: 1130
error: 3.690447556436993e-05
time elapsed: 188.22129464149475
iteration: 1140
error: 2.946659515146166e-05
time elapsed: 189.47118425369263
iteration: 1150
error: 2.4103341274894774e-05
time elapsed: 190.7383041381836
iteration: 1160
error: 1.985207200050354e-05
time elapsed: 191.99877405166626
iteration: 1170
error: 1.6412290278822184e-05
time elapsed: 193.2021505832672
iteration: 1180
error: 1.3523036614060402e-05
time elapsed: 194.52229738235474
iteration: 1190
error: 1.1205065675312653e-05
time elapsed: 195.75098299980164
iteration: 1200
error: 9.34818308451213e-06
time elapsed: 196.86135005950928
iteration: 1210
error: 7.947106496430933e-06
time elapsed: 198.109069108963
iteration: 1220
error: 6.858059350633994e-06
time elapsed: 199.34833765029907
iteration: 1230
error: 5.821584636578336e-06
time elapsed: 200.53320002555847
iteration: 1240
error: 5.065521690994501e-06
time elapsed: 201.827

iteration: 2180
error: 2.8816430130973458e-08
time elapsed: 316.70996737480164
iteration: 2190
error: 2.8816430130973458e-08
time elapsed: 317.89557671546936
iteration: 2200
error: 2.8172507882118225e-08
time elapsed: 319.04934000968933
iteration: 2210
error: 2.7535861590877175e-08
time elapsed: 320.1972715854645
iteration: 2220
error: 2.7535861590877175e-08
time elapsed: 321.4601557254791
iteration: 2230
error: 2.7535861590877175e-08
time elapsed: 322.5030462741852
iteration: 2240
error: 2.506203600205481e-08
time elapsed: 323.69172954559326
iteration: 2250
error: 2.270462573505938e-08
time elapsed: 324.8165624141693
iteration: 2260
error: 2.1569576347246766e-08
time elapsed: 326.04656767845154
iteration: 2270
error: 2.213346306234598e-08
time elapsed: 327.40186738967896
iteration: 2280
error: 1.7826096154749393e-08
time elapsed: 328.6032919883728
iteration: 2290
error: 1.7826096154749393e-08
time elapsed: 329.79892897605896
iteration: 2300
error: 1.7826096154749393e-08
time elapsed: 