# E[Q] Visualization Suite

Beautiful visualizations of the E[Q] imperfect-information learning pipeline.

**Outputs:**
- Static snapshots of game decisions with rendered domino graphics
- 3D animated trajectory through decision space (rotating, color-coded by E[Q])
- 3D animated belief cloud showing uncertainty collapse as game progresses

All animations exported as GIFs for easy phone sharing via MMS.

In [None]:
# === CONFIGURATION ===
PROJECT_ROOT = "/home/jason/v2/mk5-tailwind"
DATASET_PATH = f"{PROJECT_ROOT}/forge/data/eq_v2.2_250g.pt"
OUTPUT_DIR = f"{PROJECT_ROOT}/forge/eq/renders"

# === Setup ===
import sys
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.animation import FuncAnimation
from matplotlib.colors import LinearSegmentedColormap
from pathlib import Path
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA

# Ensure output directory exists
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Style
plt.style.use('dark_background')
plt.rcParams['figure.facecolor'] = '#0d1117'
plt.rcParams['axes.facecolor'] = '#0d1117'
plt.rcParams['savefig.facecolor'] = '#0d1117'
plt.rcParams['font.family'] = 'monospace'

print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Load E[Q] dataset
print(f"Loading {DATASET_PATH}...")
data = torch.load(DATASET_PATH, weights_only=False)

print(f"Loaded {len(data['transcript_tokens']):,} decisions")
print(f"Keys: {list(data.keys())}")

# Extract arrays
tokens = data['transcript_tokens']
lengths = data['transcript_lengths']
e_q_mean = data['e_q_mean']
e_q_var = data.get('e_q_var')
legal_mask = data['legal_mask']
action_taken = data['action_taken']
game_idx = data['game_idx']
decision_idx = data['decision_idx']
ess = data.get('ess')
metadata = data.get('metadata', {})

print(f"\nDataset version: {metadata.get('version', 'unknown')}")
print(f"Games: {game_idx.max().item() + 1}")

## 1. Domino Rendering

Beautiful domino graphics with actual pip dots.

In [None]:
# Pip positions for each value (0-6) on a unit square
PIP_POSITIONS = {
    0: [],
    1: [(0.5, 0.5)],
    2: [(0.25, 0.75), (0.75, 0.25)],
    3: [(0.25, 0.75), (0.5, 0.5), (0.75, 0.25)],
    4: [(0.25, 0.25), (0.25, 0.75), (0.75, 0.25), (0.75, 0.75)],
    5: [(0.25, 0.25), (0.25, 0.75), (0.5, 0.5), (0.75, 0.25), (0.75, 0.75)],
    6: [(0.25, 0.2), (0.25, 0.5), (0.25, 0.8), (0.75, 0.2), (0.75, 0.5), (0.75, 0.8)],
}

def draw_domino(ax, x, y, high, low, width=1.0, height=2.0, 
                face_color='#1a1a2e', edge_color='#4a9eff', pip_color='white',
                highlight=False, alpha=1.0):
    """Draw a domino at (x, y) with given pips."""
    if highlight:
        edge_color = '#ffcc00'
        face_color = '#2a2a4e'
    
    # Main rectangle
    rect = mpatches.FancyBboxPatch(
        (x, y), width, height,
        boxstyle="round,pad=0.02,rounding_size=0.1",
        facecolor=face_color, edgecolor=edge_color, linewidth=2,
        alpha=alpha
    )
    ax.add_patch(rect)
    
    # Dividing line
    ax.plot([x + 0.1, x + width - 0.1], [y + height/2, y + height/2], 
            color=edge_color, linewidth=1.5, alpha=alpha)
    
    # Draw pips for top half (high value)
    pip_radius = width * 0.08
    for px, py in PIP_POSITIONS[high]:
        circle = plt.Circle(
            (x + px * width, y + height/2 + py * height/2),
            pip_radius, color=pip_color, alpha=alpha
        )
        ax.add_patch(circle)
    
    # Draw pips for bottom half (low value)
    for px, py in PIP_POSITIONS[low]:
        circle = plt.Circle(
            (x + px * width, y + py * height/2),
            pip_radius, color=pip_color, alpha=alpha
        )
        ax.add_patch(circle)


def domino_id_to_pips(domino_id):
    """Convert domino ID (0-27) to (high, low) pips."""
    hi = 0
    while (hi + 1) * (hi + 2) // 2 <= domino_id:
        hi += 1
    lo = domino_id - hi * (hi + 1) // 2
    return hi, lo


