In [0]:
import numpy as np0
import jax.numpy as np
import jax.random as random
import grad_constr_jax

## Test Model

In [0]:
def con(x, y):
    return [
        np.sum(x*x) - 1.0,
        np.sum(y*y) - 1.0
    ]
def obj(x, y):
    return np.sum(x) + np.sum(y)

In [0]:
key = random.PRNGKey(0)
var0 = {
    'x': random.uniform(key, (250,)),
    'y': random.uniform(key, (250,))
}

In [0]:
%time var1 = grad_constr_jax.constrained_gradient_descent(obj, con, var0, output=True)

## Simple Model

In [0]:
N = 100
theta_min, theta_max = 1, 2

In [0]:
theta = np.linspace(theta_min, theta_max, N)
dist = (1/N)*np.ones(N)

In [0]:
data = np.array([
    1.0,
    0.1,
    0.2
])

In [0]:
def con(alpha, rho, kappa, eta, zeta, r, qbar):
    q = (alpha*r**(1-rho)+(1-alpha)*theta**(1-rho))**(1/(1-rho))
    dq = alpha*(q/r)**rho
    return [
        dq*qbar**zeta - kappa*r**eta,
        qbar - np.sum(q*dist)
    ]
def obj(alpha, rho, kappa, eta, zeta, r, qbar):
    q = (alpha*r**(1-rho)+(1-alpha)*theta**(1-rho))**(1/(1-rho))
    M = kappa*(r**(1+eta))/(1+kappa)
    prof = (qbar**zeta)*q
    theory = np.array([
        np.sum(q*dist),
        np.sum(prof*dist),
        np.sum(M*dist)
    ])
    return np.sum((theory-data)**2)

In [0]:
var0 = {
    'alpha': np.array([0.5]),
    'rho': np.array([1.5]),
    'kappa': np.array([0.1]),
    'eta': np.array([1.0]),
    'zeta': np.array([0.02]),
    'r': 0.1*np.ones(N),
    'qbar': np.array([1.0])
}

In [0]:
%time var1 = grad_constr_jax.constrained_gradient_descent(obj, con, var0, output=True)

In [0]:
var1