# Variance Aggregation Backup: Interactive Demo

This notebook demonstrates the **variance aggregation backup** mechanism in Bayesian BAI-MCTS.

## Key Concepts

1. **Aggregated Beliefs**: Parent nodes maintain beliefs about their best child using optimality weights
2. **Optimality Weights**: Probability each child is optimal via pairwise Gaussian CDF
3. **Variance Propagation**: `agg_sigma_sq` from children becomes observation variance for parents
4. **Ensemble Effect**: Variance decreases as 1/√N due to squared weight aggregation

**Setup:** Use `Runtime > Change runtime type > GPU` for faster model inference.

In [None]:
# Clone repository
!git clone https://github.com/caldred/nanozero.git 2>/dev/null || echo "Already cloned"
%cd nanozero

In [None]:
import sys
sys.path.insert(0, '.')

import math
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Import from nanozero
from nanozero.bayesian_mcts import BayesianNode, BayesianMCTS, normal_cdf
from nanozero.config import BayesianMCTSConfig
from nanozero.game import get_game

print("Imports successful!")

---
## Part 1: BayesianNode Basics

A `BayesianNode` maintains:
- `mu`, `sigma_sq`: Gaussian belief about this node's value
- `agg_mu`, `agg_sigma_sq`: Aggregated belief from children (expected value of best child)

In [None]:
def print_node(node, name="Node", indent=0):
    """Pretty print a BayesianNode."""
    prefix = "  " * indent
    print(f"{prefix}{name}:")
    print(f"{prefix}  belief: μ={node.mu:.4f}, σ²={node.sigma_sq:.4f} (σ={math.sqrt(node.sigma_sq):.4f})")
    if node.agg_mu is not None:
        print(f"{prefix}  aggregated: μ_agg={node.agg_mu:.4f}, σ²_agg={node.agg_sigma_sq:.4f}")
    print(f"{prefix}  precision: {node.precision():.2f}")

# Create a simple node
node = BayesianNode(prior=0.3, mu=0.5, sigma_sq=0.25)
print_node(node, "Initial Node")

# Update with an observation
print("\n--- Updating with observation value=0.8, obs_var=0.1 ---\n")
node.update(value=0.8, obs_var=0.1)
print_node(node, "After Update")

---
## Part 2: Aggregation with Multiple Children

When a node has children, `aggregate_children()` computes:
1. **Optimality scores** via pairwise Gaussian CDF comparisons
2. **Weights** = normalized scores (soft-pruned)
3. **Aggregated mean** = weighted average of child means
4. **Aggregated variance** = `Σ w²[σ² + (μ - μ_agg)²]`

In [None]:
def create_tree_with_children(child_beliefs):
    """
    Create a parent node with children having specified beliefs.
    
    child_beliefs: list of (mu, sigma_sq) tuples
    """
    parent = BayesianNode()
    for i, (mu, sigma_sq) in enumerate(child_beliefs):
        parent.children[i] = BayesianNode(prior=1.0/len(child_beliefs), mu=mu, sigma_sq=sigma_sq)
    return parent

# Scenario 1: Clear winner (child 0 much better)
print("=" * 60)
print("Scenario 1: Clear Winner")
print("=" * 60)
# Note: children store values from child's perspective (opponent)
# So a child with mu=-0.8 means the parent expects value +0.8 from that action
parent1 = create_tree_with_children([
    (-0.8, 0.04),   # Child 0: looks great for parent (mu=+0.8 from parent view)
    (-0.2, 0.04),   # Child 1: mediocre
    (-0.1, 0.04),   # Child 2: mediocre
])

print("\nBefore aggregation:")
for a, c in parent1.children.items():
    print(f"  Child {a}: μ={c.mu:.3f}, σ²={c.sigma_sq:.3f} → parent view: {-c.mu:.3f}")

parent1.aggregate_children(prune_threshold=0.01)
print(f"\nAfter aggregation:")
print(f"  agg_mu={parent1.agg_mu:.4f} (expected value of best child)")
print(f"  agg_sigma_sq={parent1.agg_sigma_sq:.6f} (uncertainty about best)")

In [None]:
# Scenario 2: Uncertain (children have similar means but high variance)
print("\n" + "=" * 60)
print("Scenario 2: High Uncertainty")
print("=" * 60)

parent2 = create_tree_with_children([
    (-0.4, 0.25),   # High variance
    (-0.35, 0.25),
    (-0.3, 0.25),
])