# Demo: Draw all 28 dominoes
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(-0.5, 14)
ax.set_ylim(-0.5, 5)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('All 28 Dominoes', fontsize=16, color='white', pad=20)

for domino_id in range(28):
    high, low = domino_id_to_pips(domino_id)
    row = domino_id // 7
    col = domino_id % 7
    draw_domino(ax, col * 2, (3 - row) * 2.5, high, low, highlight=(high == low))

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/all_dominoes.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {OUTPUT_DIR}/all_dominoes.png")

## 2. Decision Snapshot

Single-frame visualization of a game decision with E[Q] values.

In [None]:
# Token feature indices (from transcript_tokenize.py)
FEAT_HIGH = 0
FEAT_LOW = 1
FEAT_PLAYER = 4
FEAT_IS_IN_HAND = 5
FEAT_DECL = 6
FEAT_TOKEN_TYPE = 7
TOKEN_TYPE_HAND = 1
TOKEN_TYPE_PLAY = 2

DECL_NAMES = {
    0: "Blanks", 1: "Ones", 2: "Twos", 3: "Threes", 4: "Fours",
    5: "Fives", 6: "Sixes", 7: "Doubles", 8: "Doubles (suit)", 9: "No Trump"
}
PLAYER_NAMES = ["ME", "LEFT", "PARTNER", "RIGHT"]


def decode_decision(tok, length):
    """Decode transcript tokens into game state."""
    tok = tok[:length]
    decl_id = tok[0, FEAT_DECL].item()
    
    hand = []
    plays = []
    
    for i in range(length):
        tt = tok[i, FEAT_TOKEN_TYPE].item()
        if tt == TOKEN_TYPE_HAND:
            hand.append((tok[i, FEAT_HIGH].item(), tok[i, FEAT_LOW].item()))
        elif tt == TOKEN_TYPE_PLAY:
            plays.append((
                tok[i, FEAT_PLAYER].item(),
                tok[i, FEAT_HIGH].item(),
                tok[i, FEAT_LOW].item()
            ))
    
    return {'decl_id': decl_id, 'hand': hand, 'plays': plays}


