In [0]:
import numpy as np0
import jax.numpy as np
import jax.random as random
import grad_constr_jax

## Test Model

In [0]:
def con(x, y):
    return [
        np.sum(x*x) - 1.0,
        np.sum(y*y) - 1.0
    ]
def obj(x, y):
    return np.sum(x) + np.sum(y)

In [0]:
key = random.PRNGKey(0)
var0 = {
    'x': random.uniform(key, (100,)),
    'y': random.uniform(key, (100,))
}

In [0]:
%time var1 = grad_constr_jax.constrained_gradient_descent(obj, con, var0, output=True)

In [0]:
var1['x'][:10]

## Simple Model

In [0]:
N = 100
theta_min, theta_max = 1, 2

In [0]:
theta = np.linspace(theta_min, theta_max, N)
dist = (1/N)*np.ones(N)

In [0]:
data = np.array([
    2.0,
    0.3,
    0.05,
    0.02
])

In [0]:
def mean(x):
    return np.sum(x*dist)
def std(x):
    return np.sum((x**2)*dist) - mean(x)**2
def q_func(alpha, rho, r, theta):
    return (alpha*r**(1-rho)+(1-alpha)*theta**(1-rho))**(1/(1-rho))
def con(alpha, rho, kappa, eta, zeta, r, qbar):
    rp = np.maximum(0.001, r)
    q = q_func(alpha, rho, rp, theta)
    dq = alpha*(q/rp)**rho
    return [
        dq*qbar**zeta - kappa*rp**eta,
        qbar - np.sum(q*dist)
    ]
def mmt(alpha, rho, kappa, eta, zeta, r, qbar):
    rp = np.maximum(0.001, r)
    q = q_func(alpha, rho, rp, theta)
    M = kappa*(rp**(1+eta))/(1+kappa)
    prof = (qbar**zeta)*q
    return [
        mean(prof),
        std(prof),
        mean(M),
        std(M)
    ]
def obj(alpha, rho, kappa, eta, zeta, r, qbar):
    theory = np.array(mmt(alpha, rho, kappa, eta, zeta, r, qbar))
    return -np.sum((theory-data)**2)

In [0]:
vmin = {
    'eta': np.array([0.0])
}

In [0]:
var0 = {
    'alpha': np.array([0.5]),
    'rho': np.array([1.5]),
    'kappa': np.array([0.01]),
    'eta': np.array([1.0]),
    'zeta': np.array([0.02]),
    'r': 0.1*np.ones(N),
    'qbar': np.array([1.0])
}

In [0]:
var1 = grad_constr_jax.constrained_gradient_descent(
    obj, con, var0, vmin=vmin, max_iter=1000, output=True
)

In [0]:
var1

In [0]:
obj(**var1)