<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/MaxEntropy_ConstrainedOptimization_Broyden_Lagrange.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

N students take an exam scored out of 50 points.  The most probable distribution would be a uniform one where $n_i = \frac{N}{n_{grades}} = 50$.  The number of combinations in which N students can be distributed over $n_{grades}$ grades is

$$\Omega = \dfrac{N!}{\prod\limits_{i=0}^{n_{grades}} n_i}$$

Taking the logarithm of both sides then applying Stirling's approximation $\ln n! = n \ln n - n$ and noting that $N = \sum\limits_{i=0}^{n_{grades}} n_i $ yields the Shannon entropy of the system:

$$\ln \Omega = N \ln N - \sum\limits_{i=0}^{n_{grades}}  n_i \ln n_i = -N\sum\limits_{i=0}^{n_{grades}}p_i \ln p_i$$




In [53]:
import jax
import jax.numpy as jnp
from scipy.optimize import root
from jax.config import config
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'
eps=1e-12

In [99]:
def broyden3(func, x, J=None, tol=1e-10, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = Jf(x)
    Jinv = jnp.linalg.inv(J)
    f = func(x)

    for i in range(max_iter):

        dx = - Jinv @ f
        if verbose>0:
            print(f"\nIter: {i}  dx: {dx}")

        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha*dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp)-jnp.linalg.norm(f)
            if verbose>1:
                print(f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}")
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose>0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx=dx_try
        f= fp
        x= xp
        if verbose>0:
          print(x, f)
        if jnp.all(jnp.abs(f)<tol):
          break

        u = jnp.expand_dims(fp,1)
        v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
    return x, f, i

In [143]:
n_grades = 50
p_guess = jnp.full(n_grades, 1/n_grades)
grades = jnp.arange(n_grades)

p_min = jnp.full(n_grades, -20.)
p_max = -p_min

def entropy(p):
    return jnp.sum(p[:n_grades]*jnp.log(p[:n_grades]))

def p_constraint(p):
    return jnp.sum(p[:n_grades])-1.

def avg_constraint(p):
    return jnp.sum(p[:n_grades]*grades)-25.

def std_constraint(p):
    avg = jnp.sum(p[:n_grades]*grades)
    std= jnp.sqrt(jnp.sum(p[:n_grades]*(grades - avg)**2 )) - 5.
    return std

In [144]:
def L(p):
    pp = jax.scipy.special.expit(p[:n_grades])
    return entropy(pp)+p[-1]*p_constraint(pp)+p[-2]*avg_constraint(pp)+p[-3]*std_constraint(pp)

dL = jax.jit(jax.grad(L))

In [145]:
dL_guess = jnp.concatenate([jax.scipy.special.logit(p_guess), jnp.full(3,1.)])

In [146]:
p,f, i =broyden3(dL, dL_guess, tol=1e-10, xmin=jnp.concatenate([p_min, jnp.full(3,-jnp.inf)]),
              xmax=jnp.concatenate([p_max, jnp.full(3,jnp.inf)]), max_iter=500)
print(i)

67


In [147]:
pp = jax.scipy.special.expit(p[:n_grades])
pp

DeviceArray([2.97398593e-07, 7.92393859e-07, 2.02848386e-06,
             4.98919509e-06, 1.17901115e-05, 2.67691042e-05,
             5.83953494e-05, 1.22391479e-04, 2.46463493e-04,
             4.76850801e-04, 8.86422824e-04, 1.58317106e-03,
             2.71670981e-03, 4.47906286e-03, 7.09511702e-03,
             1.07984262e-02, 1.57902813e-02, 2.21844039e-02,
             2.99456787e-02, 3.88373021e-02, 4.83941031e-02,
             5.79381114e-02, 6.66445625e-02, 7.36535367e-02,
             7.82079655e-02, 7.97878695e-02, 7.82080120e-02,
             7.36536244e-02, 6.66446814e-02, 5.79382492e-02,
             4.83942470e-02, 3.88374407e-02, 2.99458033e-02,
             2.21845095e-02, 1.57903658e-02, 1.07984905e-02,
             7.09516344e-03, 4.47909483e-03, 2.71673082e-03,
             1.58318424e-03, 8.86430733e-04, 4.76855339e-04,
             2.46465986e-04, 1.22392789e-04, 5.83960093e-05,
             2.67694227e-05, 1.17902588e-05, 4.98926026e-06,
             2.02851296e

In [148]:
fig=make_subplots()
fig.add_bar(x=grades,y=pp)

In [149]:
std_constraint(pp)

DeviceArray(2.48920884e-11, dtype=float64)

In [150]:
p_constraint(pp)

DeviceArray(4.82280882e-13, dtype=float64)

In [151]:
avg_constraint(pp)

DeviceArray(2.85105273e-11, dtype=float64)