<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/max_entropy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

N students take an exam scored out of 50 points.  The most probable distribution would be a uniform one where $n_i = \frac{N}{n_{grades}} = 50$.  The number of combinations in which N students can be distributed over $n_{grades}$ grades is

$$\Omega = \dfrac{N!}{\prod\limits_{i=0}^{n_{grades}} n_i!}$$

Taking the logarithm of both sides then applying Stirling's approximation $\ln n! = n \ln n - n$ and noting that $N = \sum\limits_{i=0}^{n_{grades}} n_i $ yields the Shannon entropy of the system:

$$\ln \Omega = N \ln N - \sum\limits_{i=0}^{n_{grades}}  n_i \ln n_i = -N\sum\limits_{i=0}^{n_{grades}}p_i \ln p_i$$

In [16]:
from scipy.optimize import minimize
from plotly.subplots import make_subplots 
import jax.numpy as jnp
import jax
from jax.config import config
config.update("jax_enable_x64", True)

In [17]:
def entropy(p):
    return jnp.sum(p*jnp.log(p))

In [75]:
n_grades=51
n_students=100
class_avg=25.

p_guess=jnp.full(n_grades, 1/n_grades)


In [76]:
def tot(p):
    return jnp.sum(p)-1.

In [77]:
def avg(p):
    return jnp.sum(p*n_students*jnp.arange(n_grades))/n_students - class_avg


In [78]:
def std(p):
    avg=jnp.sum(p*jnp.arange(n_grades))
    return jnp.sqrt((jnp.sum(p*n_students*(jnp.arange(n_grades)-avg)**2))/n_students)-5.

In [79]:
constraints=[dict(type='eq',fun=tot,jac=jax.jacobian(tot)),
             dict(type='eq',fun=avg,jac=jax.jacobian(avg)),
             dict(type='eq',fun=std,jac=jax.jacobian(std))]

res=minimize(entropy,p_guess,method='slsqp',bounds=[[0.,1.]]*n_grades,jac=jax.jacobian(entropy), constraints=constraints,
             options=dict(maxiter=1000), tol=1e-10)

In [80]:
fig=make_subplots()
fig.add_bar(x=jnp.arange(n_grades),y=n_students*res.x)

In [68]:
res

     fun: DeviceArray(-2.05631842, dtype=float64)
     jac: array([-19.20475535, -19.20408209, -19.20466921, -19.20495998,
       -19.20504369, -19.20438188, -19.20496772, -17.10699402,
       -14.93022743, -12.46301984, -10.11378249,  -8.016216  ,
        -6.18428161,  -4.61808547,  -3.31765518,  -1.98301522,
        -0.47874807,  -1.02799214,  -0.7751226 ,  -0.80292762,
        -1.09476947,  -1.6552459 ,  -2.48052035,  -3.57164731,
        -4.92851625,  -6.55115204])
 message: 'Optimization terminated successfully'
    nfev: 329
     nit: 163
    njev: 163
  status: 0
 success: True
       x: array([1.51157170e-09, 1.51258973e-09, 1.51170191e-09, 1.51126242e-09,
       1.51115442e-09, 1.51217691e-09, 1.51131147e-09, 1.23162199e-08,
       1.08600767e-07, 1.40941166e-06, 1.49080862e-05, 1.21432517e-04,
       7.58439483e-04, 3.63161499e-03, 1.33310825e-02, 5.06399367e-02,
       2.27922835e-01, 1.31599482e-01, 1.69462653e-01, 1.64815598e-01,
       1.23098709e-01, 7.02816041e-02, 3.07