In [None]:
import jax 
import jax.numpy as jnp
from jax.scipy import optimize
from jax.experimental.optimizers import adam
from agent import disc_ces_utility

In [None]:
c_shape = 3
X_shape = 2
beta = 0.95
X = jnp.exp(jax.random.normal(jax.random.PRNGKey(1), shape=(X_shape,)))
theta0 = jax.random.normal(jax.random.PRNGKey(2), shape=(X_shape,))
u = disc_ces_utility(beta, 2)


def V_hat(theta):
    return lambda X: jnp.dot(theta, X)


def future(u, v_hat, X):
    f = lambda c: jnp.multiply(-1, u(c, 0) + v_hat(X - jnp.sum(c)))
    c0 = jnp.ones(c_shape)
    return optimize.minimize(f, c0, method='BFGS', options={'line_search_maxiter': 10000, 'gtol': 1e-2}).x
       
    
def epsilon(theta, u, X):
    v_hat = V_hat(theta)
    f = future(u, v_hat, X)
    v = v_hat(X)
    return jnp.linalg.norm(v - f, ord=2)**2
    
    
e = lambda theta: epsilon(theta, u, X)

In [None]:
opt_init, opt_update, get_params = adam(step_size=0.01)
opt_state = opt_init(theta0)

i = 0
max_iter = 1000
tol = 1e-5
err = jnp.inf

while i < max_iter and err > tol:
    theta = get_params(opt_state)
    grad = jax.jacfwd(e)(theta)
    err = e(theta)
    if jnp.isnan(grad).any():
        print(f'Grad in iteration {i} is nan, terminating')
        break
    opt_state = opt_update(i, grad, opt_state)
    if i % 10 == 0:
        print(f'iteration: {i}\ntheta: {theta}\ngradient: {grad}\nerror: {err}')

    i += 1