# Weight Power Ablation: Variance Aggregation in Bayesian MCTS

This notebook tests different powers for the weight term in variance aggregation:

```
agg_sigma_sq = sum(w^p * (sigma_sq + disagreement))
```

Current implementation uses `p=2` (squared weights), but since weights are somewhat correlated, the optimal power may be between 1 and 2.

**Hypothesis:** If weights are:
- Perfectly independent: p=2 is optimal (sum of variances)
- Perfectly correlated: p=1 is optimal (linear combination)
- Partially correlated: Optimal p is between 1 and 2

**Setup:** Use `Runtime > Change runtime type > GPU` for best performance.

In [None]:
# Check GPU
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install Rust toolchain
!curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
import os
os.environ["PATH"] = f"{os.environ['HOME']}/.cargo/bin:" + os.environ["PATH"]

# Verify Rust installation
!rustc --version

In [None]:
# Clone repository
!git clone https://github.com/caldred/nanozero.git
%cd nanozero

# Install Python dependencies
!pip install -q numpy scipy maturin

# Build and install Rust extension
%cd nanozero-mcts-rs
!maturin build --release
!pip install target/wheels/nanozero_mcts_rs-*.whl
%cd ..

# Verify Rust backend is available
!python -c "from nanozero.game import RUST_AVAILABLE; print(f'Rust backend available: {RUST_AVAILABLE}')"

---
## Train Connect4 Model

First, train a model we'll use for all ablation tests.

In [None]:
# Train Connect4 - A100 optimized settings
# ~5-10 minutes on A100
!python -m scripts.train \
    --game=connect4 \
    --n_layer=4 \
    --num_iterations=150 \
    --games_per_iteration=64 \
    --training_steps=200 \
    --mcts_simulations=100 \
    --batch_size=256 \
    --buffer_size=100000 \
    --parallel_games=128 \
    --eval_interval=25

---
## Modified Bayesian MCTS with Configurable Weight Power

Create a subclass that allows specifying the weight power `p` in variance aggregation.

In [None]:
import math
import numpy as np
import torch
from typing import Dict, List, Optional, Tuple
import sys
sys.path.insert(0, '.')

from nanozero.config import BayesianMCTSConfig
from nanozero.game import Game
from nanozero.mcts import TranspositionTable


def normal_cdf(x: float) -> float:
    """Standard normal CDF using error function."""
    return 0.5 * (1.0 + math.erf(x / math.sqrt(2.0)))


class BayesianNodeWithPower:
    """
    MCTS tree node with Gaussian belief over value.
    Modified to support configurable weight power in variance aggregation.
    """

    def __init__(self, prior: float = 0.0, mu: float = 0.0, sigma_sq: float = 1.0):
        self.prior = prior
        self.mu = mu
        self.sigma_sq = sigma_sq
        self.children: Dict[int, 'BayesianNodeWithPower'] = {}
        self.agg_mu: Optional[float] = None
        self.agg_sigma_sq: Optional[float] = None
        self.visits: int = 0

    def expanded(self) -> bool:
        return len(self.children) > 0

    def sample(self) -> float:
        return np.random.normal(self.mu, math.sqrt(self.sigma_sq))

    def update(self, value: float, obs_var: float, min_var: float = 1e-6) -> None:
        precision_prior = 1.0 / max(self.sigma_sq, min_var)
        precision_obs = 1.0 / max(obs_var, min_var)
        new_precision = precision_prior + precision_obs
        self.mu = (precision_prior * self.mu + precision_obs * value) / new_precision
        self.sigma_sq = max(1.0 / new_precision, min_var)

    def precision(self) -> float:
        return 1.0 / self.sigma_sq

    def aggregate_children(
        self, 
        prune_threshold: float = 0.01, 
        visited_only: bool = False,
        weight_power: float = 2.0  # NEW PARAMETER
    ) -> None:
        """
        Compute aggregated belief from children using optimality weights.
        
        Args:
            prune_threshold: Children with P(optimal) < threshold get weight 0
            visited_only: If True, only aggregate children that have been visited
            weight_power: Power to raise weights to in variance aggregation (1 to 2)
        """
        if not self.children:
            return

        if visited_only:
            children = [c for c in self.children.values() if c.visits > 0]
            if not children:
                return
        else:
            children = list(self.children.values())

        n = len(children)

        if n == 1:
            child = children[0]
            self.agg_mu = -child.mu
            self.agg_sigma_sq = child.sigma_sq
            return

        # Get child beliefs from parent's perspective (negate child values)
        mus = np.array([-c.mu for c in children])
        sigma_sqs = np.array([c.sigma_sq for c in children])

        # Find leader and challenger by mean
        sorted_idx = np.argsort(mus)[::-1]
        leader_idx = sorted_idx[0]
        challenger_idx = sorted_idx[1]

        # Compute optimality scores via pairwise Gaussian CDF comparisons
        scores = np.zeros(n)
        mu_L, sigma_sq_L = mus[leader_idx], sigma_sqs[leader_idx]
        mu_C, sigma_sq_C = mus[challenger_idx], sigma_sqs[challenger_idx]

        for i in range(n):
            if i == leader_idx:
                diff = mu_L - mu_C
                std = math.sqrt(sigma_sq_L + sigma_sq_C)
            else:
                diff = mus[i] - mu_L
                std = math.sqrt(sigma_sqs[i] + sigma_sq_L)

            if std > 1e-10:
                scores[i] = normal_cdf(diff / std)
            else:
                scores[i] = 1.0 if diff > 0 else 0.0

        # Soft prune and normalize
        scores[scores < prune_threshold] = 0.0
        total = scores.sum()
        if total < 1e-10:
            weights = np.ones(n) / n
        else:
            weights = scores / total

        # Aggregated mean (weighted average of children)
        self.agg_mu = float(np.sum(weights * mus))

        # Aggregated variance with CONFIGURABLE weight power
        disagreement = (mus - self.agg_mu) ** 2
        self.agg_sigma_sq = float(np.sum(weights ** weight_power * (sigma_sqs + disagreement)))