print("\nBefore aggregation (high variance children):")
for a, c in parent2.children.items():
    print(f"  Child {a}: μ={c.mu:.3f}, σ²={c.sigma_sq:.3f}")

parent2.aggregate_children(prune_threshold=0.01)
print(f"\nAfter aggregation:")
print(f"  agg_mu={parent2.agg_mu:.4f}")
print(f"  agg_sigma_sq={parent2.agg_sigma_sq:.6f} (higher than Scenario 1!)")

In [None]:
# Scenario 3: Visualize optimality weights
print("\n" + "=" * 60)
print("Scenario 3: Optimality Weights Visualization")
print("=" * 60)

def compute_optimality_weights(child_beliefs, prune_threshold=0.01):
    """Compute optimality weights for visualization."""
    # Get beliefs from parent's perspective
    mus = np.array([-mu for mu, _ in child_beliefs])
    sigma_sqs = np.array([s for _, s in child_beliefs])
    n = len(child_beliefs)
    
    sorted_idx = np.argsort(mus)[::-1]
    leader_idx = sorted_idx[0]
    challenger_idx = sorted_idx[1]
    
    scores = np.zeros(n)
    mu_L, sigma_sq_L = mus[leader_idx], sigma_sqs[leader_idx]
    mu_C, sigma_sq_C = mus[challenger_idx], sigma_sqs[challenger_idx]
    
    for i in range(n):
        if i == leader_idx:
            diff = mu_L - mu_C
            std = math.sqrt(sigma_sq_L + sigma_sq_C)
        else:
            diff = mus[i] - mu_L
            std = math.sqrt(sigma_sqs[i] + sigma_sq_L)
        scores[i] = normal_cdf(diff / std) if std > 1e-10 else (1.0 if diff > 0 else 0.0)
    
    scores[scores < prune_threshold] = 0.0
    weights = scores / (scores.sum() + 1e-10)
    return mus, scores, weights

# Different scenarios
scenarios = {
    "Clear Winner": [(-0.8, 0.04), (-0.2, 0.04), (-0.1, 0.04)],
    "Close Race": [(-0.52, 0.04), (-0.50, 0.04), (-0.48, 0.04)],
    "High Variance": [(-0.6, 0.25), (-0.5, 0.25), (-0.4, 0.25)],
    "Mixed": [(-0.7, 0.01), (-0.5, 0.16), (-0.3, 0.04)],
}

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for ax, (name, beliefs) in zip(axes, scenarios.items()):
    mus, scores, weights = compute_optimality_weights(beliefs)
    
    x = np.arange(len(beliefs))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, scores, width, label='Optimality Score', alpha=0.8)
    bars2 = ax.bar(x + width/2, weights, width, label='Normalized Weight', alpha=0.8)
    
    ax.set_xlabel('Action')
    ax.set_ylabel('Probability')
    ax.set_title(f'{name}\n(μ from parent view: {[f"{m:.2f}" for m in mus]})')
    ax.set_xticks(x)
    ax.set_xticklabels([f'Action {i}' for i in range(len(beliefs))])
    ax.legend()
    ax.set_ylim(0, 1.1)
    ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

---
## Part 3: Backup with Variance Propagation

The key insight: **`agg_sigma_sq` from children becomes `obs_var` for parents.**

This means:
- Uncertain subtrees contribute observations with higher variance
- Confident subtrees contribute observations with lower variance
- Parent precision reflects true epistemic uncertainty about the best action

In [None]:
def simulate_backup_chain(depths=3, children_per_node=3, n_backups=20):
    """
    Simulate backup chain and track variance propagation.
    
    Creates a linear chain of nodes (like a search path) and
    repeatedly backs up values, tracking how variance changes.
    """
    sigma_0 = 0.5
    base_obs_var = 0.25
    
    # Create chain of nodes (root -> level1 -> level2 -> leaf)
    nodes = []
    for d in range(depths):
        node = BayesianNode(mu=0.0, sigma_sq=sigma_0**2)
        # Add children
        for i in range(children_per_node):
            child_mu = np.random.normal(0, 0.3)
            node.children[i] = BayesianNode(mu=child_mu, sigma_sq=sigma_0**2)
        node.aggregate_children()
        nodes.append(node)
    
    # Track metrics over backups
    history = {
        'backup': [],
        'root_agg_mu': [],
        'root_agg_sigma_sq': [],
        'level1_agg_sigma_sq': [],
        'obs_var_at_root': [],
    }
    
    for backup_num in range(n_backups):
        # Simulate a value coming from the leaf
        leaf_value = np.random.normal(0.5, 0.3)  # Pretend leaf says +0.5 ish
        
        # Backup through the chain
        value = leaf_value
        obs_var = base_obs_var
        
        for d in range(depths - 1, -1, -1):
            node = nodes[d]
            # Update a random child (simulating MCTS path)
            action = np.random.randint(0, children_per_node)
            node.children[action].update(value, obs_var, min_var=1e-6)
            
            # Recompute aggregation
            node.aggregate_children()
            
            # Propagate
            value = -node.agg_mu
            obs_var = node.agg_sigma_sq
        
        # Record
        history['backup'].append(backup_num)
        history['root_agg_mu'].append(nodes[0].agg_mu)
        history['root_agg_sigma_sq'].append(nodes[0].agg_sigma_sq)
        if depths > 1:
            history['level1_agg_sigma_sq'].append(nodes[1].agg_sigma_sq)
        history['obs_var_at_root'].append(obs_var)
    
    return history