def render_decision_snapshot(idx, save_path=None):
    """Render a beautiful snapshot of a single decision."""
    # Decode state
    tok = tokens[idx]
    length = lengths[idx].item()
    state = decode_decision(tok, length)
    
    eq = e_q_mean[idx]
    var = e_q_var[idx] if e_q_var is not None else None
    mask = legal_mask[idx]
    action = action_taken[idx].item()
    game = game_idx[idx].item()
    decision = decision_idx[idx].item()
    
    hand = state['hand']
    plays = state['plays']
    decl_id = state['decl_id']
    
    # Calculate trick info
    n_complete_tricks = len(plays) // 4
    trick_num = n_complete_tricks + 1
    current_trick = plays[-(len(plays) % 4):] if len(plays) % 4 != 0 else []
    
    # Create figure
    fig = plt.figure(figsize=(16, 10))
    
    # Custom grid
    gs = fig.add_gridspec(3, 2, height_ratios=[1, 2, 2], width_ratios=[1.5, 1],
                          hspace=0.3, wspace=0.2)
    
    # === Header ===
    ax_header = fig.add_subplot(gs[0, :])
    ax_header.axis('off')
    ax_header.text(0.5, 0.7, f"Game {game + 1}  •  Decision {decision + 1}/28  •  Trick {trick_num}/7",
                   ha='center', va='center', fontsize=20, color='white', weight='bold')
    ax_header.text(0.5, 0.2, f"Declaration: {DECL_NAMES[decl_id]}",
                   ha='center', va='center', fontsize=16, color='#4a9eff')
    
    # === Hand Display ===
    ax_hand = fig.add_subplot(gs[1, 0])
    ax_hand.set_xlim(-0.5, len(hand) * 1.5 + 0.5)
    ax_hand.set_ylim(-0.5, 3)
    ax_hand.set_aspect('equal')
    ax_hand.axis('off')
    ax_hand.set_title('My Hand', fontsize=14, color='white', pad=10)
    
    for i, (high, low) in enumerate(hand):
        is_selected = (i == action)
        is_legal = mask[i].item() if i < len(mask) else False
        alpha = 1.0 if is_legal else 0.4
        draw_domino(ax_hand, i * 1.5, 0.5, high, low, highlight=is_selected, alpha=alpha)
        if is_selected:
            ax_hand.text(i * 1.5 + 0.5, -0.3, "▲ PLAYED", ha='center', 
                        fontsize=10, color='#ffcc00', weight='bold')
    
    # === E[Q] Bar Chart ===
    ax_eq = fig.add_subplot(gs[1, 1])
    
    # Get legal actions sorted by E[Q]
    legal_indices = [i for i in range(len(hand)) if mask[i].item()]
    sorted_legal = sorted(legal_indices, key=lambda i: eq[i].item(), reverse=True)
    
    # Color gradient: green (high) to red (low)
    cmap = LinearSegmentedColormap.from_list('eq', ['#ff4444', '#888888', '#44ff44'])
    eq_vals = [eq[i].item() for i in sorted_legal]
    eq_min, eq_max = min(eq_vals), max(eq_vals)
    eq_range = eq_max - eq_min if eq_max != eq_min else 1
    
    y_pos = np.arange(len(sorted_legal))
    colors = [cmap((eq[i].item() - eq_min) / eq_range) for i in sorted_legal]
    
    bars = ax_eq.barh(y_pos, [eq[i].item() for i in sorted_legal], 
                      color=colors, edgecolor='white', linewidth=0.5)
    
    # Error bars for uncertainty
    if var is not None:
        stds = [np.sqrt(max(0, var[i].item())) for i in sorted_legal]
        ax_eq.errorbar([eq[i].item() for i in sorted_legal], y_pos, 
                       xerr=stds, fmt='none', color='white', alpha=0.5, capsize=3)
    
    # Labels
    labels = [f"[{hand[i][0]}:{hand[i][1]}]" for i in sorted_legal]
    ax_eq.set_yticks(y_pos)
    ax_eq.set_yticklabels(labels, fontsize=12, color='white')
    ax_eq.set_xlabel('E[Q] (expected points)', fontsize=12, color='white')
    ax_eq.set_title('Action Values', fontsize=14, color='white', pad=10)
    ax_eq.axvline(0, color='white', linestyle='--', alpha=0.3)
    ax_eq.tick_params(colors='white')
    
    # Highlight selected
    for i, idx_action in enumerate(sorted_legal):
        if idx_action == action:
            bars[i].set_edgecolor('#ffcc00')
            bars[i].set_linewidth(3)
            ax_eq.text(eq[idx_action].item() + 1, i, '← PLAYED', 
                      va='center', fontsize=10, color='#ffcc00')
    
    # === Current Trick ===
    ax_trick = fig.add_subplot(gs[2, 0])
    ax_trick.set_xlim(-0.5, 8)
    ax_trick.set_ylim(-0.5, 3)
    ax_trick.set_aspect('equal')
    ax_trick.axis('off')
    
    if current_trick:
        ax_trick.set_title(f'Current Trick ({len(current_trick)}/4 played)', 
                          fontsize=14, color='white', pad=10)
        for i, (player, high, low) in enumerate(current_trick):
            draw_domino(ax_trick, i * 1.8, 0.5, high, low)
            ax_trick.text(i * 1.8 + 0.5, 2.8, PLAYER_NAMES[player], 
                         ha='center', fontsize=10, color='#4a9eff')
        # Show "ME: ?" for the decision point
        ax_trick.text(len(current_trick) * 1.8 + 0.5, 1.5, "?", 
                     ha='center', fontsize=24, color='#ffcc00')
        ax_trick.text(len(current_trick) * 1.8 + 0.5, 2.8, "ME", 
                     ha='center', fontsize=10, color='#ffcc00')
    else:
        ax_trick.set_title('Leading new trick...', fontsize=14, color='white', pad=10)
        ax_trick.text(2, 1.5, "Your lead!", ha='center', fontsize=16, color='#44ff44')
    
    # === Stats Panel ===
    ax_stats = fig.add_subplot(gs[2, 1])
    ax_stats.axis('off')
    
    stats_text = f"""Statistics
─────────────────
Legal actions: {len(legal_indices)}
Best E[Q]: {eq_max:+.1f} pts
Worst E[Q]: {eq_min:+.1f} pts
Gap: {eq_max - eq_min:.1f} pts
"""
    if var is not None:
        best_idx = sorted_legal[0]
        stats_text += f"Best σ: ±{np.sqrt(max(0, var[best_idx].item())):.1f} pts\n"
    if ess is not None:
        stats_text += f"\nESS: {ess[idx].item():.1f}"
    
    ax_stats.text(0.1, 0.9, stats_text, transform=ax_stats.transAxes,
                  fontsize=12, color='white', va='top', family='monospace',
                  bbox=dict(boxstyle='round', facecolor='#1a1a2e', edgecolor='#4a9eff'))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()
    return fig


