<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/ipopt_colaboratory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q condacolab
import condacolab
condacolab.install()

[0m✨🍰✨ Everything looks OK!


In [2]:
!conda install -c conda-forge cyipopt

Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | done
Solving environment: - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - cyipopt


The following packages will be downloaded:

    package                    |            build
    ---------------------------|----

In [83]:
from jax.config import config
import jax.numpy as jnp

# Enable 64 bit floating point precision
config.update("jax_enable_x64", True)

# We use the CPU instead of GPU und mute all warnings if no GPU/TPU is found.
config.update('jax_platform_name', 'cpu')

import cyipopt
from cyipopt import minimize_ipopt
import jax
from jax import jit, grad, jacrev, jacfwd
import jax.numpy as jnp

from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'

from functools import partial

In [252]:
n_grades = 500
p_guess = jnp.full(n_grades, 1/n_grades)
grades = jnp.arange(n_grades)


def entropy(p):
    return jnp.sum(p[:n_grades]*jnp.log(p[:n_grades]))

def p_constraint(p):
    return jnp.sum(p[:n_grades])-1.

def avg_constraint(p):
    return jnp.sum(p[:n_grades]*grades)-250.

def std_constraint(p):
    avg = jnp.sum(p[:n_grades]*grades)
    std= jnp.sqrt(jnp.sum(p[:n_grades]*(grades - avg)**2 )) - 50.
    return std

In [260]:

# jit the functions
entropy_jit = jit(entropy)
p_constraint_jit = jit(p_constraint)
avg_constraint_jit = jit(avg_constraint)
std_constraint_jit = jit(std_constraint)

# build the derivatives and jit them
entropy_grad = jit(grad(entropy))  
entropy_hess = jit(jacrev(jacfwd(entropy)))

p_constraint_jac = jit(jacfwd(p_constraint_jit))
avg_constraint_jac = jit(jacfwd(avg_constraint_jit))
std_constraint_jac = jit(jacfwd(std_constraint_jit))

p_constraint_hess = jacrev(jacfwd(p_constraint_jit))
p_constraint_hessvp = jit(lambda x,v: p_constraint_hess(x)*v[0])

avg_constraint_hess = jacrev(jacfwd(avg_constraint_jit))
avg_constraint_hessvp = jit(lambda x,v: avg_constraint_hess(x)*v[0])

std_constraint_hess = jacrev(jacfwd(std_constraint_jit))
std_constraint_hessvp = jit(lambda x,v: std_constraint_hess(x)*v[0])


cons = [
    {'type': 'eq', 'fun': p_constraint_jit, 'jac': p_constraint_jac, 'hess': p_constraint_hessvp},
    {'type': 'eq', 'fun': avg_constraint_jit, 'jac': avg_constraint_jac, 'hess': avg_constraint_hessvp},
    {'type': 'eq', 'fun': std_constraint_jit, 'jac': std_constraint_jac, 'hess': std_constraint_hessvp}
]

In [261]:
x0 = jnp.full(n_grades,1/n_grades)
bnds = [(0, 1)]*n_grades
res = minimize_ipopt(entropy_jit, jac=entropy_grad, hess=entropy_hess, x0=x0, bounds=bnds,
                     constraints=cons, options={'disp': 5})


In [262]:
fig=make_subplots()
fig.add_bar(x=grades,y=res.x)
fig.update_layout(width=800)

In [253]:
class Problem():
    def __init__(self,obj,constraints,x0):
        self.x0=x0
        self.obj=obj
        self.obj_jit=jax.jit(obj)
        self.obj_grad = jax.jit(jax.grad(obj))
        self.constraints_functions=constraints
        self.constraint_jac = jax.jit(jax.jacobian(self.combined_constraints))
        self.hess = jax.jit(jax.jacrev(jax.jacfwd(self.L)))
        self.n_constraints = self.combined_constraints(self.x0).size
        self.row, self.col = self.hessianstructure()


    def objective(self, x):
        return self.obj_jit(x)

    def gradient(self, x):
        return self.obj_grad(x)

    @partial(jax.jit, static_argnums=(0,))
    def combined_constraints(self, x):
        return jnp.concatenate([jnp.atleast_1d(f(x)) for f in self.constraints_functions])

    def constraints(self, x):
        return self.combined_constraints(x)

    def jacobian(self, x):
        return self.constraint_jac(x)



    @partial(jax.jit, static_argnums=(0,))
    def L(self, x, lagrange, obj_factor):
        return obj_factor*self.obj(x) + jnp.sum(lagrange*self.combined_constraints(x))

    def hessian(self, x, lagrange, obj_factor):
        return self.hess(x,lagrange,obj_factor)[self.row, self.col]

    def hessianstructure(self):
        return jnp.nonzero(jnp.tril(self.hess(self.x0,jnp.ones(self.n_constraints), 1.)))

    def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu,
                     d_norm, regularization_size, alpha_du, alpha_pr,
                     ls_trials):
        """Prints information at every Ipopt iteration."""

        msg = "Objective value at iteration #{:d} is - {:g}"

        print(msg.format(iter_count, obj_value))


In [254]:
problem=Problem(entropy, [p_constraint, avg_constraint, std_constraint], jnp.full(n_grades,1/n_grades))

In [257]:
x0 = jnp.full(n_grades,1/n_grades)
nlp=cyipopt.Problem(
   n=len(x0),
   m=problem.n_constraints,
   problem_obj=problem,
   lb=[0.]*len(x0),
   ub=[1.]*len(x0),
   cl=[0.]*problem.n_constraints,
   cu=[0.]*problem.n_constraints,
)

In [258]:
res=nlp.solve(x0)

Objective value at iteration #0 is - -23.0258
Objective value at iteration #1 is - -20.6738
Objective value at iteration #2 is - -17.2309
Objective value at iteration #3 is - -13.2973
Objective value at iteration #4 is - -10.2909
Objective value at iteration #5 is - -8.1091
Objective value at iteration #6 is - -6.60414
Objective value at iteration #7 is - -5.22766
Objective value at iteration #8 is - -4.21721
Objective value at iteration #9 is - -3.51208
Objective value at iteration #10 is - -2.71272
Objective value at iteration #11 is - -2.67005
Objective value at iteration #12 is - -2.42972
Objective value at iteration #13 is - -1.28043
Objective value at iteration #14 is - -1.28043
Objective value at iteration #15 is - -1.35322
Objective value at iteration #16 is - -1.35202
Objective value at iteration #16 is - -1.35202
Objective value at iteration #17 is - -1.3032
Objective value at iteration #18 is - -0.957784
Objective value at iteration #19 is - -0.711544
Objective value at iter

In [259]:
fig=make_subplots()
fig.add_bar(x=grades,y=res[0])
fig.update_layout(width=800)