# Information Theory Demos: Divergences and Mutual Information

What: Visualize KL, TV bound via Pinsker, and JS for Gaussian shifts; simulate I(X;Y) for a Binary Symmetric Channel (BSC).

Why: Build intuition for f-divergences and data processing—core tools in CAS 751.

How: Parametric Gaussians and Monte Carlo MI estimates.

TODO: Try different variances and larger grids; compare numeric JS to approximate formulas.


In [None]:
# Setup and Divergences over mean shifts
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

np.random.seed(0)

# KL between N(m1,s1^2) and N(m2,s2^2)
def kl_gauss(m1, s1, m2, s2):
    return np.log(s2/s1) + (s1**2 + (m1-m2)**2)/(2*s2**2) - 0.5

# JS via numeric integration

def js_gauss_numeric(m1, s1, m2, s2, xmin=-6, xmax=6, n=4000):
    grid = np.linspace(xmin, xmax, n)
    p = norm.pdf(grid, m1, s1)
    q = norm.pdf(grid, m2, s2)
    m = 0.5*(p+q)
    eps=1e-12
    kl_pm = np.trapz(p*(np.log(p+eps)-np.log(m+eps)), grid)
    kl_qm = np.trapz(q*(np.log(q+eps)-np.log(m+eps)), grid)
    return 0.5*(kl_pm+kl_qm)

shifts = np.linspace(0, 2.5, 26)
kl_vals, tv_bounds, js_vals = [], [], []
for d in shifts:
    kl = kl_gauss(0,1,d,1)
    kl_vals.append(kl)
    tv_bounds.append(np.sqrt(0.5*kl))
    js_vals.append(js_gauss_numeric(0,1,d,1))

plt.figure(figsize=(7,4))
plt.plot(shifts, kl_vals, label='KL')
plt.plot(shifts, tv_bounds, label='TV bound (Pinsker)')
plt.plot(shifts, js_vals, label='JS (numeric)')
plt.xlabel('Mean shift')
plt.ylabel('Divergence')
plt.title('Divergences vs Mean Shift for N(0,1) vs N(Δ,1)')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Mutual Information for a Binary Symmetric Channel (BSC)

def mi_bsc(p, n=200000, seed=1):
    rng = np.random.default_rng(seed)
    X = rng.integers(0,2,size=n)
    N = rng.binomial(1,p,size=n)
    Y = X ^ N
    # entropies
    def H01(v):
        c = np.bincount(v, minlength=2)/len(v)
        c = c[c>0]
        return -(c*np.log2(c)).sum()
    HX=H01(X); HY=H01(Y)
    # joint
    joint = np.zeros((2,2))
    for x,y in zip(X,Y):
        joint[x,y]+=1
    Pxy = joint/len(X)
    Px = Pxy.sum(axis=1, keepdims=True)
    Py = Pxy.sum(axis=0, keepdims=True)
    eps=1e-12
    I = np.sum(Pxy*(np.log2(Pxy+eps)-np.log2(Px+eps)-np.log2(Py+eps)))
    return I

ps = np.linspace(0,0.5,26)
mis = [mi_bsc(p) for p in ps]
plt.figure(figsize=(7,4))
plt.plot(ps, mis, marker='o')
plt.xlabel('Flip probability p')
plt.ylabel('I(X;Y) [bits]')
plt.title('MI of BSC vs Flip Probability')
plt.grid(True)
plt.show()