class BayesianMCTSWithPower:
    """
    Bayesian MCTS with configurable weight power in variance aggregation.
    """

    def __init__(
        self,
        game: Game,
        config: BayesianMCTSConfig,
        weight_power: float = 2.0,  # NEW PARAMETER
        use_transposition_table: bool = True
    ):
        self.game = game
        self.config = config
        self.weight_power = weight_power
        self.use_tt = use_transposition_table
        self.tt = TranspositionTable(game) if use_transposition_table else None

    def clear_cache(self):
        if self.tt:
            self.tt.clear()

    def search(
        self,
        states: np.ndarray,
        model: torch.nn.Module,
        num_simulations: Optional[int] = None
    ) -> np.ndarray:
        if num_simulations is None:
            num_simulations = self.config.num_simulations

        num_states = states.shape[0]
        device = next(model.parameters()).device
        policies = np.zeros((num_states, self.game.config.action_size), dtype=np.float32)

        # Handle terminal states
        non_terminal_indices = []
        non_terminal_states = []
        for i in range(num_states):
            if self.game.is_terminal(states[i]):
                legal = self.game.legal_actions(states[i])
                if legal:
                    for a in legal:
                        policies[i, a] = 1.0 / len(legal)
            else:
                non_terminal_indices.append(i)
                non_terminal_states.append(states[i])

        if not non_terminal_states:
            return policies

        # Batch expand roots
        non_terminal_states_arr = np.stack(non_terminal_states)
        roots, _ = self._batch_expand_roots(non_terminal_states_arr, model, device)

        # Simulation loop
        for sim in range(num_simulations):
            leaves_to_expand = []
            terminal_backups = []
            expansion_backups = []

            for local_idx in range(len(roots)):
                root = roots[local_idx]
                state = non_terminal_states[local_idx]
                leaf_node, search_path, leaf_state, is_terminal = self._select_to_leaf(root, state)

                if is_terminal:
                    value = self.game.terminal_reward(leaf_state)
                    terminal_backups.append((local_idx, search_path, value))
                elif not leaf_node.expanded():
                    leaves_to_expand.append((local_idx, leaf_node, leaf_state))
                    expansion_backups.append((local_idx, search_path))

            # Batch expand
            if leaves_to_expand:
                nodes = [item[1] for item in leaves_to_expand]
                leaf_states = [item[2] for item in leaves_to_expand]
                values = self._batch_expand_leaves(nodes, leaf_states, model, device)

                for (local_idx, search_path), value in zip(expansion_backups, values):
                    self._backup(search_path, value)

            for local_idx, search_path, value in terminal_backups:
                self._backup(search_path, value)

        # Extract policies
        for local_idx, state_idx in enumerate(non_terminal_indices):
            policies[state_idx] = self._get_policy(roots[local_idx])

        return policies

    def _select_to_leaf(
        self,
        root: BayesianNodeWithPower,
        state: np.ndarray
    ) -> Tuple[BayesianNodeWithPower, List[Tuple[BayesianNodeWithPower, int]], np.ndarray, bool]:
        node = root
        search_path = []
        current_state = state.copy()

        while node.expanded() and not self.game.is_terminal(current_state):
            action, child = self._select_child_thompson_ids(node)
            search_path.append((node, action))
            current_state = self.game.next_state(current_state, action)
            node = child

        is_terminal = self.game.is_terminal(current_state)
        return node, search_path, current_state, is_terminal

    def _create_children_from_policy(
        self,
        node: BayesianNodeWithPower,
        state: np.ndarray,
        policy: np.ndarray,
        value: float
    ) -> None:
        legal_actions = self.game.legal_actions(state)
        sigma_0 = self.config.sigma_0

        eps = 1e-8
        legal_probs = np.array([policy[a] for a in legal_actions])
        legal_probs = legal_probs / (legal_probs.sum() + eps)
        log_probs = np.log(legal_probs + eps)
        entropy = -np.sum(legal_probs * log_probs)
        scale = sigma_0 * (math.sqrt(6) / math.pi)

        for i, action in enumerate(legal_actions):
            mu = -value - scale * (log_probs[i] + entropy)
            sigma_sq = sigma_0 ** 2
            node.children[action] = BayesianNodeWithPower(
                prior=policy[action],
                mu=mu,
                sigma_sq=sigma_sq
            )

    def _batch_expand_leaves(
        self,
        nodes: List[BayesianNodeWithPower],
        states: List[np.ndarray],
        model: torch.nn.Module,
        device: torch.device
    ) -> List[float]:
        batch_size = len(nodes)
        if batch_size == 0:
            return []

        results = [None] * batch_size
        miss_indices = []
        miss_canonical_keys = []

        for i in range(batch_size):
            state = states[i]
            node = nodes[i]

            if self.tt:
                cached = self.tt.get(state)
                if cached is not None:
                    policy, value = cached
                    self._create_children_from_policy(node, state, policy, value)
                    node.aggregate_children(self.config.prune_threshold, weight_power=self.weight_power)
                    results[i] = value
                    continue

            if self.tt:
                canonical_key, _ = self.tt._canonical_key(state)
                miss_canonical_keys.append(canonical_key)
            miss_indices.append(i)

        if miss_indices:
            if self.tt and miss_canonical_keys:
                unique_keys = {}
                for j, (idx, key) in enumerate(zip(miss_indices, miss_canonical_keys)):
                    if key not in unique_keys:
                        unique_keys[key] = j
                unique_local_indices = list(unique_keys.values())
            else:
                unique_local_indices = list(range(len(miss_indices)))

            unique_states = [states[miss_indices[j]] for j in unique_local_indices]

            state_tensors = torch.stack([
                self.game.to_tensor(self.game.canonical_state(s)) for s in unique_states
            ]).to(device)

            action_masks = torch.stack([
                torch.from_numpy(self.game.legal_actions_mask(s))
                for s in unique_states
            ]).float().to(device)

            policies, values = model.predict(state_tensors, action_masks)
            policies = policies.cpu().numpy()
            values = values.cpu().numpy().flatten()

            if self.tt:
                for j, state in enumerate(unique_states):
                    self.tt.put(state, policies[j], values[j])

            for idx in miss_indices:
                state = states[idx]
                node = nodes[idx]

                if self.tt:
                    policy, value = self.tt.get(state)
                else:
                    local_idx = miss_indices.index(idx)
                    policy = policies[local_idx]
                    value = values[local_idx]

                self._create_children_from_policy(node, state, policy, value)
                node.aggregate_children(self.config.prune_threshold, weight_power=self.weight_power)
                results[idx] = value

        return results

    def _batch_expand_roots(
        self,
        states: np.ndarray,
        model: torch.nn.Module,
        device: torch.device
    ) -> Tuple[List[BayesianNodeWithPower], np.ndarray]:
        batch_size = states.shape[0]
        roots = [BayesianNodeWithPower() for _ in range(batch_size)]

        cache_hits = [None] * batch_size
        miss_indices = []

        if self.tt:
            for i in range(batch_size):
                cached = self.tt.get(states[i])
                if cached is not None:
                    cache_hits[i] = cached
                else:
                    miss_indices.append(i)
        else:
            miss_indices = list(range(batch_size))

        if miss_indices:
            miss_states = np.stack([states[i] for i in miss_indices])

            state_tensors = torch.stack([
                self.game.to_tensor(self.game.canonical_state(s)) for s in miss_states
            ]).to(device)

            action_masks = torch.stack([
                torch.from_numpy(self.game.legal_actions_mask(s))
                for s in miss_states
            ]).float().to(device)

            policies, values = model.predict(state_tensors, action_masks)
            policies = policies.cpu().numpy()
            values = values.cpu().numpy().flatten()

            for j, idx in enumerate(miss_indices):
                if self.tt:
                    self.tt.put(states[idx], policies[j], values[j])
                cache_hits[idx] = (policies[j], values[j])

        all_values = np.zeros(batch_size, dtype=np.float32)
        for i, root in enumerate(roots):
            policy, value = cache_hits[i]
            self._create_children_from_policy(root, states[i], policy, value)
            root.aggregate_children(self.config.prune_threshold, weight_power=self.weight_power)
            all_values[i] = value

        return roots, all_values

    def _select_child_thompson_ids(
        self,
        node: BayesianNodeWithPower
    ) -> Tuple[int, BayesianNodeWithPower]:
        children = list(node.children.items())
        if len(children) == 1:
            return children[0]

        samples = [(action, child, -child.sample()) for action, child in children]
        samples.sort(key=lambda x: x[2], reverse=True)

        leader_action, leader_node, _ = samples[0]
        challenger_action, challenger_node, _ = samples[1]

        alpha = self.config.ids_alpha
        precision_i = leader_node.precision()
        precision_j = challenger_node.precision()

        beta = (precision_i + alpha) / (precision_i + precision_j + 2 * alpha)

        if np.random.random() < beta:
            return challenger_action, challenger_node
        else:
            return leader_action, leader_node

    def _backup(
        self,
        search_path: List[Tuple[BayesianNodeWithPower, int]],
        leaf_value: float
    ) -> None:
        for i, (parent, action) in enumerate(reversed(search_path)):
            child = parent.children[action]
            child.visits += 1

            if i == 0:
                if not child.expanded():
                    child.update(leaf_value, self.config.obs_var, self.config.min_variance)
                else:
                    if child.agg_mu is not None:
                        child.mu = child.agg_mu
                        child.sigma_sq = child.agg_sigma_sq

            parent.aggregate_children(self.config.prune_threshold, visited_only=True, weight_power=self.weight_power)

            if parent.agg_mu is not None:
                parent.mu = parent.agg_mu
                parent.sigma_sq = parent.agg_sigma_sq

    def _get_policy(self, root: BayesianNodeWithPower) -> np.ndarray:
        policy = np.zeros(self.game.config.action_size, dtype=np.float32)

        if not root.expanded():
            return policy

        actions = list(root.children.keys())
        children = [root.children[a] for a in actions]
        n = len(children)

        if n == 1:
            policy[actions[0]] = 1.0
            return policy

        mus = np.array([-c.mu for c in children])
        sigma_sqs = np.array([c.sigma_sq for c in children])

        sorted_idx = np.argsort(mus)[::-1]
        leader_idx = sorted_idx[0]
        challenger_idx = sorted_idx[1]

        scores = np.zeros(n)
        mu_L, sigma_sq_L = mus[leader_idx], sigma_sqs[leader_idx]
        mu_C, sigma_sq_C = mus[challenger_idx], sigma_sqs[challenger_idx]

        for i in range(n):
            if i == leader_idx:
                diff = mu_L - mu_C
                std = math.sqrt(sigma_sq_L + sigma_sq_C)
            else:
                diff = mus[i] - mu_L
                std = math.sqrt(sigma_sqs[i] + sigma_sq_L)

            if std > 1e-10:
                scores[i] = normal_cdf(diff / std)
            else:
                scores[i] = 1.0 if diff > 0 else 0.0

        total = scores.sum()
        if total < 1e-10:
            for action in actions:
                policy[action] = 1.0 / n
        else:
            for i, action in enumerate(actions):
                policy[action] = scores[i] / total

        return policy