# Run simulation
np.random.seed(42)
history = simulate_backup_chain(depths=3, children_per_node=3, n_backups=50)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Root aggregated mean
axes[0].plot(history['backup'], history['root_agg_mu'], 'b-', linewidth=2)
axes[0].set_xlabel('Backup #')
axes[0].set_ylabel('Root agg_mu')
axes[0].set_title('Root Aggregated Mean Over Backups')
axes[0].grid(True, alpha=0.3)

# Plot 2: Aggregated variance at different levels
axes[1].plot(history['backup'], history['root_agg_sigma_sq'], 'b-', linewidth=2, label='Root')
axes[1].plot(history['backup'], history['level1_agg_sigma_sq'], 'r--', linewidth=2, label='Level 1')
axes[1].set_xlabel('Backup #')
axes[1].set_ylabel('agg_sigma_sq')
axes[1].set_title('Aggregated Variance at Different Levels')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: Observation variance propagated to root
axes[2].plot(history['backup'], history['obs_var_at_root'], 'g-', linewidth=2)
axes[2].axhline(y=0.25, color='gray', linestyle='--', label='Initial obs_var')
axes[2].set_xlabel('Backup #')
axes[2].set_ylabel('obs_var at root level')
axes[2].set_title('Propagated Observation Variance')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal root agg_sigma_sq: {history['root_agg_sigma_sq'][-1]:.6f}")
print(f"Initial root agg_sigma_sq: {history['root_agg_sigma_sq'][0]:.6f}")
print(f"Reduction factor: {history['root_agg_sigma_sq'][0] / history['root_agg_sigma_sq'][-1]:.2f}x")

---
## Part 4: Full MCTS Search Visualization

Let's run actual MCTS on TicTacToe and visualize the beliefs.

In [None]:
import torch

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load game and create a simple model
game = get_game('tictactoe')
print(f"Game: {game.config.name}")
print(f"Action size: {game.config.action_size}")

In [None]:
from nanozero.model import AlphaZeroTransformer
from nanozero.config import get_model_config

# Create a small model (untrained, just for demo)
model_config = get_model_config(game.config, n_layer=2)
model = AlphaZeroTransformer(model_config).to(device)
model.eval()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Create BayesianMCTS
config = BayesianMCTSConfig(
    num_simulations=100,
    sigma_0=0.5,
    obs_var=0.25,
    prune_threshold=0.01,
    early_stopping=False,  # Disable for visualization
)

mcts = BayesianMCTS(game, config)
print("BayesianMCTS created with variance aggregation!")

In [None]:
def visualize_mcts_tree(root, game, max_depth=2, title="MCTS Tree"):
    """Visualize the MCTS tree with beliefs."""
    
    def format_belief(node):
        agg_str = ""
        if node.agg_mu is not None:
            agg_str = f"\nagg: μ={node.agg_mu:.2f}, σ²={node.agg_sigma_sq:.3f}"
        return f"μ={node.mu:.2f}, σ²={node.sigma_sq:.3f}{agg_str}"
    
    def print_tree(node, depth=0, action=None, prefix=""):
        if depth > max_depth:
            return
        
        indent = "  " * depth
        action_str = f"Action {action}: " if action is not None else ""
        print(f"{indent}{prefix}{action_str}{format_belief(node)}")
        
        for a, child in sorted(node.children.items()):
            print_tree(child, depth + 1, a, "└─ ")
    
    print(f"\n{'='*60}")
    print(title)
    print("='*60")
    print_tree(root)

