# 03c: Capture Probability Model

**Goal**: Fit a linear model predicting V from count capture indicators.

**Model**: V = sum(count_value * capture_indicator) + trick_points + error

**Key Questions**:
1. What is the R^2 of a simple capture-based model?
2. Are the coefficients close to the count point values?
3. What does the residual structure look like?

**Reference**: docs/analysis-draft.md Section 6

In [None]:
# === CONFIGURATION ===
DATA_DIR = "/mnt/d/shards-standard/"
PROJECT_ROOT = "/home/jason/v2/mk5-tailwind"

# === Setup imports ===
import sys
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error
from tqdm.notebook import tqdm

from forge.analysis.utils import loading, features, viz, navigation
from forge.oracle import schema, tables

viz.setup_notebook_style()
print("Ready")

## 1. Load Data and Compute Capture Outcomes

In [None]:
# Load first shard
shard_files = loading.find_shard_files(DATA_DIR)
df, seed, decl_id = schema.load_file(shard_files[0])

print(f"Seed: {seed}")
print(f"Declaration: {decl_id} ({schema.DECL_NAMES[decl_id]})")
print(f"Total states: {len(df):,}")

In [None]:
# Build state lookup
state_to_idx, V, Q = navigation.build_state_lookup_fast(df)
states = df['state'].values

print(f"Built lookup for {len(state_to_idx):,} states")

In [None]:
# Sample states and compute capture outcomes
N_SAMPLE = min(30000, len(states))
sample_indices = np.random.choice(len(states), N_SAMPLE, replace=False)
sample_states = states[sample_indices]

print(f"Sampling {N_SAMPLE:,} states for modeling...")

In [None]:
# Compute capture outcomes with individual domino tracking
model_data = []

for i, state in enumerate(tqdm(sample_states, desc="Computing captures")):
    captures = navigation.track_count_captures(
        state, seed, decl_id, state_to_idx, V, Q
    )
    
    idx = sample_indices[i]
    row = {
        'state': state,
        'V': V[idx],
        'depth': features.depth(np.array([state]))[0],
    }
    
    # Individual capture indicators: +1 if team 0 captures, -1 if team 1, 0 if already played
    for domino_id in features.COUNT_DOMINO_IDS:
        pips = schema.domino_pips(domino_id)
        col_name = f"capture_{pips[0]}_{pips[1]}"
        if domino_id in captures:
            row[col_name] = 1 if captures[domino_id] == 0 else -1
        else:
            row[col_name] = 0  # Already played before this state
    
    model_data.append(row)

model_df = pd.DataFrame(model_data)
print(f"Model data ready: {len(model_df):,} samples")

## 2. Build Feature Matrix

In [None]:
# Identify capture columns
capture_cols = [c for c in model_df.columns if c.startswith('capture_')]
print(f"Capture features: {capture_cols}")

# Feature matrix: capture indicators weighted by point values
X_simple = model_df[capture_cols].values
y = model_df['V'].values

# Get point values for each capture column
point_values = []
for col in capture_cols:
    # Parse domino from column name
    parts = col.split('_')
    high, low = int(parts[1]), int(parts[2])
    # Find domino ID
    for d in features.COUNT_DOMINO_IDS:
        pips = schema.domino_pips(d)
        if pips == (high, low):
            point_values.append(tables.DOMINO_COUNT_POINTS[d])
            break

point_values = np.array(point_values)
print(f"Point values: {point_values}")

## 3. Model 1: Simple Capture Model

V = sum(capture_indicator * count_points)

In [None]:
# Simple model: V_pred = sum(capture * points)
# capture is +1 for team 0, -1 for team 1
# So this predicts net count capture advantage

V_pred_simple = (X_simple * point_values).sum(axis=1)

r2_simple = r2_score(y, V_pred_simple)
rmse_simple = np.sqrt(mean_squared_error(y, V_pred_simple))
corr_simple = np.corrcoef(y, V_pred_simple)[0, 1]

print(f"Simple Capture Model: V = sum(capture * points)")
print(f"  R^2: {r2_simple:.4f}")
print(f"  RMSE: {rmse_simple:.2f}")
print(f"  Correlation: {corr_simple:.4f}")

## 4. Model 2: Learned Coefficients

In [None]:
# Fit linear regression with learned coefficients
reg = LinearRegression()
reg.fit(X_simple, y)

V_pred_learned = reg.predict(X_simple)
r2_learned = r2_score(y, V_pred_learned)
rmse_learned = np.sqrt(mean_squared_error(y, V_pred_learned))

print(f"Learned Coefficient Model:")
print(f"  R^2: {r2_learned:.4f}")
print(f"  RMSE: {rmse_learned:.2f}")
print(f"  Intercept: {reg.intercept_:.2f}")
print(f"\nCoefficients vs True Point Values:")
for i, col in enumerate(capture_cols):
    print(f"  {col}: learned={reg.coef_[i]:.2f}, true={point_values[i]}")

## 5. Model 3: With Depth Feature

In [None]:
# Add depth as a feature
X_with_depth = np.column_stack([X_simple, model_df['depth'].values])

reg_depth = LinearRegression()
reg_depth.fit(X_with_depth, y)

