# 03b: Basin Analysis

**Goal**: Partition states by their ultimate count capture outcomes.

**Key Questions**:
1. Can we trace states to terminal via principal variation?
2. What is the distribution of count capture outcomes?
3. How much V variance exists within capture-outcome basins?

A "basin" is a set of states that lead to the same count capture outcome under optimal play.

**Reference**: docs/analysis-draft.md Section 6

In [None]:
# === CONFIGURATION ===
DATA_DIR = "/mnt/d/shards-standard/"
PROJECT_ROOT = "/home/jason/v2/mk5-tailwind"

# === Setup imports ===
import sys
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from forge.analysis.utils import loading, features, viz, navigation
from forge.oracle import schema, tables

viz.setup_notebook_style()
print("Ready")

## 1. Load Single Seed for Deep Analysis

Basin analysis is computationally intensive, so we start with one seed.

In [None]:
# Load first available shard
shard_files = loading.find_shard_files(DATA_DIR)
df, seed, decl_id = schema.load_file(shard_files[0])

print(f"Seed: {seed}")
print(f"Declaration: {decl_id} ({schema.DECL_NAMES[decl_id]})")
print(f"Total states: {len(df):,}")

In [None]:
# Build fast state lookup
state_to_idx, V, Q = navigation.build_state_lookup_fast(df)
states = df['state'].values

print(f"Built lookup for {len(state_to_idx):,} states")

## 2. Principal Variation Tracing

Trace a few states to terminal to verify the navigation works.