# Run search on empty board
state = game.initial_state()
print("TicTacToe board (X to move):")
print(game.display(state))

# Manually expand root and run a few simulations
root = BayesianNode()
mcts._expand(root, state, model, device)

print(f"\nAfter expansion (before simulations):")
print(f"Root aggregated belief: μ={root.agg_mu:.4f}, σ²={root.agg_sigma_sq:.6f}")
print(f"\nChildren beliefs (from root's perspective):")
for a, c in sorted(root.children.items()):
    parent_view_mu = -c.mu
    print(f"  Action {a}: μ={c.mu:.3f} (parent sees: {parent_view_mu:.3f}), σ²={c.sigma_sq:.3f}")

In [None]:
# Run simulations and track convergence
state = game.initial_state()
root = BayesianNode()

# Expand root
mcts._expand(root, state, model, device)

# Track metrics
sim_history = {
    'sim': [],
    'agg_mu': [],
    'agg_sigma_sq': [],
    'child_sigma_sqs': {a: [] for a in root.children.keys()},
}

# Record initial state
sim_history['sim'].append(0)
sim_history['agg_mu'].append(root.agg_mu)
sim_history['agg_sigma_sq'].append(root.agg_sigma_sq)
for a, c in root.children.items():
    sim_history['child_sigma_sqs'][a].append(c.sigma_sq)

# Run simulations
for sim in range(1, 101):
    mcts._run_simulation(root, state, model, device)
    
    sim_history['sim'].append(sim)
    sim_history['agg_mu'].append(root.agg_mu)
    sim_history['agg_sigma_sq'].append(root.agg_sigma_sq)
    for a, c in root.children.items():
        sim_history['child_sigma_sqs'][a].append(c.sigma_sq)

print(f"Ran 100 simulations.")
print(f"Final root agg_sigma_sq: {sim_history['agg_sigma_sq'][-1]:.6f}")
print(f"Initial root agg_sigma_sq: {sim_history['agg_sigma_sq'][0]:.6f}")

In [None]:
# Visualize convergence
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Root aggregated mean
axes[0].plot(sim_history['sim'], sim_history['agg_mu'], 'b-', linewidth=2)
axes[0].set_xlabel('Simulation #')
axes[0].set_ylabel('Root agg_mu')
axes[0].set_title('Root Value Estimate')
axes[0].grid(True, alpha=0.3)

# Plot 2: Root aggregated variance
axes[1].plot(sim_history['sim'], sim_history['agg_sigma_sq'], 'r-', linewidth=2)
axes[1].set_xlabel('Simulation #')
axes[1].set_ylabel('Root agg_sigma_sq')
axes[1].set_title('Root Variance (Uncertainty)')
axes[1].grid(True, alpha=0.3)

# Plot 3: Child variances
colors = plt.cm.tab10(np.linspace(0, 1, 9))
for i, (a, variances) in enumerate(sim_history['child_sigma_sqs'].items()):
    axes[2].plot(sim_history['sim'], variances, color=colors[a], 
                 linewidth=1.5, alpha=0.7, label=f'Action {a}')
