# Notebook 06: Bifurcation Detection

**Purpose**: Detect moves where game trajectories sharply diverge in feature space ("strategic phase transitions").

**Workflow**:
1. Load SAE features and game metadata
2. Compute trajectory divergence (cosine distance between direction vectors)
3. Detect bifurcations (moves above 95th percentile threshold)
4. Statistical validation (permutation test for outcome correlation)
5. Baselines (position-shuffled null)
6. Visualization and analysis


## 1. Setup

In [None]:
# ============================================================================
# COLAB SETUP - Run this cell first!
# ============================================================================
import sys
from pathlib import Path

# Detect if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Google Colab")

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Set up paths
    DRIVE_ROOT = Path('/content/drive/MyDrive/chaos')
    DRIVE_ROOT.mkdir(parents=True, exist_ok=True)

    print(f"Drive mounted. Project root: {DRIVE_ROOT}")

    # Install dependencies
    print("Installing dependencies...")
    !pip install -q torch>=2.0.0 h5py>=3.8.0
    !pip install -q matplotlib>=3.7.0 tqdm>=4.65.0
    !pip install -q scikit-learn>=1.2.0 scipy>=1.10.0
    print("Dependencies installed!")

    # Unzip src.zip
    !unzip -n -q src.zip -d /content/
else:
    print("Running locally")
    DRIVE_ROOT = None

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import time
import h5py
from contextlib import contextmanager

# ============================================================================
# TIMING INFRASTRUCTURE
# ============================================================================
@contextmanager
def timed_section(name):
    """Context manager to time code sections."""
    start = time.time()
    print(f"[START] {name}...")
    yield
    elapsed = time.time() - start
    print(f"[DONE] {name}: {elapsed:.1f}s ({elapsed/60:.1f} min)")

# Add src to path
if IN_COLAB:
    sys.path.insert(0, '/content')
sys.path.insert(0, str(Path('.').absolute()))

# Import bifurcation analysis module
from src.analysis.bifurcation_analysis import (
    BifurcationResult,
    CorrelationResult,
    compute_trajectory_divergence,
    detect_bifurcations,
    correlate_with_outcomes,
    BifurcationAnalyzer,
)

print("Modules loaded successfully")

# Configuration
CONFIG = {
    'block_idx': 35,  # Strategic layer
    
    # Bifurcation detection settings
    'divergence_window': 5,          # Moves to average for direction vectors
    'divergence_metric': 'cosine',   # Cosine recommended for SAE features
    'threshold_percentile': 95,      # Per chaos theory literature
    'min_distance': 5,               # Minimum moves between bifurcations
    
    # Statistical validation
    'n_permutations': 1000,          # For permutation test
    
    # Paths
    'features_dir': 'outputs/data/sae_features',
    'metadata_dir': 'outputs/data/metadata',
    'output_dir': 'outputs/analysis/bifurcation_detection',
}

# ============================================================================
# COLAB: Configure paths for Drive storage
# ============================================================================
if IN_COLAB:
    CONFIG['features_dir'] = str(DRIVE_ROOT / 'data' / 'sae_features')
    CONFIG['metadata_dir'] = str(DRIVE_ROOT / 'data' / 'metadata')
    CONFIG['output_dir'] = str(DRIVE_ROOT / 'analysis' / 'bifurcation_detection')

    print(f"Features dir: {CONFIG['features_dir']}")
    print(f"Output dir: {CONFIG['output_dir']}")

# Create output directory
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
(Path(CONFIG['output_dir']) / 'figures').mkdir(parents=True, exist_ok=True)

## 2. Load Data

In [None]:
# Load SAE features
features_path = Path(CONFIG['features_dir']) / f"block{CONFIG['block_idx']}_features.h5"

features = None
n_samples = 0
n_features = 0

if features_path.exists():
    with h5py.File(features_path, 'r') as f:
        block_key = f'block{CONFIG["block_idx"]}'
        if block_key not in f:
            block_key = list(f.keys())[0]
            print(f"Note: Using key '{block_key}'")
        
        dset = f[block_key]
        n_samples, n_features = dset.shape
        features = dset[()].astype(np.float32)
        
    print(f"Loaded SAE features: {features.shape}")
    print(f"Memory: {features.nbytes / 1e9:.2f} GB")
