<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/max_entropy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

N students take an exam scored out of 50 points.  The most probable distribution would be a uniform one where $n_i = \frac{N}{n_{grades}} = 50$.  The number of combinations in which N students can be distributed over $n_{grades}$ grades is

$$\Omega = \dfrac{N!}{\prod\limits_{i=0}^{n_{grades}} n_i!}$$

Taking the logarithm of both sides then applying Stirling's approximation $\ln n! = n \ln n - n$ and noting that $N = \sum\limits_{i=0}^{n_{grades}} n_i $ yields the Shannon entropy of the system:

$$\ln \Omega = N \ln N - \sum\limits_{i=0}^{n_{grades}}  n_i \ln n_i = -N\sum\limits_{i=0}^{n_{grades}}p_i \ln p_i$$

In [1]:
from scipy.optimize import minimize
import scipy.stats as stats
from plotly.subplots import make_subplots
import jax.numpy as jnp
import jax
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
def entropy(p):
    return jnp.sum(p*jnp.log(p))

In [3]:
n_grades=51
n_students=100
class_avg=25.

p_guess=jnp.full(n_grades, 1/n_grades)




In [4]:
def tot(p):
    return jnp.sum(p)-1.

In [5]:
def avg(p):
    return jnp.sum(p*n_students*jnp.arange(n_grades))/n_students - class_avg


In [6]:
def std(p):
    avg=jnp.sum(p*jnp.arange(n_grades))
    return jnp.sqrt((jnp.sum(p*n_students*(jnp.arange(n_grades)-avg)**2))/n_students)-5.

In [7]:
constraints=[dict(type='eq',fun=tot,jac=jax.jacobian(tot)),
             dict(type='eq',fun=avg,jac=jax.jacobian(avg)),
             dict(type='eq',fun=std,jac=jax.jacobian(std))]

res=minimize(entropy,p_guess,method='slsqp',bounds=[[0.,1.]]*n_grades,jac=jax.jacobian(entropy), constraints=constraints,
             options=dict(maxiter=1000), tol=1e-10)

In [8]:
fig=make_subplots()
fig.add_bar(x=jnp.arange(n_grades),y=n_students*res.x)

In [9]:
cdf=stats.norm.cdf(jnp.arange(n_grades),loc=25,scale=5)
fig=make_subplots()
fig.add_bar(x=jnp.arange(n_grades),y=n_students*res.x)
fig.add_scatter(x=jnp.arange(n_grades)+0.5,y=n_students*(cdf[1:]-cdf[:-1]), mode='lines')

In [10]:
fig.add_bar?