In [None]:
# Sample a state and trace its PV
sample_state = states[len(states) // 2]  # Middle of the file

pv = navigation.trace_principal_variation(
    sample_state, seed, decl_id, state_to_idx, V, Q
)

print(f"Principal variation from sample state:")
print(f"  Length: {len(pv)} moves")
print(f"  Start V: {pv[0][1]}")
print(f"  Terminal V: {pv[-1][1]}")
print(f"\nFirst 5 moves:")
for i, (state, v, move) in enumerate(pv[:5]):
    print(f"  {i}: V={v:+3d}, move={move}")

In [None]:
# Test count capture tracking
captures = navigation.track_count_captures(
    sample_state, seed, decl_id, state_to_idx, V, Q
)

print("Count captures along PV:")
for domino_id, team in captures.items():
    pips = schema.domino_pips(domino_id)
    points = tables.DOMINO_COUNT_POINTS[domino_id]
    print(f"  {pips[0]}-{pips[1]} ({points} pts): Team {team}")

t0, t1 = navigation.count_capture_signature(captures)
print(f"\nCapture signature: Team 0 = {t0}, Team 1 = {t1}")

## 3. Basin Computation

Trace all states to compute their capture outcomes.

In [None]:
# Sample states for basin analysis (full computation is expensive)
N_SAMPLE = min(50000, len(states))
sample_indices = np.random.choice(len(states), N_SAMPLE, replace=False)
sample_states = states[sample_indices]

print(f"Computing basins for {N_SAMPLE:,} states...")

In [None]:
# Compute capture outcomes for sampled states
basin_data = []

for i, state in enumerate(tqdm(sample_states, desc="Tracing PVs")):
    captures = navigation.track_count_captures(
        state, seed, decl_id, state_to_idx, V, Q
    )
    t0_capture, t1_capture = navigation.count_capture_signature(captures)
    
    idx = sample_indices[i]
    basin_data.append({
        'state': state,
        'V': V[idx],
        'depth': features.depth(np.array([state]))[0],
        't0_capture': t0_capture,
        't1_capture': t1_capture,
        'capture_diff': t0_capture - t1_capture,
    })

basin_df = pd.DataFrame(basin_data)
print(f"Basin computation complete: {len(basin_df):,} states")

## 4. Basin Distribution

In [None]:
# Create basin signature (tuple of team 0 and team 1 captures)
basin_df['basin'] = list(zip(basin_df['t0_capture'], basin_df['t1_capture']))

# Count states per basin
basin_counts = basin_df['basin'].value_counts()
print(f"Number of unique basins: {len(basin_counts)}")
print(f"\nTop 10 basins by state count:")
print(basin_counts.head(10))

In [None]:
# Visualize basin distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Capture difference distribution
axes[0].hist(basin_df['capture_diff'], bins=15, color='steelblue', alpha=0.7, edgecolor='black')
axes[0].axvline(x=0, color='red', linestyle='--', label='Balanced')
axes[0].set_xlabel('Capture Difference (Team 0 - Team 1)')
axes[0].set_ylabel('Number of States')
axes[0].set_title('Distribution of Count Capture Outcomes')
axes[0].legend()

# Basin size distribution
axes[1].hist(basin_counts.values, bins=30, color='coral', alpha=0.7, edgecolor='black')
axes[1].set_xlabel('States per Basin')
axes[1].set_ylabel('Number of Basins')
axes[1].set_title('Basin Size Distribution')
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig('../../results/figures/03b_basin_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Within-Basin V Variance

Key question: How much V variance exists within each capture-outcome basin?

In [None]:
# Compute V statistics per basin
basin_v_stats = basin_df.groupby('basin')['V'].agg(['mean', 'std', 'min', 'max', 'count'])
basin_v_stats = basin_v_stats[basin_v_stats['count'] >= 10]  # Filter small basins

print(f"Basin V statistics (basins with >= 10 states):")
print(basin_v_stats.describe())

In [None]:
# Average within-basin variance
# Weight by basin size
weights = basin_v_stats['count']
avg_within_var = np.average(basin_v_stats['std']**2, weights=weights)
avg_within_std = np.sqrt(avg_within_var)

# Total variance
total_var = basin_df['V'].var()
total_std = basin_df['V'].std()

# Variance explained by basins
var_explained = 1 - avg_within_var / total_var

print(f"Total V std: {total_std:.2f}")
print(f"Avg within-basin std: {avg_within_std:.2f}")
print(f"Variance explained by capture basins: {100*var_explained:.1f}%")

In [None]:
# Visualize V distribution within selected basins
top_basins = basin_counts.head(6).index.tolist()

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, basin in enumerate(top_basins):
    basin_v = basin_df[basin_df['basin'] == basin]['V']
    ax = axes[i]
    ax.hist(basin_v, bins=30, color='steelblue', alpha=0.7, edgecolor='black')
    ax.axvline(x=basin_v.mean(), color='red', linestyle='--', label=f'Mean={basin_v.mean():.1f}')
    ax.set_xlabel('V')
    ax.set_ylabel('Count')
    ax.set_title(f'Basin {basin}\n(n={len(basin_v)}, std={basin_v.std():.1f})')
    ax.legend()

plt.tight_layout()
plt.savefig('../../results/figures/03b_basin_v_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Capture Outcome vs V Relationship

In [None]:
# Correlation between capture difference and V
corr = np.corrcoef(basin_df['capture_diff'], basin_df['V'])[0, 1]
print(f"Correlation(capture_diff, V) = {corr:.4f}")

# Mean V by capture difference
v_by_capture = basin_df.groupby('capture_diff')['V'].agg(['mean', 'std', 'count'])
print("\nMean V by capture difference:")
print(v_by_capture)

In [None]:
# Plot V vs capture difference
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Scatter plot
axes[0].scatter(basin_df['capture_diff'], basin_df['V'], alpha=0.1, s=1)
axes[0].set_xlabel('Capture Difference (Team 0 - Team 1)')
axes[0].set_ylabel('V')
axes[0].set_title(f'V vs Capture Difference (r={corr:.3f})')

# Add regression line
z = np.polyfit(basin_df['capture_diff'], basin_df['V'], 1)
p = np.poly1d(z)
x_line = np.linspace(basin_df['capture_diff'].min(), basin_df['capture_diff'].max(), 100)
axes[0].plot(x_line, p(x_line), 'r-', linewidth=2, label=f'y = {z[0]:.2f}x + {z[1]:.2f}')
axes[0].legend()

# Mean V with error bars
significant = v_by_capture[v_by_capture['count'] >= 50]
axes[1].errorbar(
    significant.index,
    significant['mean'],
    yerr=significant['std'] / np.sqrt(significant['count']),
    fmt='o-', markersize=8, capsize=4
)
axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[1].set_xlabel('Capture Difference')
axes[1].set_ylabel('Mean V')
axes[1].set_title('Mean V by Capture Outcome')

plt.tight_layout()
plt.savefig('../../results/figures/03b_v_vs_capture.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Depth-Stratified Analysis

In [None]:
# Within-basin variance by depth
depth_variance = []

for d in basin_df['depth'].unique():
    depth_mask = basin_df['depth'] == d
    depth_data = basin_df[depth_mask]
    
    if len(depth_data) < 100:
        continue
    
    # Compute within-basin variance at this depth
    depth_basin_stats = depth_data.groupby('basin')['V'].agg(['var', 'count'])
    depth_basin_stats = depth_basin_stats[depth_basin_stats['count'] >= 5]
    
    if len(depth_basin_stats) > 0:
        avg_var = np.average(depth_basin_stats['var'], weights=depth_basin_stats['count'])
        total_var = depth_data['V'].var()
        var_explained = 1 - avg_var / total_var if total_var > 0 else 0
        
        depth_variance.append({
            'depth': d,
            'total_var': total_var,
            'within_var': avg_var,
            'var_explained': var_explained,
            'n_states': len(depth_data),
            'n_basins': len(depth_basin_stats),
        })

depth_var_df = pd.DataFrame(depth_variance).sort_values('depth')
print("Variance explained by basins at each depth:")
print(depth_var_df.to_string(index=False))

In [None]:
# Plot variance explained by depth
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(depth_var_df['depth'], 100 * depth_var_df['var_explained'], 'o-', markersize=8)
ax.set_xlabel('Depth (dominoes remaining)')
ax.set_ylabel('Variance Explained by Capture Basins (%)')
ax.set_title('How Much V Variance is Explained by Count Capture Outcomes?')
ax.axhline(y=50, color='red', linestyle='--', alpha=0.5, label='50%')
ax.legend()
ax.set_ylim(0, 100)

plt.tight_layout()
plt.savefig('../../results/figures/03b_var_explained_by_depth.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary

In [None]:
summary = {
    'States analyzed': f"{len(basin_df):,}",
    'Unique basins': len(basin_counts),
    'Total V std': f"{total_std:.2f}",
    'Avg within-basin std': f"{avg_within_std:.2f}",
    'Variance explained': f"{100*var_explained:.1f}%",
    'Corr(capture_diff, V)': f"{corr:.4f}",
}

print(viz.create_summary_table(summary, "Basin Analysis Summary"))

In [None]:
# Save results
basin_df.to_csv('../../results/tables/03b_basin_data.csv', index=False)
depth_var_df.to_csv('../../results/tables/03b_depth_variance.csv', index=False)
print("Results saved to results/tables/")