print("BayesianMCTSWithPower class defined successfully.")

---
## Arena Test Function

In [None]:
from nanozero.game import get_game
from nanozero.model import AlphaZeroTransformer
from nanozero.mcts import BatchedMCTS, sample_action
from nanozero.config import get_model_config, MCTSConfig, BayesianMCTSConfig
from nanozero.common import load_checkpoint


def run_arena(game, model, puct_mcts, ttts_mcts, num_games, mcts_simulations):
    """
    Run arena between PUCT and TTTS, return results from TTTS perspective.
    
    Returns:
        Tuple of (wins, draws, losses) for TTTS
    """
    wins, draws, losses = 0, 0, 0

    for i in range(num_games):
        state = game.initial_state()
        ttts_turn = 1 if i % 2 == 0 else -1  # Alternate who goes first

        while not game.is_terminal(state):
            current = game.current_player(state)
            if current == ttts_turn:
                policy = ttts_mcts.search(
                    state[np.newaxis, ...], model,
                    num_simulations=mcts_simulations
                )[0]
            else:
                policy = puct_mcts.search(
                    state[np.newaxis, ...], model,
                    num_simulations=mcts_simulations, add_noise=False
                )[0]
            action = sample_action(policy, temperature=0)
            state = game.next_state(state, action)

        reward = game.terminal_reward(state)
        final_player = game.current_player(state)

        if final_player == ttts_turn:
            ttts_result = reward
        else:
            ttts_result = -reward

        if ttts_result > 0:
            wins += 1
        elif ttts_result < 0:
            losses += 1
        else:
            draws += 1

    return wins, draws, losses