# Render a sample decision
render_decision_snapshot(100, save_path=f"{OUTPUT_DIR}/decision_snapshot.png")

## 3. 3D Game Trajectory Animation

Embed transcript tokens in 3D and trace the game's path through decision space.

In [None]:
def create_trajectory_animation(game_id=0, save_path=None, fps=10):
    """Create 3D trajectory animation for a single game."""
    from matplotlib.animation import PillowWriter
    
    # Find all decisions for this game
    game_mask = (game_idx == game_id)
    indices = torch.where(game_mask)[0].numpy()
    
    if len(indices) == 0:
        print(f"No data for game {game_id}")
        return None
    
    print(f"Game {game_id}: {len(indices)} decisions")
    
    # Get token embeddings (flatten and PCA to 3D)
    game_tokens = tokens[indices].numpy()
    game_lengths = lengths[indices].numpy()
    game_eq = e_q_mean[indices].numpy()
    game_var = e_q_var[indices].numpy() if e_q_var is not None else None
    game_ess = ess[indices].numpy() if ess is not None else None
    
    # Create embeddings: mean pool the tokens
    embeddings = []
    for i, (tok, length) in enumerate(zip(game_tokens, game_lengths)):
        emb = tok[:length].mean(axis=0)
        embeddings.append(emb)
    embeddings = np.array(embeddings)
    
    # PCA to 3D
    pca = PCA(n_components=3)
    coords_3d = pca.fit_transform(embeddings)
    
    # Get best E[Q] per decision (for coloring)
    best_eq = []
    for i in range(len(indices)):
        mask = legal_mask[indices[i]].numpy()
        eq_vals = game_eq[i]
        legal_eq = eq_vals[mask]
        best_eq.append(legal_eq.max() if len(legal_eq) > 0 else 0)
    best_eq = np.array(best_eq)
    
    # Uncertainty for point sizes
    if game_var is not None:
        uncertainties = []
        for i in range(len(indices)):
            mask = legal_mask[indices[i]].numpy()
            var_vals = game_var[i]
            legal_var = var_vals[mask]
            uncertainties.append(np.sqrt(legal_var.mean()) if len(legal_var) > 0 else 0)
        uncertainties = np.array(uncertainties)
        point_sizes = 50 + uncertainties * 10
    else:
        point_sizes = np.full(len(indices), 80)
    
    # Normalize colors
    eq_norm = (best_eq - best_eq.min()) / (best_eq.max() - best_eq.min() + 1e-8)
    
    # Create figure
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Styling
    ax.set_facecolor('#0d1117')
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('#333333')
    ax.yaxis.pane.set_edgecolor('#333333')
    ax.zaxis.pane.set_edgecolor('#333333')
    ax.tick_params(colors='#666666', labelsize=8)
    ax.set_xlabel('PC1', color='#666666', fontsize=9)
    ax.set_ylabel('PC2', color='#666666', fontsize=9)
    ax.set_zlabel('PC3', color='#666666', fontsize=9)
    
    # Color map
    cmap = plt.cm.plasma
    
    # Set axis limits
    margin = 0.1
    ax.set_xlim(coords_3d[:, 0].min() - margin, coords_3d[:, 0].max() + margin)
    ax.set_ylim(coords_3d[:, 1].min() - margin, coords_3d[:, 1].max() + margin)
    ax.set_zlim(coords_3d[:, 2].min() - margin, coords_3d[:, 2].max() + margin)
    
    def update(frame):
        ax.clear()
        
        # Re-apply styling after clear
        ax.set_facecolor('#0d1117')
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.xaxis.pane.set_edgecolor('#333333')
        ax.yaxis.pane.set_edgecolor('#333333')
        ax.zaxis.pane.set_edgecolor('#333333')
        ax.tick_params(colors='#666666', labelsize=8)
        ax.set_xlabel('PC1', color='#666666', fontsize=9)
        ax.set_ylabel('PC2', color='#666666', fontsize=9)
        ax.set_zlabel('PC3', color='#666666', fontsize=9)
        ax.set_xlim(coords_3d[:, 0].min() - margin, coords_3d[:, 0].max() + margin)
        ax.set_ylim(coords_3d[:, 1].min() - margin, coords_3d[:, 1].max() + margin)
        ax.set_zlim(coords_3d[:, 2].min() - margin, coords_3d[:, 2].max() + margin)
        
        n = frame + 1
        xs, ys, zs = coords_3d[:n, 0], coords_3d[:n, 1], coords_3d[:n, 2]
        
        # Draw trajectory line
        ax.plot(xs, ys, zs, color='#4a9eff', linewidth=2, alpha=0.6)
        
        # Draw points
        colors = cmap(eq_norm[:n])
        ax.scatter(xs, ys, zs, c=colors, s=point_sizes[:n], alpha=0.8, 
                   edgecolors='white', linewidth=0.5)
        
        # Highlight current point
        ax.scatter([xs[-1]], [ys[-1]], [zs[-1]], c='#ffcc00', s=200, 
                   marker='*', edgecolors='white', linewidth=2, zorder=10)
        
        # Rotate view
        ax.view_init(elev=20, azim=frame * 4)
        
        # Title
        trick = (frame // 4) + 1
        ax.set_title(f'Game {game_id + 1} • Decision {frame + 1}/28 • Trick {trick}/7\n'
                    f'E[Q] = {best_eq[frame]:+.1f} pts',
                    fontsize=12, color='white', pad=10)
        
        return []
    
    # Create animation
    n_frames = len(indices)
    anim = FuncAnimation(fig, update, frames=n_frames, interval=1000//fps, blit=False)
    
    if save_path:
        print(f"Saving animation to {save_path}...")
        # Use Pillow for GIF output (no ffmpeg needed)
        writer = PillowWriter(fps=fps)
        anim.save(save_path, writer=writer, dpi=80,
                  savefig_kwargs={'facecolor': '#0d1117'})
        print(f"Saved: {save_path}")
    
    plt.close()
    return anim


# Create trajectory animation (GIF format for MMS compatibility)
create_trajectory_animation(game_id=0, save_path=f"{OUTPUT_DIR}/trajectory_game0.gif", fps=6)

## 4. 3D Belief Cloud Animation

Visualize how uncertainty collapses as the game progresses.

We use ESS (effective sample size) and variance to simulate the "belief cloud" - 
showing how concentrated vs diffuse our beliefs are about the true game state.

In [None]:
def create_belief_cloud_animation(game_id=0, n_particles=150, save_path=None, fps=6):
    """Create 3D belief cloud animation showing uncertainty evolution."""
    from matplotlib.animation import PillowWriter
    
    # Find all decisions for this game
    game_mask = (game_idx == game_id)
    indices = torch.where(game_mask)[0].numpy()
    
    if len(indices) == 0:
        print(f"No data for game {game_id}")
        return None
    
    print(f"Game {game_id}: {len(indices)} decisions")
    
    # Get uncertainty metrics
    game_var = e_q_var[indices].numpy() if e_q_var is not None else None
    game_ess = ess[indices].numpy() if ess is not None else None
    game_eq = e_q_mean[indices].numpy()
    
    # Compute mean uncertainty per decision
    uncertainties = []
    for i in range(len(indices)):
        mask = legal_mask[indices[i]].numpy()
        if game_var is not None:
            var_vals = game_var[i][mask]
            uncertainties.append(np.sqrt(var_vals.mean()) if len(var_vals) > 0 else 5)
        else:
            uncertainties.append(5)
    uncertainties = np.array(uncertainties)
    
    # Normalize uncertainty for cloud spread
    max_spread = 2.0
    min_spread = 0.15
    u_norm = uncertainties / (uncertainties.max() + 1e-8)
    spreads = min_spread + u_norm * (max_spread - min_spread)
    
    # Get best E[Q] for center position
    centers = []
    for i in range(len(indices)):
        mask = legal_mask[indices[i]].numpy()
        eq_vals = game_eq[i][mask]
        centers.append(eq_vals.mean() if len(eq_vals) > 0 else 0)
    centers = np.array(centers)
    
    # Create figure
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Random seed for reproducible particles
    rng = np.random.default_rng(42)
    
    def update(frame):
        ax.clear()
        
        # Styling
        ax.set_facecolor('#0d1117')
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.xaxis.pane.set_edgecolor('#333333')
        ax.yaxis.pane.set_edgecolor('#333333')
        ax.zaxis.pane.set_edgecolor('#333333')
        ax.tick_params(colors='#666666', labelsize=8)
        ax.set_xlabel('Belief X', color='#666666', fontsize=9)
        ax.set_ylabel('Belief Y', color='#666666', fontsize=9)
        ax.set_zlabel('E[Q]', color='#666666', fontsize=9)
        ax.set_xlim(-3, 3)
        ax.set_ylim(-3, 3)
        ax.set_zlim(centers.min() - 5, centers.max() + 5)
        
        # Current state
        spread = spreads[frame]
        center_eq = centers[frame]
        current_ess = game_ess[frame] if game_ess is not None else 100
        
        # Generate particle cloud (consistent per frame)
        rng_frame = np.random.default_rng(42 + frame)
        xs = rng_frame.normal(0, spread, n_particles)
        ys = rng_frame.normal(0, spread, n_particles)
        zs = rng_frame.normal(center_eq, spread * 2, n_particles)
        
        # Color by distance from center (simulates weight)
        dists = np.sqrt(xs**2 + ys**2 + (zs - center_eq)**2)
        weights = np.exp(-dists / (spread + 0.1))
        weights = weights / weights.max()
        
        # Particle colors (cyan gradient based on weight)
        colors = plt.cm.Blues(0.3 + weights * 0.7)
        sizes = 15 + weights * 50
        
        # Draw particles
        ax.scatter(xs, ys, zs, c=colors, s=sizes, alpha=0.6, edgecolors='none')
        
        # Draw center (mean belief) as glowing star
        ax.scatter([0], [0], [center_eq], c='#ffcc00', s=400, marker='*', 
                   edgecolors='white', linewidth=2, zorder=10)
        
        # Draw trajectory so far
        if frame > 0:
            traj_z = centers[:frame+1]
            traj_x = np.zeros(frame+1)
            traj_y = np.zeros(frame+1)
            ax.plot(traj_x, traj_y, traj_z, color='#ffcc00', linewidth=3, alpha=0.8)
        
        # Rotate view
        ax.view_init(elev=25, azim=frame * 5 + 45)
        
        # Title
        trick = (frame // 4) + 1
        ess_str = f"ESS={current_ess:.0f}" if game_ess is not None else ""
        ax.set_title(f'Belief Cloud • Decision {frame + 1}/28 • Trick {trick}/7\n'
                    f'Uncertainty: ±{uncertainties[frame]:.1f} pts  {ess_str}',
                    fontsize=12, color='white', pad=10)
        
        return []
    
    # Create animation
    n_frames = len(indices)
    anim = FuncAnimation(fig, update, frames=n_frames, interval=1000//fps, blit=False)
    
    if save_path:
        print(f"Saving animation to {save_path}...")
        writer = PillowWriter(fps=fps)
        anim.save(save_path, writer=writer, dpi=80,
                  savefig_kwargs={'facecolor': '#0d1117'})
        print(f"Saved: {save_path}")
    
    plt.close()
    return anim


# Create belief cloud animation (GIF format for MMS)
create_belief_cloud_animation(game_id=0, save_path=f"{OUTPUT_DIR}/belief_cloud_game0.gif", fps=5)

## 5. Generate Gallery

Create multiple snapshots and animations for different games.

In [None]:
# Generate snapshots for interesting decisions
# Find high-uncertainty decisions (more interesting)
if e_q_var is not None:
    # Mean variance per decision
    mean_var = []
    for i in range(len(tokens)):
        mask = legal_mask[i].numpy()
        var_vals = e_q_var[i].numpy()[mask]
        mean_var.append(var_vals.mean() if len(var_vals) > 0 else 0)
    mean_var = np.array(mean_var)
    
    # Top 5 highest uncertainty decisions
    top_uncertain = np.argsort(mean_var)[-5:]
    print("Generating snapshots for high-uncertainty decisions...")
    for i, idx in enumerate(top_uncertain):
        render_decision_snapshot(idx, save_path=f"{OUTPUT_DIR}/snapshot_uncertain_{i+1}.png")

print("\nDone! Check the renders directory.")

In [None]:
# List all generated files
from pathlib import Path
output_path = Path(OUTPUT_DIR)

print("Generated files:")
print("=" * 50)
for f in sorted(output_path.glob("*")):
    size_mb = f.stat().st_size / (1024 * 1024)
    print(f"  {f.name:40} {size_mb:6.2f} MB")

print("\n" + "=" * 50)
print("MMS Tips:")
print("  - GIFs under 3MB work well for MMS")
print("  - If too large, reduce fps or n_particles")
print("  - PNGs can be sent as regular images")
print(f"\nFiles ready at: {OUTPUT_DIR}")