# Simplex Flow Visualization

This notebook visualizes the probability flow on the simplex for different values of Youden's index J.

We model K modes (e.g., 2 good + 1 bad) and show how the probability mass flows under noisy rewards.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import LineCollection

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 12})

## Simplex Coordinates

For K=3 modes, we use barycentric coordinates to visualize on a 2D triangle.

In [None]:
def to_cartesian(p):
    """Convert 3D simplex point to 2D Cartesian for plotting."""
    # Vertices of equilateral triangle
    v0 = np.array([0, 0])
    v1 = np.array([1, 0])
    v2 = np.array([0.5, np.sqrt(3)/2])
    return p[0] * v0 + p[1] * v1 + p[2] * v2

def draw_simplex(ax):
    """Draw the simplex triangle."""
    vertices = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
    ax.plot(vertices[:, 0], vertices[:, 1], 'k-', linewidth=2)
    ax.set_aspect('equal')
    ax.axis('off')
    
    # Label vertices
    ax.text(-0.05, -0.05, 'Good 1', ha='center', fontsize=10)
    ax.text(1.05, -0.05, 'Good 2', ha='center', fontsize=10)
    ax.text(0.5, np.sqrt(3)/2 + 0.05, 'Bad', ha='center', fontsize=10, color='red')

## Compute Flow Field

The drift on the simplex depends on J:
- J > 0: Flow towards good modes
- J < 0: Flow towards bad mode

In [None]:
def compute_drift(p, J, eta=0.1):
    """
    Compute drift vector at point p on simplex.
    p = [p_good1, p_good2, p_bad]
    """
    p = np.array(p)
    p = np.clip(p, 1e-6, 1 - 1e-6)
    p = p / p.sum()
    
    # Reward structure: good modes get +1, bad mode gets -1 (on average, modulated by J)
    # Simplified drift: bad mode mass decays proportional to J
    p_bad = p[2]
    p_good = p[0] + p[1]
    
    # Drift in bad mode direction
    drift_bad = -eta * J * p_bad * (1 - p_bad)
    
    # Distribute gain to good modes proportionally
    if p_good > 1e-6:
        drift_good1 = -drift_bad * (p[0] / p_good)
        drift_good2 = -drift_bad * (p[1] / p_good)
    else:
        drift_good1 = drift_good2 = -drift_bad / 2
    
    return np.array([drift_good1, drift_good2, drift_bad])

## Simulate Trajectories

In [None]:
def simulate_trajectory(p0, J, T=100, dt=0.5, eta=0.1):
    """Simulate a trajectory on the simplex."""
    p = np.array(p0, dtype=float)
    trajectory = [p.copy()]
    
    for _ in range(int(T / dt)):
        drift = compute_drift(p, J, eta)
        p = p + drift * dt
        p = np.clip(p, 1e-6, 1)
        p = p / p.sum()
        trajectory.append(p.copy())
    
    return np.array(trajectory)

## Visualize Flow for Different J Values

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

J_values = [1.0, 0.0, -0.5]
titles = ['J = 1.0 (Learning)', 'J = 0 (Neutral)', 'J = -0.5 (Anti-learning)']
colors = ['blue', 'gray', 'red']

# Initial points spread across simplex
np.random.seed(42)
n_traj = 8
init_points = []
for _ in range(n_traj):
    p = np.random.dirichlet([1, 1, 1])
    init_points.append(p)

for ax, J, title, color in zip(axes, J_values, titles, colors):
    draw_simplex(ax)
    ax.set_title(title, fontsize=14, fontweight='bold')
    
    for p0 in init_points:
        traj = simulate_trajectory(p0, J)
        xy = np.array([to_cartesian(p) for p in traj])
        
        # Plot trajectory with fading alpha
        ax.plot(xy[:, 0], xy[:, 1], color=color, alpha=0.6, linewidth=1.5)
        ax.scatter(xy[0, 0], xy[0, 1], color='green', s=30, zorder=5)  # Start
        ax.scatter(xy[-1, 0], xy[-1, 1], color=color, s=50, zorder=5, marker='x')  # End

plt.suptitle('Simplex Flow: How Probability Mass Moves Under Noisy Rewards', fontsize=14)
plt.tight_layout()
plt.show()

## Interpretation

- **J > 0 (Blue)**: All trajectories flow towards the bottom edge (good modes). The bad mode probability decays.
- **J = 0 (Gray)**: No net flow. Trajectories stay roughly where they started.
- **J < 0 (Red)**: Trajectories flow towards the top vertex (bad mode). The model anti-learns.