print("Arena function defined.")

---
## Run Ablation: Test Different Weight Powers

Test powers from 1.0 to 2.0 in increments of 0.25.

In [None]:
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

game = get_game('connect4', use_rust=True)
print(f"Game backend: {game.backend}")

model_config = get_model_config(game.config, n_layer=4)
model = AlphaZeroTransformer(model_config).to(device)
load_checkpoint('checkpoints/connect4_final.pt', model)
model.eval()

# Setup PUCT baseline
puct_config = MCTSConfig()
puct_mcts = BatchedMCTS(game, puct_config)

print("Model loaded successfully.")

In [None]:
# Ablation configuration
weight_powers = [1.0, 1.25, 1.5, 1.75, 2.0]
num_games = 100  # Per configuration
mcts_simulations = 100
seed = 42

print(f"\nWeight Power Ablation")
print(f"=====================")
print(f"Games per config: {num_games}")
print(f"Simulations: {mcts_simulations}")
print(f"Config: sigma_0=1.0, obs_var=1.0")
print("=" * 50)

results = []

for power in weight_powers:
    # Create TTTS with this weight power
    ttts_config = BayesianMCTSConfig(
        sigma_0=1.0,
        obs_var=1.0,
    )
    ttts_mcts = BayesianMCTSWithPower(game, ttts_config, weight_power=power)
    
    # Clear caches
    puct_mcts.clear_cache()
    ttts_mcts.clear_cache()
    
    # Set seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Run arena
    wins, draws, losses = run_arena(
        game, model, puct_mcts, ttts_mcts,
        num_games=num_games, mcts_simulations=mcts_simulations
    )
    
    decisive = wins + losses
    win_rate = wins / decisive if decisive > 0 else 0.5
    
    results.append({
        'power': power,
        'wins': wins,
        'draws': draws,
        'losses': losses,
        'win_rate': win_rate,
    })
    
    print(f"p={power:.2f}: {wins}W/{draws}D/{losses}L  ({win_rate:.1%} decisive win rate)")

