<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/ipopt_colaboratory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q condacolab
import condacolab
condacolab.install()

⏬ Downloading https://github.com/jaimergp/miniforge/releases/latest/download/Mambaforge-colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:40
🔁 Restarting kernel...


In [1]:
!conda install -c conda-forge cyipopt

Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | done
Solving environment: - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - cyipopt


The following packages will be downloaded:

    package       

In [9]:
from jax.config import config

# Enable 64 bit floating point precision
config.update("jax_enable_x64", True)

# We use the CPU instead of GPU und mute all warnings if no GPU/TPU is found.
config.update('jax_platform_name', 'cpu')

from cyipopt import minimize_ipopt
from jax import jit, grad, jacrev, jacfwd
import jax.numpy as jnp

from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'

In [18]:
n_grades = 100
p_guess = jnp.full(n_grades, 1/n_grades)
grades = jnp.arange(n_grades)

p_min = jnp.full(n_grades, -20.)
p_max = -p_min

def entropy(p):
    return jnp.sum(p[:n_grades]*jnp.log(p[:n_grades]))

def p_constraint(p):
    return jnp.sum(p[:n_grades])-1.

def avg_constraint(p):
    return jnp.sum(p[:n_grades]*grades)-75.

def std_constraint(p):
    avg = jnp.sum(p[:n_grades]*grades)
    std= jnp.sqrt(jnp.sum(p[:n_grades]*(grades - avg)**2 )) - 5.
    return std

In [19]:

# jit the functions
entropy_jit = jit(entropy)
p_constraint_jit = jit(p_constraint)
avg_constraint_jit = jit(avg_constraint)
std_constraint_jit = jit(std_constraint)

# build the derivatives and jit them
entropy_grad = jit(grad(entropy))  
entropy_hess = jit(jacrev(jacfwd(entropy)))
p_constraint_jac = jit(jacfwd(p_constraint_jit))
avg_constraint_jac = jit(jacfwd(avg_constraint_jit))
std_constraint_jac = jit(jacfwd(std_constraint_jit))

p_constraint_hess = jacrev(jacfwd(p_constraint_jit))
p_constraint_hessvp = jit(lambda x,v: p_constraint_hess(x)*v[0])

avg_constraint_hess = jacrev(jacfwd(avg_constraint_jit))
avg_constraint_hessvp = jit(lambda x,v: avg_constraint_hess(x)*v[0])

std_constraint_hess = jacrev(jacfwd(std_constraint_jit))
std_constraint_hessvp = jit(lambda x,v: std_constraint_hess(x)*v[0])


cons = [
    {'type': 'eq', 'fun': p_constraint_jit, 'jac': p_constraint_jac, 'hess': p_constraint_hessvp},
    {'type': 'eq', 'fun': avg_constraint_jit, 'jac': avg_constraint_jac, 'hess': avg_constraint_hessvp},
    {'type': 'eq', 'fun': std_constraint_jit, 'jac': std_constraint_jac, 'hess': std_constraint_hessvp}
]

# initial guess
x0 = jnp.full(n_grades,1/n_grades)

# variable bounds: 1 <= x[i] <= 5
bnds = [(0, 1)]*n_grades

res = minimize_ipopt(entropy_jit, jac=entropy_grad, hess=entropy_hess, x0=x0, bounds=bnds,
                     constraints=cons, options={'disp': 5})


In [20]:
fig=make_subplots()
fig.add_bar(x=grades,y=res.x)
fig.update_layout(width=800)