axes[2].set_xlabel('Simulation #')
axes[2].set_ylabel('Child sigma_sq')
axes[2].set_title('Child Variances')
axes[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Final policy
policy = mcts._get_policy(root)

print("\nFinal Policy (probability each action is optimal):")
print("=" * 50)

# Display as 3x3 grid
print("\nTicTacToe grid (policy values):")
for row in range(3):
    row_str = "  "
    for col in range(3):
        action = row * 3 + col
        row_str += f" {policy[action]:.2f} "
    print(row_str)

print(f"\nBest action: {np.argmax(policy)} (probability: {np.max(policy):.3f})")

# Show child beliefs
print("\nChild beliefs after search:")
for a in sorted(root.children.keys()):
    c = root.children[a]
    parent_view = -c.mu
    print(f"  Action {a}: value={parent_view:.3f}, variance={c.sigma_sq:.4f}, policy={policy[a]:.3f}")

---
## Part 5: Comparing Prune Thresholds

The `prune_threshold` parameter controls soft-pruning of low-probability children.
Higher threshold = more aggressive pruning = sharper policies but potentially more bias.

In [None]:
def run_search_with_threshold(threshold, n_sims=50):
    """Run MCTS with a specific prune threshold."""
    config = BayesianMCTSConfig(
        num_simulations=n_sims,
        prune_threshold=threshold,
        early_stopping=False,
    )
    local_mcts = BayesianMCTS(game, config)
    
    state = game.initial_state()
    policy = local_mcts.search(state[np.newaxis, ...], model)[0]
    
    return policy

thresholds = [0.0, 0.01, 0.05, 0.1, 0.2]
policies_by_threshold = {}

for thresh in thresholds:
    policies_by_threshold[thresh] = run_search_with_threshold(thresh)

# Visualize
fig, axes = plt.subplots(1, len(thresholds), figsize=(15, 3))

for ax, thresh in zip(axes, thresholds):
    policy = policies_by_threshold[thresh]
    im = ax.imshow(policy.reshape(3, 3), cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'threshold={thresh}')
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Annotate with values
    for i in range(3):
        for j in range(3):
            ax.text(j, i, f'{policy[i*3+j]:.2f}', ha='center', va='center', 
                   color='white' if policy[i*3+j] > 0.5 else 'black', fontsize=10)

plt.suptitle('Policy Heatmaps by Prune Threshold', fontsize=14)
plt.tight_layout()
plt.show()

# Entropy comparison
print("\nPolicy entropy by threshold:")
for thresh in thresholds:
    policy = policies_by_threshold[thresh]
    ent = -np.sum(policy * np.log(policy + 1e-8))
    print(f"  threshold={thresh}: entropy={ent:.3f}, max_prob={np.max(policy):.3f}")

---
## Part 6: Interactive Exploration

Play with different parameters and see how they affect the search.

In [None]:
# Interactive parameter exploration
def explore_parameters(sigma_0=0.5, obs_var=0.25, prune_threshold=0.01, n_sims=50):
    """Run search with custom parameters and display results."""
    config = BayesianMCTSConfig(
        num_simulations=n_sims,
        sigma_0=sigma_0,
        obs_var=obs_var,
        prune_threshold=prune_threshold,
        early_stopping=False,
    )
    local_mcts = BayesianMCTS(game, config)
    
    state = game.initial_state()
    
    # Manual search to access root
    root = BayesianNode()
    local_mcts._expand(root, state, model, device)
    
    for _ in range(n_sims):
        local_mcts._run_simulation(root, state, model, device)
    
    policy = local_mcts._get_policy(root)
    
    print(f"Parameters: sigma_0={sigma_0}, obs_var={obs_var}, prune_threshold={prune_threshold}, n_sims={n_sims}")
    print(f"Root aggregated: μ={root.agg_mu:.4f}, σ²={root.agg_sigma_sq:.6f}")
    print(f"Policy entropy: {-np.sum(policy * np.log(policy + 1e-8)):.3f}")
    print(f"Best action: {np.argmax(policy)} (prob={np.max(policy):.3f})")
    print("\nPolicy grid:")
    for row in range(3):
        print("  " + " ".join(f"{policy[row*3+col]:.2f}" for col in range(3)))
    
    return root, policy

# Try different configurations
print("="*60)
print("Low prior variance (sigma_0=0.1):")
print("="*60)
explore_parameters(sigma_0=0.1)

print("\n" + "="*60)
print("High prior variance (sigma_0=1.0):")
print("="*60)
explore_parameters(sigma_0=1.0)

print("\n" + "="*60)
print("Low observation variance (obs_var=0.05):")
print("="*60)
explore_parameters(obs_var=0.05)

print("\n" + "="*60)
print("More simulations (n_sims=200):")
print("="*60)
explore_parameters(n_sims=200)

---
## Summary

### Key Takeaways

1. **Variance Aggregation** computes parent beliefs from ALL children, not just visited ones

2. **Optimality Weights** use pairwise Gaussian CDF: `P(child > leader)` for each child

3. **Variance Formula**: `σ²_agg = Σ w²[σ² + (μ - μ_agg)²]`
   - Squared weights → ensemble effect (variance decreases as 1/√N)
   - Disagreement term → uncertain when children disagree

4. **Variance Propagation**: `agg_sigma_sq` becomes `obs_var` for parent level
   - Uncertain subtrees contribute high-variance observations
   - Confident subtrees contribute low-variance observations

5. **Policy Extraction** uses optimality weights directly (deterministic, fast)

### Parameters That Matter

| Parameter | Effect |
|-----------|--------|
| `sigma_0` | Prior uncertainty - higher = slower convergence but more exploration |
| `obs_var` | Leaf observation noise - higher = slower belief updates |
| `prune_threshold` | Soft-prune weak children - higher = sharper policies |
| `n_sims` | More simulations = lower variance = more confident policies |

In [None]:
print("Notebook complete! Feel free to modify and re-run cells to explore.")