# Stochastic Dynamics: Wright-Fisher Diffusion

This notebook demonstrates the stochastic dynamics of mode competition in RLVR.

Within the good modes, competition follows Wright-Fisher diffusion - a classical model from population genetics.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 12, 'figure.figsize': (12, 5)})

## Wright-Fisher Diffusion on the Simplex

For K modes with shares $y = (y_1, ..., y_K)$ on the simplex, the SDE is:

$$dy_i = \lambda (u_i - y_i) dt + \sqrt{V} \left( \sqrt{y_i} dB_i - y_i \sum_k \sqrt{y_k} dB_k \right)$$

where:
- $\lambda$ is the KL regularization strength
- $u$ is the reference distribution
- $V$ is the noise intensity (related to batch size)

In [None]:
def simulate_wf(y0, lam, V, T, dt, seed=42):
    """
    Simulate Wright-Fisher diffusion with diversity drift.
    
    Args:
        y0: Initial distribution (K-dim, sums to 1)
        lam: KL regularization strength
        V: Noise intensity
        T: Total time
        dt: Time step
    """
    rng = np.random.default_rng(seed)
    K = len(y0)
    u = np.ones(K) / K  # Uniform reference
    
    n_steps = int(T / dt)
    t = np.linspace(0, T, n_steps + 1)
    Y = np.zeros((n_steps + 1, K))
    Y[0] = y0
    
    sqrt_V_dt = np.sqrt(V * dt)
    
    for k in range(n_steps):
        y = np.maximum(Y[k], 1e-10)
        y = y / y.sum()
        
        # Drift (KL regularization towards uniform)
        drift = lam * (u - y) * dt
        
        # Noise (Wright-Fisher)
        z = rng.normal(size=K)
        sqrt_y = np.sqrt(y)
        proj = np.dot(sqrt_y, z)
        noise = sqrt_V_dt * (sqrt_y * z - y * proj)
        
        y_next = y + drift + noise
        y_next = np.maximum(y_next, 0)
        y_next = y_next / y_next.sum()
        Y[k + 1] = y_next
    
    return t, Y

def entropy(y):
    """Shannon entropy of distribution y."""
    y = np.clip(y, 1e-12, 1)
    return -np.sum(y * np.log(y))

## Effect of KL Regularization

Without KL regularization ($\lambda = 0$), the system eventually fixates on one mode (entropy drops).

With KL regularization ($\lambda > 0$), diversity is maintained.

In [None]:
K = 5
y0 = np.ones(K) / K  # Start uniform
V = 0.3
T = 10.0
dt = 0.01

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Case 1: No KL regularization
t1, Y1 = simulate_wf(y0, lam=0.0, V=V, T=T, dt=dt, seed=42)

# Case 2: With KL regularization  
t2, Y2 = simulate_wf(y0, lam=1.0, V=V, T=T, dt=dt, seed=42)

# Plot mode shares
for i in range(K):
    axes[0, 0].plot(t1, Y1[:, i], label=f'Mode {i+1}')
    axes[0, 1].plot(t2, Y2[:, i], label=f'Mode {i+1}')

axes[0, 0].set_title('No KL Regularization ($\\lambda = 0$)', fontsize=12)
axes[0, 0].set_xlabel('Time')
axes[0, 0].set_ylabel('Mode share')
axes[0, 0].legend(loc='upper right')

axes[0, 1].set_title('With KL Regularization ($\\lambda = 1$)', fontsize=12)
axes[0, 1].set_xlabel('Time')
axes[0, 1].set_ylabel('Mode share')
axes[0, 1].legend(loc='upper right')

# Plot entropy
H1 = [entropy(Y1[k]) for k in range(len(t1))]
H2 = [entropy(Y2[k]) for k in range(len(t2))]
H_max = np.log(K)

axes[1, 0].plot(t1, H1, 'b-', linewidth=2)
axes[1, 0].axhline(H_max, color='gray', linestyle='--', label=f'Max entropy (ln {K})')
axes[1, 0].set_title('Entropy Over Time ($\\lambda = 0$)', fontsize=12)
axes[1, 0].set_xlabel('Time')
axes[1, 0].set_ylabel('Entropy H(y)')
axes[1, 0].set_ylim(0, H_max * 1.1)
axes[1, 0].legend()

axes[1, 1].plot(t2, H2, 'b-', linewidth=2)
axes[1, 1].axhline(H_max, color='gray', linestyle='--', label=f'Max entropy (ln {K})')
axes[1, 1].set_title('Entropy Over Time ($\\lambda = 1$)', fontsize=12)
axes[1, 1].set_xlabel('Time')
axes[1, 1].set_ylabel('Entropy H(y)')
axes[1, 1].set_ylim(0, H_max * 1.1)
axes[1, 1].legend()

plt.suptitle('Wright-Fisher Diffusion: Mode Competition', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Monte Carlo: Entropy Distribution

In [None]:
n_mc = 50
T_long = 20.0

H_no_kl = []
H_with_kl = []

for seed in range(n_mc):
    _, Y = simulate_wf(y0, lam=0.0, V=V, T=T_long, dt=dt, seed=seed)
    H_no_kl.append([entropy(Y[k]) for k in range(len(Y))])
    
    _, Y = simulate_wf(y0, lam=1.0, V=V, T=T_long, dt=dt, seed=seed)
    H_with_kl.append([entropy(Y[k]) for k in range(len(Y))])

H_no_kl = np.array(H_no_kl)
H_with_kl = np.array(H_with_kl)
t_long = np.linspace(0, T_long, H_no_kl.shape[1])

fig, ax = plt.subplots(figsize=(10, 5))

# Plot mean and confidence band
ax.plot(t_long, H_no_kl.mean(axis=0), 'r-', label='$\\lambda = 0$ (no KL)', linewidth=2)
ax.fill_between(t_long, 
                np.percentile(H_no_kl, 10, axis=0),
                np.percentile(H_no_kl, 90, axis=0), 
                color='red', alpha=0.2)

ax.plot(t_long, H_with_kl.mean(axis=0), 'b-', label='$\\lambda = 1$ (with KL)', linewidth=2)
ax.fill_between(t_long,
                np.percentile(H_with_kl, 10, axis=0),
                np.percentile(H_with_kl, 90, axis=0),
                color='blue', alpha=0.2)

ax.axhline(H_max, color='gray', linestyle='--', label='Max entropy')
ax.set_xlabel('Time')
ax.set_ylabel('Entropy H(y)')
ax.set_title('Monte Carlo: Entropy Dynamics (mean $\\pm$ 80% CI)', fontsize=12)
ax.legend()
ax.set_ylim(0, H_max * 1.1)
plt.tight_layout()
plt.show()

## Key Takeaway

- **Without KL regularization**: Stochastic sampling noise causes modes to compete until one dominates (entropy collapse)
- **With KL regularization**: The system maintains diversity near the reference distribution

This is why KL regularization in RLHF/GRPO helps maintain response diversity.