print("\nAblation complete!")

In [None]:
# Visualize results
import matplotlib.pyplot as plt

powers = [r['power'] for r in results]
win_rates = [r['win_rate'] for r in results]
wins = [r['wins'] for r in results]
draws = [r['draws'] for r in results]
losses = [r['losses'] for r in results]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Win rate plot
ax1.plot(powers, win_rates, 'bo-', linewidth=2, markersize=10)
ax1.axhline(y=0.5, color='gray', linestyle='--', label='50% baseline')
ax1.set_xlabel('Weight Power (p)', fontsize=12)
ax1.set_ylabel('TTTS Win Rate vs PUCT', fontsize=12)
ax1.set_title('Win Rate by Weight Power', fontsize=14)
ax1.set_ylim(0.3, 0.7)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Stacked bar chart
x = np.arange(len(powers))
width = 0.6

ax2.bar(x, wins, width, label='Wins', color='green', alpha=0.8)
ax2.bar(x, draws, width, bottom=wins, label='Draws', color='gray', alpha=0.8)
ax2.bar(x, losses, width, bottom=[w+d for w, d in zip(wins, draws)], label='Losses', color='red', alpha=0.8)
ax2.set_xticks(x)
ax2.set_xticklabels([f'p={p}' for p in powers])
ax2.set_xlabel('Weight Power', fontsize=12)
ax2.set_ylabel('Games', fontsize=12)
ax2.set_title('Game Outcomes by Weight Power', fontsize=14)
ax2.legend()

