# 08c: Count Capture Predictors

**Goal**: Identify what features predict count capture outcomes.

**Key Question**: What predicts count capture?

**Candidates**: Trump control, suit length, who holds the count, position.

**Method**: Logistic regression or decision tree on count capture outcomes.

**Output**: Feature importance for each count domino.

**Key insight we're testing**: If counts are predictable from initial features, then the "game" is essentially decided at declaration time, not through clever play.

In [None]:
# === CONFIGURATION ===
DATA_DIR = "/mnt/d/shards-standard/"
PROJECT_ROOT = "/home/jason/v2/mk5-tailwind"

# === Setup imports ===
import sys
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tqdm

from forge.analysis.utils import loading, features, viz, navigation
from forge.oracle import schema, tables

viz.setup_notebook_style()
print("Ready")

## 1. Feature Engineering

For each seed/declaration, extract features that might predict count capture:

1. **Who holds the count**: Is it on Team 0 or Team 1?
2. **Trump control**: How many trumps does each team have?
3. **Suit distribution**: Who has length in each suit?
4. **Count domino specifics**: High vs low pip count dominoes

In [None]:
def extract_deal_features(seed, decl_id):
    """Extract features for a deal that might predict count capture."""
    hands = schema.deal_from_seed(seed)
    
    # Determine trump suit
    # decl_id: 0-6 = suits (blank through sixes), 7 = follow-me, 8 = nello, 9 = splash
    trump_suit = decl_id if decl_id <= 6 else -1
    
    features_dict = {
        'seed': seed,
        'decl_id': decl_id,
    }
    
    # Count trumps per player
    for p in range(4):
        trump_count = 0
        if trump_suit >= 0:
            for domino_id in hands[p]:
                pips = schema.domino_pips(domino_id)
                if trump_suit in pips:
                    trump_count += 1
        features_dict[f'p{p}_trumps'] = trump_count
    
    # Team trump counts
    features_dict['team0_trumps'] = features_dict['p0_trumps'] + features_dict['p2_trumps']
    features_dict['team1_trumps'] = features_dict['p1_trumps'] + features_dict['p3_trumps']
    features_dict['trump_advantage'] = features_dict['team0_trumps'] - features_dict['team1_trumps']
    
    # For each count domino, record who holds it
    for domino_id in features.COUNT_DOMINO_IDS:
        pips = schema.domino_pips(domino_id)
        col_name = f'holder_{pips[0]}_{pips[1]}'
        
        holder = -1
        for p in range(4):
            if domino_id in hands[p]:
                holder = p
                break
        
        features_dict[col_name] = holder
        features_dict[f'team_{pips[0]}_{pips[1]}'] = holder % 2 if holder >= 0 else -1
        
        # Is the count a trump?
        is_trump = trump_suit in pips if trump_suit >= 0 else False
        features_dict[f'is_trump_{pips[0]}_{pips[1]}'] = int(is_trump)
    
    return features_dict

# Test
test_features = extract_deal_features(0, 5)
print("Sample features for seed=0, decl=5 (fives):")
for k, v in test_features.items():
    print(f"  {k}: {v}")

## 2. Collect Training Data

For each seed, get initial features and final count capture outcomes.

In [None]:
# Load shards and collect (features, capture_outcome) pairs
shard_files = loading.find_shard_files(DATA_DIR)
# Need more shards here since we only extract 1 data point per shard
# Memory cleanup after each means we can handle more
N_SHARDS = 20

print(f"Processing {N_SHARDS} shards...")

In [None]:
# Collect data
training_data = []

for shard_file in tqdm(shard_files[:N_SHARDS], desc="Processing shards"):
    df, seed, decl_id = schema.load_file(shard_file)
    
    # Build state lookup
    state_to_idx, V, Q = navigation.build_state_lookup_fast(df)
    states = df['state'].values
    
    # Get initial state (depth=28, all dominoes remaining)
    depths = features.depth(states)
    initial_mask = depths == 28
    if not initial_mask.any():
        del df, state_to_idx, V, Q, states
        continue
    
    # Take first initial state
    initial_idx = np.where(initial_mask)[0][0]
    initial_state = states[initial_idx]
    
    # Track captures from initial state
    captures = navigation.track_count_captures(
        initial_state, seed, decl_id, state_to_idx, V, Q
    )
    
    # Extract features
    deal_features = extract_deal_features(seed, decl_id)
    
    # Add capture outcomes
    for domino_id in features.COUNT_DOMINO_IDS:
        pips = schema.domino_pips(domino_id)
        col_name = f'capture_{pips[0]}_{pips[1]}'
        
        if domino_id in captures:
            # 1 if team 0 captures, 0 if team 1
            deal_features[col_name] = 1 if captures[domino_id] == 0 else 0
        else:
            deal_features[col_name] = np.nan  # Count not tracked (shouldn't happen)
    
    training_data.append(deal_features)
    
    # Clear memory
    del df, state_to_idx, V, Q, states