V_pred_depth = reg_depth.predict(X_with_depth)
r2_depth = r2_score(y, V_pred_depth)
rmse_depth = np.sqrt(mean_squared_error(y, V_pred_depth))

print(f"Model with Depth:")
print(f"  R^2: {r2_depth:.4f}")
print(f"  RMSE: {rmse_depth:.2f}")
print(f"  Depth coefficient: {reg_depth.coef_[-1]:.4f}")

## 6. Model Comparison

In [None]:
# Compare models
model_comparison = pd.DataFrame({
    'Model': ['Simple (fixed coef)', 'Learned coef', 'Learned + depth'],
    'R^2': [r2_simple, r2_learned, r2_depth],
    'RMSE': [rmse_simple, rmse_learned, rmse_depth],
})

print("Model Comparison:")
print(model_comparison.to_string(index=False))

In [None]:
# Visualize model comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

models = [
    ('Simple', V_pred_simple, r2_simple),
    ('Learned', V_pred_learned, r2_learned),
    ('Learned+Depth', V_pred_depth, r2_depth),
]

for i, (name, pred, r2) in enumerate(models):
    ax = axes[i]
    ax.scatter(y, pred, alpha=0.1, s=1)
    ax.plot([-42, 42], [-42, 42], 'r--', linewidth=2, label='Perfect')
    ax.set_xlabel('True V')
    ax.set_ylabel('Predicted V')
    ax.set_title(f'{name} Model\nR^2 = {r2:.4f}')
    ax.set_xlim(-45, 45)
    ax.set_ylim(-45, 45)
    ax.legend()

plt.tight_layout()
plt.savefig('../../results/figures/03c_model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Residual Analysis

In [None]:
# Residuals from best model
residuals = y - V_pred_depth

print(f"Residual statistics:")
print(f"  Mean: {residuals.mean():.4f}")
print(f"  Std: {residuals.std():.2f}")
print(f"  Min: {residuals.min():.2f}")
print(f"  Max: {residuals.max():.2f}")

In [None]:
# Residual plots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Residual histogram
axes[0, 0].hist(residuals, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
axes[0, 0].axvline(x=0, color='red', linestyle='--')
axes[0, 0].set_xlabel('Residual (V - V_pred)')
axes[0, 0].set_ylabel('Count')
axes[0, 0].set_title('Residual Distribution')

# Residuals vs predicted
axes[0, 1].scatter(V_pred_depth, residuals, alpha=0.1, s=1)
axes[0, 1].axhline(y=0, color='red', linestyle='--')
axes[0, 1].set_xlabel('Predicted V')
axes[0, 1].set_ylabel('Residual')
axes[0, 1].set_title('Residuals vs Predicted')

# Residuals vs depth
axes[1, 0].scatter(model_df['depth'], residuals, alpha=0.1, s=1)
axes[1, 0].axhline(y=0, color='red', linestyle='--')
axes[1, 0].set_xlabel('Depth')
axes[1, 0].set_ylabel('Residual')
axes[1, 0].set_title('Residuals vs Depth')

# Q-Q plot
from scipy import stats
stats.probplot(residuals, dist="norm", plot=axes[1, 1])
axes[1, 1].set_title('Q-Q Plot (Normality Check)')

plt.tight_layout()
plt.savefig('../../results/figures/03c_residual_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. What Do Residuals Capture?

The residuals represent V variance NOT explained by count captures.

In [None]:
# Residual magnitude by depth
residual_df = pd.DataFrame({
    'depth': model_df['depth'],
    'residual': residuals,
    'abs_residual': np.abs(residuals),
})

residual_by_depth = residual_df.groupby('depth')['abs_residual'].agg(['mean', 'std', 'count'])
print("Mean absolute residual by depth:")
print(residual_by_depth)

In [None]:
# Plot residual vs depth
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(residual_by_depth.index, residual_by_depth['mean'], 'o-', markersize=8)
ax.fill_between(
    residual_by_depth.index,
    residual_by_depth['mean'] - residual_by_depth['std'],
    residual_by_depth['mean'] + residual_by_depth['std'],
    alpha=0.3
)
ax.set_xlabel('Depth (dominoes remaining)')
ax.set_ylabel('Mean Absolute Residual')
ax.set_title('Unexplained V Variance by Depth')

plt.tight_layout()
plt.savefig('../../results/figures/03c_residual_by_depth.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary

In [None]:
summary = {
    'States modeled': f"{len(model_df):,}",
    'Simple model R^2': f"{r2_simple:.4f}",
    'Learned model R^2': f"{r2_learned:.4f}",
    'Full model R^2': f"{r2_depth:.4f}",
    'Full model RMSE': f"{rmse_depth:.2f}",
    'Residual std': f"{residuals.std():.2f}",
}

print(viz.create_summary_table(summary, "Capture Probability Model Summary"))

In [None]:
# Save results
model_comparison.to_csv('../../results/tables/03c_model_comparison.csv', index=False)

# Save coefficients
coef_df = pd.DataFrame({
    'feature': capture_cols + ['depth'],
    'coefficient': list(reg_depth.coef_),
    'true_value': list(point_values) + [np.nan],
})
coef_df.to_csv('../../results/tables/03c_coefficients.csv', index=False)

print("Results saved to results/tables/")