plt.tight_layout()
plt.savefig('weight_power_ablation.png', dpi=150)
plt.show()

# Find best power
best_idx = np.argmax(win_rates)
best_power = powers[best_idx]
best_win_rate = win_rates[best_idx]

print(f"\nBest weight power: p={best_power:.2f} ({best_win_rate:.1%} win rate)")

---
## Extended Ablation: Multiple Simulation Counts

Test best weight powers across different simulation counts to see if optimal `p` varies.

In [None]:
# Extended ablation with different sim counts
weight_powers = [1.0, 1.5, 2.0]  # Just test the extremes and middle
sim_counts = [50, 100, 200]
num_games = 60  # Fewer games per cell for faster testing

print(f"\nExtended Ablation: Weight Power x Simulation Count")
print(f"="*60)

extended_results = []

for n_sims in sim_counts:
    print(f"\n--- {n_sims} simulations ---")
    for power in weight_powers:
        ttts_config = BayesianMCTSConfig(
            sigma_0=1.0,
            obs_var=1.0,
        )
        ttts_mcts = BayesianMCTSWithPower(game, ttts_config, weight_power=power)
        
        puct_mcts.clear_cache()
        ttts_mcts.clear_cache()
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        wins, draws, losses = run_arena(
            game, model, puct_mcts, ttts_mcts,
            num_games=num_games, mcts_simulations=n_sims
        )
        
        decisive = wins + losses
        win_rate = wins / decisive if decisive > 0 else 0.5
        
        extended_results.append({
            'n_sims': n_sims,
            'power': power,
            'wins': wins,
            'draws': draws,
            'losses': losses,
            'win_rate': win_rate,
        })
        
        print(f"  p={power:.1f}: {wins}W/{draws}D/{losses}L ({win_rate:.1%})")

In [None]:
# Visualize extended results
fig, ax = plt.subplots(figsize=(10, 6))

for power in weight_powers:
    power_results = [r for r in extended_results if r['power'] == power]
    sims = [r['n_sims'] for r in power_results]
    rates = [r['win_rate'] for r in power_results]
    ax.plot(sims, rates, 'o-', linewidth=2, markersize=8, label=f'p={power}')

ax.axhline(y=0.5, color='gray', linestyle='--', label='50% baseline')
ax.set_xlabel('MCTS Simulations', fontsize=12)
ax.set_ylabel('TTTS Win Rate vs PUCT', fontsize=12)
ax.set_title('Weight Power Effect Across Simulation Counts', fontsize=14)
ax.set_ylim(0.3, 0.7)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('weight_power_vs_sims.png', dpi=150)
plt.show()

---
## Summary

### Theory

The variance aggregation formula:
```
agg_sigma_sq = sum(w^p * (sigma_sq + disagreement))
```

- **p=2 (current)**: Correct if children's values are independent random variables
- **p=1**: Correct if children's values are perfectly correlated
- **1 < p < 2**: Appropriate for partial correlation

### Results Interpretation

1. **If p=2 is best**: Children's values are approximately independent (search explores different regions)
2. **If p=1 is best**: Children's values are highly correlated (NN errors are systematic)
3. **If intermediate p is best**: Moderate correlation exists

### Next Steps

If an intermediate value of `p` performs best:
1. Add `weight_power` parameter to `BayesianMCTSConfig`
2. Update `aggregate_children` to use the parameter
3. Consider making `p` adaptive based on tree depth or visit count