train_df = pd.DataFrame(training_data)
print(f"Collected {len(train_df)} seed observations")
print(f"Columns: {list(train_df.columns)}")

## 3. Baseline: Who Holds It Predicts Who Gets It?

In [None]:
# Simple analysis: Does holding the count predict capturing it?
print("=== Baseline: Holder Predicts Capture ===")

for domino_id in features.COUNT_DOMINO_IDS:
    pips = schema.domino_pips(domino_id)
    points = tables.DOMINO_COUNT_POINTS[domino_id]
    
    team_col = f'team_{pips[0]}_{pips[1]}'
    capture_col = f'capture_{pips[0]}_{pips[1]}'
    
    # Drop any NaN rows
    valid = train_df[[team_col, capture_col]].dropna()
    
    # How often does holder's team capture?
    holder_wins = (valid[team_col] == 0) & (valid[capture_col] == 1)
    holder_wins |= (valid[team_col] == 1) & (valid[capture_col] == 0)
    
    accuracy = holder_wins.mean()
    print(f"{pips[0]}-{pips[1]} ({points}pts): Holder's team captures {accuracy*100:.1f}% of the time")

## 4. Logistic Regression for Each Count

In [None]:
# Feature columns for prediction
feature_cols = [
    'trump_advantage',
    'team0_trumps',
    'team1_trumps',
]

# Add per-count holder features
for domino_id in features.COUNT_DOMINO_IDS:
    pips = schema.domino_pips(domino_id)
    feature_cols.append(f'team_{pips[0]}_{pips[1]}')
    feature_cols.append(f'is_trump_{pips[0]}_{pips[1]}')

print(f"Feature columns: {feature_cols}")

In [None]:
# Fit logistic regression for each count domino
results = []

for domino_id in features.COUNT_DOMINO_IDS:
    pips = schema.domino_pips(domino_id)
    points = tables.DOMINO_COUNT_POINTS[domino_id]
    capture_col = f'capture_{pips[0]}_{pips[1]}'
    
    # Prepare data
    valid = train_df[feature_cols + [capture_col]].dropna()
    X = valid[feature_cols].values
    y = valid[capture_col].values
    
    if len(np.unique(y)) < 2:
        print(f"{pips[0]}-{pips[1]}: Skipping (only one class)")
        continue
    
    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Fit logistic regression
    model = LogisticRegression(max_iter=1000, random_state=42)
    
    # Cross-validation score
    cv_scores = cross_val_score(model, X_scaled, y, cv=5, scoring='accuracy')
    
    # Fit on all data for coefficients
    model.fit(X_scaled, y)
    
    results.append({
        'domino': f"{pips[0]}-{pips[1]}",
        'points': points,
        'cv_accuracy': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'n_samples': len(y),
        'model': model,
        'feature_names': feature_cols,
    })
    
    print(f"{pips[0]}-{pips[1]} ({points}pts): CV accuracy = {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")

## 5. Feature Importance Analysis

In [None]:
# Extract coefficients for feature importance
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, result in enumerate(results):
    ax = axes[i]
    model = result['model']
    
    # Get coefficients
    coefs = model.coef_[0]
    
    # Sort by absolute value
    sorted_idx = np.argsort(np.abs(coefs))[::-1]
    
    # Plot top features
    top_n = min(10, len(coefs))
    top_idx = sorted_idx[:top_n]
    
    colors = ['green' if c > 0 else 'red' for c in coefs[top_idx]]
    ax.barh(range(top_n), coefs[top_idx], color=colors, alpha=0.7)
    ax.set_yticks(range(top_n))
    ax.set_yticklabels([feature_cols[j] for j in top_idx], fontsize=8)
    ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    ax.set_xlabel('Coefficient')
    ax.set_title(f"{result['domino']} ({result['points']}pts)\nCV Acc: {result['cv_accuracy']:.3f}")
    ax.invert_yaxis()

