<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/MaxEntropy_ConstrainedOptimization_Broyden_Lagrange.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

N students take an exam scored out of 50 points.  The most probable distribution would be a uniform one where $n_i = \frac{N}{n_{grades}} = 50$.  The number of combinations in which N students can be distributed over $n_{grades}$ grades is

$$\Omega = \dfrac{N!}{\prod\limits_{i=0}^{n_{grades}} n_i}$$

Taking the logarithm of both sides then applying Stirling's approximation $\ln n! = n \ln n - n$ and noting that $N = \sum\limits_{i=0}^{n_{grades}} n_i $ yields the Shannon entropy of the system:

$$\ln \Omega = N \ln N - \sum\limits_{i=0}^{n_{grades}}  n_i \ln n_i = -N\sum\limits_{i=0}^{n_{grades}}p_i \ln p_i$$




In [53]:
import jax
import jax.numpy as jnp
from scipy.optimize import root
from jax.config import config
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'
eps=1e-12

In [99]:
def broyden3(func, x, J=None, tol=1e-10, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = Jf(x)
    Jinv = jnp.linalg.inv(J)
    f = func(x)

    for i in range(max_iter):

        dx = - Jinv @ f
        if verbose>0:
            print(f"\nIter: {i}  dx: {dx}")

        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha*dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp)-jnp.linalg.norm(f)
            if verbose>1:
                print(f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}")
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose>0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx=dx_try
        f= fp
        x= xp
        if verbose>0:
          print(x, f)
        if jnp.all(jnp.abs(f)<tol):
          break

        u = jnp.expand_dims(fp,1)
        v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
    return x, f, i

In [134]:
n_grades = 40
p_guess = jnp.full(n_grades, 1/n_grades)
grades = jnp.arange(n_grades)

p_min = jnp.full(n_grades, -20.)
p_max = -p_min

def entropy(p):
    return jnp.sum(p[:n_grades]*jnp.log(p[:n_grades]))

def p_constraint(p):
    return jnp.sum(p[:n_grades])-1.

def avg_constraint(p):
    return jnp.sum(p[:n_grades]*grades)-20.

def std_constraint(p):
    avg = jnp.sum(p[:n_grades]*grades)
    std= jnp.sqrt(jnp.sum(p[:n_grades]*(grades - avg)**2 )) - 5.
    return std

In [135]:
def L(p):
    pp = jax.scipy.special.expit(p[:n_grades])
    return entropy(pp)+p[-1]*p_constraint(pp)+p[-2]*avg_constraint(pp)+p[-3]*std_constraint(pp)

dL = jax.jit(jax.grad(L))

In [136]:
dL_guess = jnp.concatenate([jax.scipy.special.logit(p_guess), jnp.full(3,1.)])

In [137]:
p,f, i =broyden3(dL, dL_guess, tol=1e-10, xmin=jnp.concatenate([p_min, jnp.full(3,-jnp.inf)]),
              xmax=jnp.concatenate([p_max, jnp.full(3,jnp.inf)]), max_iter=500)
print(i)

52


In [138]:
pp = jax.scipy.special.expit(p[:n_grades])
pp

DeviceArray([2.69832735e-05, 5.88128400e-05, 1.23167877e-04,
             2.47839515e-04, 4.79171998e-04, 8.90143652e-04,
             1.58882679e-03, 2.72483870e-03, 4.49006710e-03,
             7.10906771e-03, 1.08148440e-02, 1.58079536e-02,
             2.22013220e-02, 2.99591720e-02, 3.88444132e-02,
             4.83921716e-02, 5.79254539e-02, 6.66210644e-02,
             7.36209601e-02, 7.81698312e-02, 7.97488923e-02,
             7.81732076e-02, 7.36273200e-02, 6.66296975e-02,
             5.79354624e-02, 4.84026235e-02, 3.88544811e-02,
             2.99682314e-02, 2.22089947e-02, 1.58140998e-02,
             1.08195161e-02, 7.11244611e-03, 4.49239492e-03,
             2.72636911e-03, 1.58978783e-03, 8.90720543e-04,
             4.79503253e-04, 2.48021561e-04, 1.23263669e-04,
             5.88611253e-05], dtype=float64)

In [139]:
fig=make_subplots()
fig.add_bar(x=grades,y=pp)

In [140]:
std_constraint(pp)

DeviceArray(-3.70619091e-11, dtype=float64)

In [141]:
p_constraint(pp)

DeviceArray(-1.26698652e-12, dtype=float64)

In [142]:
avg_constraint(pp)

DeviceArray(-4.78479478e-11, dtype=float64)