else:
    print(f"ERROR: Features not found at {features_path}")
    print("Please run notebooks 01-03 first.")

In [None]:
# Load game metadata
game_ids = None
move_numbers = None

# Try loading from activations.h5 (same pattern as NB05)
activations_h5 = Path(CONFIG['features_dir']).parent / 'activations.h5'
if not activations_h5.exists():
    activations_h5 = Path(CONFIG['features_dir']).parent.parent / 'data' / 'activations.h5'

if activations_h5.exists():
    with h5py.File(activations_h5, 'r') as f:
        if 'position_indices' in f:
            pos_indices = f['position_indices'][()]
            game_ids = pos_indices[:, 0]
            move_numbers = pos_indices[:, 1]
            print(f"Loaded metadata from h5: {len(game_ids):,} positions")
        elif 'game_ids' in f:
            game_ids = f['game_ids'][()]
            print(f"Loaded game_ids from h5: {len(game_ids):,}")

if game_ids is None and n_samples > 0:
    # Create synthetic metadata
    print("Creating synthetic game_ids...")
    n_games = max(1, n_samples // 100)
    game_ids = np.repeat(np.arange(n_games), 100)[:n_samples]
    move_numbers = np.tile(np.arange(100), n_games)[:n_samples]
    print(f"Created synthetic metadata: {n_games} games")

# Verify alignment
if features is not None and game_ids is not None:
    if len(features) != len(game_ids):
        min_len = min(len(features), len(game_ids))
        features = features[:min_len]
        game_ids = game_ids[:min_len]
        if move_numbers is not None:
            move_numbers = move_numbers[:min_len]
        print(f"Truncated to {min_len:,} samples")
    
    n_games = len(np.unique(game_ids))
    print(f"\nData aligned: {len(features):,} samples, {n_games:,} games")

In [None]:
# Load outcomes if available (optional)
outcomes = None

outcomes_path = Path(CONFIG['metadata_dir']) / 'game_outcomes.npy'
if outcomes_path.exists():
    outcomes = np.load(outcomes_path)
    print(f"Loaded game outcomes: {len(outcomes)} games")
    print(f"Win rate: {outcomes.mean():.2%}")
else:
    print("No game outcomes found. Skipping outcome correlation analysis.")
    print(f"Expected path: {outcomes_path}")

## 3. Bifurcation Detection

In [None]:
if features is not None:
    with timed_section("Bifurcation detection"):
        print("="*60)
        print("BIFURCATION DETECTION")
        print(f"Window: {CONFIG['divergence_window']} moves")
        print(f"Metric: {CONFIG['divergence_metric']}")
        print(f"Threshold: {CONFIG['threshold_percentile']}th percentile")
        print("="*60)
        
        # Initialize analyzer
        analyzer = BifurcationAnalyzer(
            window=CONFIG['divergence_window'],
            threshold_percentile=CONFIG['threshold_percentile'],
            min_distance=CONFIG['min_distance'],
            metric=CONFIG['divergence_metric'],
        )
        
        # Run analysis
        bifurcation_result = analyzer.analyze_all_games(
            features=features,
            game_ids=game_ids,
        )
        
        print(f"\nResults:")
        print(f"  Total bifurcations: {len(bifurcation_result.move_indices):,}")
        print(f"  Bifurcation rate: {len(bifurcation_result.move_indices)/len(features)*100:.2f}%")
        print(f"  Threshold: {bifurcation_result.threshold:.4f}")

In [None]:
# Divergence distribution analysis
if features is not None:
    divergences = bifurcation_result.divergence_scores
    valid_div = divergences[~np.isnan(divergences)]
    
    print(f"\nDivergence Statistics:")
    print(f"  Mean: {np.mean(valid_div):.4f}")
    print(f"  Std: {np.std(valid_div):.4f}")
    print(f"  Median: {np.median(valid_div):.4f}")
    print(f"  95th percentile: {np.percentile(valid_div, 95):.4f}")
    print(f"  99th percentile: {np.percentile(valid_div, 99):.4f}")
    print(f"  Max: {np.max(valid_div):.4f}")

## 4. Statistical Validation

In [None]:
# Outcome correlation (if outcomes available)
correlation_result = None

if features is not None and outcomes is not None:
    with timed_section("Outcome correlation test"):
        print("="*60)
        print("BIFURCATION-OUTCOME CORRELATION")
        print(f"Permutation test: {CONFIG['n_permutations']} permutations")
        print("="*60)
        
        # Verify alignment
        n_games_data = len(bifurcation_result.game_boundaries)
        if n_games_data != len(outcomes):
            print(f"WARNING: Game count mismatch: {n_games_data} vs {len(outcomes)}")
            outcomes_aligned = outcomes[:n_games_data]
        else:
            outcomes_aligned = outcomes
        
        correlation_result = correlate_with_outcomes(
            bifurcation_result,
            outcomes_aligned,
            n_permutations=CONFIG['n_permutations'],
        )
        
        print(f"\nResults:")
        print(f"  Correlation: {correlation_result.correlation:.4f}")
        print(f"  P-value: {correlation_result.p_value:.4f}")
        print(f"  Null mean: {correlation_result.null_mean:.4f}")
        print(f"  Null std: {correlation_result.null_std:.4f}")
        print(f"  Significant (p < 0.05): {correlation_result.significant}")
else:
    print("Skipping outcome correlation (no outcomes data)")

## 5. Per-Game Analysis

In [None]:
# Compute per-game statistics
if features is not None:
    game_stats = []
    
    for game_idx, (start, end) in enumerate(bifurcation_result.game_boundaries):
        # Bifurcations in this game
        game_mask = (bifurcation_result.move_indices >= start) & \
                    (bifurcation_result.move_indices < end)
        n_bif = game_mask.sum()
        
        # Divergence stats for this game
        game_div = bifurcation_result.divergence_scores[start:end]
        valid = ~np.isnan(game_div)
        
        if valid.any():
            game_stats.append({
                'game_id': game_idx,
                'n_moves': end - start,
                'n_bifurcations': n_bif,
                'bifurcation_rate': n_bif / (end - start),
                'mean_divergence': np.mean(game_div[valid]),
                'max_divergence': np.max(game_div[valid]),
            })
    
    game_stats_df = pd.DataFrame(game_stats)
    
    print("\nPer-Game Statistics:")
    print(game_stats_df.describe())

## 6. Visualization

In [None]:
# Divergence histogram
if features is not None:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    valid_div = bifurcation_result.divergence_scores[~np.isnan(bifurcation_result.divergence_scores)]
    
    # Histogram
    ax1 = axes[0]
    ax1.hist(valid_div, bins=100, edgecolor='black', alpha=0.7, density=True)
    ax1.axvline(x=bifurcation_result.threshold, color='r', linestyle='--', 
                linewidth=2, label=f'95th percentile: {bifurcation_result.threshold:.3f}')
    ax1.set_xlabel('Divergence Score (cosine distance)', fontsize=12)
    ax1.set_ylabel('Density', fontsize=12)
    ax1.set_title('Distribution of Trajectory Divergence', fontsize=14)
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Per-game bifurcation count
    ax2 = axes[1]
    ax2.hist(game_stats_df['n_bifurcations'], bins=30, edgecolor='black', alpha=0.7)
    ax2.axvline(x=game_stats_df['n_bifurcations'].mean(), color='r', linestyle='--',
                linewidth=2, label=f"Mean: {game_stats_df['n_bifurcations'].mean():.1f}")
    ax2.set_xlabel('Number of Bifurcations per Game', fontsize=12)
    ax2.set_ylabel('Frequency', fontsize=12)
    ax2.set_title('Bifurcations per Game Distribution', fontsize=14)
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(Path(CONFIG['output_dir']) / 'figures' / 'divergence_histogram.png', 
                dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Example game trajectory
if features is not None:
    # Find a game with bifurcations
    game_with_bifs = game_stats_df[game_stats_df['n_bifurcations'] > 2].iloc[0]['game_id'] \
                     if len(game_stats_df[game_stats_df['n_bifurcations'] > 2]) > 0 else 0
    game_with_bifs = int(game_with_bifs)
    
    start, end = bifurcation_result.game_boundaries[game_with_bifs]
    game_div = bifurcation_result.divergence_scores[start:end]
    
    # Bifurcations in this game
    game_mask = (bifurcation_result.move_indices >= start) & \
                (bifurcation_result.move_indices < end)
    game_bifs = bifurcation_result.move_indices[game_mask] - start
    
    fig, ax = plt.subplots(figsize=(12, 5))
    
    moves = np.arange(len(game_div))
    ax.plot(moves, game_div, 'b-', linewidth=1, alpha=0.7, label='Divergence')
    ax.axhline(y=bifurcation_result.threshold, color='r', linestyle='--', 
               linewidth=1, label='Threshold')
    
    # Mark bifurcations
    for bif in game_bifs:
        if not np.isnan(game_div[bif]):
            ax.scatter([bif], [game_div[bif]], color='red', s=100, zorder=5)
    
    ax.set_xlabel('Move Number', fontsize=12)
    ax.set_ylabel('Divergence Score', fontsize=12)
    ax.set_title(f'Game {game_with_bifs}: Trajectory Divergence with Bifurcations', fontsize=14)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(Path(CONFIG['output_dir']) / 'figures' / 'example_game_trajectory.png',
                dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Bifurcation timing distribution
if features is not None:
    # Compute relative timing (0-1 scale within each game)
    relative_timings = []
    
    for bif_idx in bifurcation_result.move_indices:
        # Find which game this belongs to
        for game_idx, (start, end) in enumerate(bifurcation_result.game_boundaries):
            if start <= bif_idx < end:
                relative_pos = (bif_idx - start) / (end - start)
                relative_timings.append(relative_pos)
                break
    
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.hist(relative_timings, bins=50, edgecolor='black', alpha=0.7)
    ax.set_xlabel('Relative Position in Game (0=start, 1=end)', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title('When Do Bifurcations Occur?', fontsize=14)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(Path(CONFIG['output_dir']) / 'figures' / 'bifurcation_timing.png',
                dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nBifurcation timing (relative to game length):")
    print(f"  Mean: {np.mean(relative_timings):.2f}")
    print(f"  Median: {np.median(relative_timings):.2f}")
    print(f"  Std: {np.std(relative_timings):.2f}")

## 7. Save Results

In [None]:
if features is not None:
    output_path = Path(CONFIG['output_dir'])
    
    # Save divergence scores
    np.save(output_path / 'divergence_all.npy', bifurcation_result.divergence_scores)
    
    # Save bifurcation indices
    bifurcation_data = {
        'move_indices': bifurcation_result.move_indices.tolist(),
        'threshold': float(bifurcation_result.threshold),
        'n_bifurcations': len(bifurcation_result.move_indices),
        'n_samples': len(features),
        'bifurcation_rate': len(bifurcation_result.move_indices) / len(features),
    }
    with open(output_path / 'bifurcation_indices.json', 'w') as f:
        json.dump(bifurcation_data, f, indent=2)
    
    # Save correlation result (if available)
    if correlation_result is not None:
        corr_data = {
            'correlation': float(correlation_result.correlation),
            'p_value': float(correlation_result.p_value),
            'null_mean': float(correlation_result.null_mean),
            'null_std': float(correlation_result.null_std),
            'n_permutations': correlation_result.n_permutations,
            'significant': correlation_result.significant,
        }
        with open(output_path / 'correlation_result.json', 'w') as f:
            json.dump(corr_data, f, indent=2)
    
    # Save per-game stats
    game_stats_df.to_csv(output_path / 'per_game_stats.csv', index=False)
    
    # Save summary
    summary = {
        'config': CONFIG,
        'n_samples': len(features),
        'n_games': len(bifurcation_result.game_boundaries),
        'n_bifurcations': len(bifurcation_result.move_indices),
        'bifurcation_rate': len(bifurcation_result.move_indices) / len(features),
        'threshold': float(bifurcation_result.threshold),
        'divergence_stats': {
            'mean': float(np.nanmean(bifurcation_result.divergence_scores)),
            'std': float(np.nanstd(bifurcation_result.divergence_scores)),
            'max': float(np.nanmax(bifurcation_result.divergence_scores)),
        },
        'correlation': corr_data if correlation_result else None,
    }
    with open(output_path / 'analysis_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nResults saved to {output_path}")
    print(f"  - divergence_all.npy")
    print(f"  - bifurcation_indices.json")
    print(f"  - per_game_stats.csv")
    print(f"  - analysis_summary.json")
    if correlation_result:
        print(f"  - correlation_result.json")