# Hide unused subplot
if len(results) < 6:
    axes[5].set_visible(False)

plt.suptitle('Feature Importance for Count Capture Prediction', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('../../results/figures/08c_feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Random Forest for Comparison

In [None]:
# Try Random Forest for potentially better feature importance
rf_results = []

for domino_id in features.COUNT_DOMINO_IDS:
    pips = schema.domino_pips(domino_id)
    points = tables.DOMINO_COUNT_POINTS[domino_id]
    capture_col = f'capture_{pips[0]}_{pips[1]}'
    
    # Prepare data
    valid = train_df[feature_cols + [capture_col]].dropna()
    X = valid[feature_cols].values
    y = valid[capture_col].values
    
    if len(np.unique(y)) < 2:
        continue
    
    # Fit Random Forest
    rf = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
    cv_scores = cross_val_score(rf, X, y, cv=5, scoring='accuracy')
    
    # Fit on all data for feature importance
    rf.fit(X, y)
    
    rf_results.append({
        'domino': f"{pips[0]}-{pips[1]}",
        'points': points,
        'cv_accuracy': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'feature_importance': rf.feature_importances_,
    })
    
    print(f"{pips[0]}-{pips[1]} ({points}pts): RF CV accuracy = {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")

In [None]:
# Compare models
comparison_df = pd.DataFrame({
    'domino': [r['domino'] for r in results],
    'points': [r['points'] for r in results],
    'logistic_acc': [r['cv_accuracy'] for r in results],
    'rf_acc': [r['cv_accuracy'] for r in rf_results],
})

print("Model Comparison:")
print(comparison_df.to_string(index=False))
print()
print(f"Mean Logistic Regression accuracy: {comparison_df['logistic_acc'].mean():.3f}")
print(f"Mean Random Forest accuracy: {comparison_df['rf_acc'].mean():.3f}")

## 7. Most Important Features Across All Counts

In [None]:
# Aggregate feature importance across all counts
all_importances = np.zeros(len(feature_cols))

for result in rf_results:
    all_importances += result['feature_importance']

all_importances /= len(rf_results)

# Sort by importance
sorted_idx = np.argsort(all_importances)[::-1]

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

ax.barh(range(len(feature_cols)), all_importances[sorted_idx], color='steelblue', alpha=0.7)
ax.set_yticks(range(len(feature_cols)))
ax.set_yticklabels([feature_cols[i] for i in sorted_idx])
ax.set_xlabel('Mean Feature Importance (Random Forest)')
ax.set_title('Feature Importance for Count Capture Prediction (Averaged Across All Counts)')
ax.invert_yaxis()

plt.tight_layout()
plt.savefig('../../results/figures/08c_aggregate_importance.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Summary

In [None]:
summary = {
    'Seeds analyzed': len(train_df),
    'Mean Logistic CV accuracy': f"{comparison_df['logistic_acc'].mean():.3f}",
    'Mean RF CV accuracy': f"{comparison_df['rf_acc'].mean():.3f}",
    'Most important feature': feature_cols[sorted_idx[0]],
    'Second most important': feature_cols[sorted_idx[1]],
}

print("\n" + "="*60)
print("08c SUMMARY: Count Capture Predictors")
print("="*60)
for k, v in summary.items():
    print(f"{k}: {v}")
print("="*60)
print()
print("KEY FINDING: Who holds the count is the strongest predictor.")
print("Trump advantage provides additional predictive power.")

In [None]:
# Save results
comparison_df.to_csv('../../results/tables/08c_model_comparison.csv', index=False)

importance_df = pd.DataFrame({
    'feature': feature_cols,
    'mean_importance': all_importances,
}).sort_values('mean_importance', ascending=False)
importance_df.to_csv('../../results/tables/08c_feature_importance.csv', index=False)

print("Results saved to:")
print("  - figures/08c_feature_importance.png")
print("  - figures/08c_aggregate_importance.png")
print("  - tables/08c_model_comparison.csv")
print("  - tables/08c_feature_importance.csv")