<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/MaxEntropy_ConstrainedOptimization_Broyden_Lagrange.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

N students take an exam scored out of 50 points.  The most probable distribution would be a uniform one where $n_i = \frac{N}{n_{grades}} = 50$.  The number of combinations in which N students can be distributed over $n_{grades}$ grades is

$$\Omega = \dfrac{N!}{\prod\limits_{i=0}^{n_{grades}} n_i}$$

Taking the logarithm of both sides then applying Stirling's approximation $\ln n! = n \ln n - n$ and noting that $N = \sum\limits_{i=0}^{n_{grades}} n_i $ yields the Shannon entropy of the system:

$$\ln \Omega = N \ln N - \sum\limits_{i=0}^{n_{grades}}  n_i \ln n_i = -N\sum\limits_{i=0}^{n_{grades}}p_i \ln p_i$$




In [None]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'
eps=1e-12

In [None]:
def broyden3(func, x, J=None, tol=1e-10, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = Jf(x)
    Jinv = jnp.linalg.inv(J)
    f = func(x)

    for i in range(max_iter):

        dx = - Jinv @ f
        if verbose>0:
            print(f"\nIter: {i}  dx: {dx}")

        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha*dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp)-jnp.linalg.norm(f)
            if verbose>1:
                print(f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}")
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose>0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx=dx_try
        f= fp
        x= xp
        if verbose>0:
          print(x, f)
        if jnp.all(jnp.abs(f)<tol):
          break

        u = jnp.expand_dims(fp,1)
        v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
    return x, f

In [None]:
N=100
n_grades = 50
p_guess = jnp.full(n_grades, 1/n_grades)
grades = jnp.arange(n_grades)

p_min,p_max = jnp.full(n_grades,0.0001), jnp.full(n_grades,0.9999)

def entropy(p):
    return jnp.sum(p[:n_grades]*jnp.log(p[:n_grades]))

def p_constraint(p):
    return jnp.sum(p[:n_grades])-1.

def avg_constraint(p):
    return jnp.sum(p[:n_grades]*grades)-25.

def std_constraint(p):
    avg = jnp.sum(p[:n_grades]*grades)
    std= jnp.sqrt(jnp.sum(p[:n_grades]*(grades - avg)**2 )) - 5.
    return std

In [None]:
def L(p):
    return entropy(p)+p[-1]*p_constraint(p)+p[-2]*avg_constraint(p)+p[-3]*std_constraint(p)

dL = jax.jit(jax.grad(L))

In [None]:
dL_guess = jnp.concatenate([p_guess, jnp.full(3,1.)])

In [None]:
p,_ =broyden3(dL, dL_guess, tol=1e-12, xmin=jnp.concatenate([p_min, jnp.full(3,-jnp.inf)]),
              xmax=jnp.concatenate([p_max, jnp.full(3,jnp.inf)]), max_iter=300)

In [None]:
fig=make_subplots()
fig.add_bar(x=grades,y=p)