# 11a: Count Lock Rate Analysis

**Goal**: Determine which count dominoes a hand "locks" (captures with high probability across opponent configs).

**Data**: Marginalized shards - same P0 hand, 3 different opponent configurations.

**Key Questions**:
1. For each count domino, what's P(Team 0 captures) across opponent configs?
2. Which counts are "locked" (P > 0.95) vs "contested" (P â‰ˆ 0.5)?
3. Does holding a count domino predict locking it?

**Reference**: Imperfect Information Analysis Suite (t42-q0be)

In [None]:
# === CONFIGURATION ===
DATA_DIR = "/mnt/d/shards-marginalized/train"
PROJECT_ROOT = "/home/jason/v2/mk5-tailwind"

# === Setup imports ===
import sys
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict
from tqdm.notebook import tqdm

from forge.analysis.utils import loading, features, viz, navigation
from forge.oracle import schema, tables
from forge.oracle.rng import deal_from_seed, deal_with_fixed_p0

viz.setup_notebook_style()
print("Ready")

## 1. Explore Marginalized Data Structure

In [None]:
# Count available shards
marg_dir = Path(DATA_DIR)
shard_files = list(marg_dir.glob("seed_*_opp*_decl_*.parquet"))
print(f"Total marginalized shard files: {len(shard_files)}")

# Extract unique base seeds
base_seeds = sorted(set(int(f.name.split('_')[1]) for f in shard_files))
print(f"Unique base seeds: {len(base_seeds)}")
print(f"Base seed range: {min(base_seeds)} - {max(base_seeds)}")

In [None]:
# Define count dominoes
print("Count dominoes:")
for d in features.COUNT_DOMINO_IDS:
    pips = schema.domino_pips(d)
    points = tables.DOMINO_COUNT_POINTS[d]
    print(f"  {pips[0]}-{pips[1]}: {points} points (id={d})")

## 2. Load and Analyze Count Captures

For each base_seed:
1. Load all 3 opponent configs
2. Find root state and trace to terminal
3. Track which team captures each count domino

In [None]:
def get_root_captures_marginalized(df, p0_hand, opp_seed, decl_id):
    """Trace from root state to terminal and track count captures.
    
    For marginalized data, we reconstruct hands using deal_with_fixed_p0.
    """
    # Build state lookup
    state_to_idx, V, Q = navigation.build_state_lookup_fast(df)
    
    # Find root state (depth = 28)
    depths = features.depth(df['state'].values)
    root_mask = depths == 28
    root_states = df.loc[root_mask, 'state'].values
    
    if len(root_states) == 0:
        return None, None
    
    root_state = root_states[0]
    root_v = df.loc[root_mask, 'V'].values[0]
    
    # Reconstruct hands for this opponent config
    hands = deal_with_fixed_p0(p0_hand, opp_seed)
    
    # Build domino -> (player, local_idx) lookup
    domino_to_location = {}
    for p in range(4):
        for local_idx, domino_id in enumerate(hands[p]):
            domino_to_location[domino_id] = (p, local_idx)
    
    # Track count captures by tracing PV
    pv = navigation.trace_principal_variation(
        root_state, opp_seed, decl_id, state_to_idx, V, Q
    )
    
    captures = {}
    
    # Process each move in the PV to track trick outcomes
    for i, (state, v, move) in enumerate(pv):
        if move < 0:
            break
        
        remaining, leader, trick_len, p0, p1, p2 = navigation.unpack_state_single(state)
        current_player = (leader + trick_len) % 4
        
        # If completing a trick, determine who wins
        if trick_len == 3:
            trick_plays = [p0, p1, p2, move]
            trick_dominos = tuple(
                hands[(leader + j) % 4][trick_plays[j]]
                for j in range(4)
            )
            
            outcome = tables.resolve_trick(trick_dominos[0], trick_dominos, decl_id)
            winner = (leader + outcome.winner_offset) % 4
            winning_team = winner % 2
            
            # Check which count dominoes are in this trick
            for domino_id in trick_dominos:
                if domino_id in features.COUNT_DOMINO_IDS and domino_id not in captures:
                    captures[domino_id] = winning_team
    
    return root_v, captures

In [None]:
# Analyze first N base seeds (memory-constrained)
N_SEEDS = min(50, len(base_seeds))  # Analyze 50 base seeds
sample_seeds = base_seeds[:N_SEEDS]

results = []

for base_seed in tqdm(sample_seeds, desc="Analyzing base seeds"):
    decl_id = base_seed % 10
    p0_hand = deal_from_seed(base_seed)[0]
    
    # Which counts does P0 hold?
    p0_counts = set(d for d in features.COUNT_DOMINO_IDS if d in p0_hand)
    
    row = {
        'base_seed': base_seed,
        'decl_id': decl_id,
        'p0_hand': tuple(p0_hand),
    }
    
    # Track count holdings
    for d in features.COUNT_DOMINO_IDS:
        pips = schema.domino_pips(d)
        row[f'holds_{pips[0]}_{pips[1]}'] = d in p0_counts
    
    # Load each opponent config
    for opp_seed in range(3):
        path = marg_dir / f"seed_{base_seed:08d}_opp{opp_seed}_decl_{decl_id}.parquet"
        
        try:
            df = pd.read_parquet(path)
            
            root_v, captures = get_root_captures_marginalized(df, p0_hand, opp_seed, decl_id)
            
            row[f'V_opp{opp_seed}'] = root_v
            
            if captures:
                for d in features.COUNT_DOMINO_IDS:
                    pips = schema.domino_pips(d)
                    col = f'capture_{pips[0]}_{pips[1]}_opp{opp_seed}'
                    # 1 if Team 0 captures, 0 if Team 1, None if not tracked
                    row[col] = 1 if captures.get(d) == 0 else (0 if d in captures else None)
            
            del df
            gc.collect()
            
        except FileNotFoundError:
            row[f'V_opp{opp_seed}'] = None
    
    results.append(row)

results_df = pd.DataFrame(results)
print(f"\nAnalyzed {len(results_df)} base seeds")

## 3. Compute Lock Rates

In [None]:
# For each count domino, compute capture rate across opponent configs
lock_rates = []

for d in features.COUNT_DOMINO_IDS:
    pips = schema.domino_pips(d)
    points = tables.DOMINO_COUNT_POINTS[d]
    
    # Get capture columns for this domino
    capture_cols = [f'capture_{pips[0]}_{pips[1]}_opp{i}' for i in range(3)]
    holds_col = f'holds_{pips[0]}_{pips[1]}'
    
    for _, row in results_df.iterrows():
        captures = [row.get(c) for c in capture_cols]
        valid_captures = [c for c in captures if c is not None]
        
        if len(valid_captures) >= 2:  # Need at least 2 valid samples
            lock_rates.append({
                'base_seed': row['base_seed'],
                'domino': f"{pips[0]}-{pips[1]}",
                'domino_id': d,
                'points': points,
                'p0_holds': row.get(holds_col, False),
                'capture_rate': np.mean(valid_captures),
                'n_samples': len(valid_captures),
                'all_captured': all(c == 1 for c in valid_captures),
                'all_lost': all(c == 0 for c in valid_captures),
            })

lock_df = pd.DataFrame(lock_rates)
print(f"Lock rate observations: {len(lock_df)}")
lock_df.head(10)

In [None]:
# Aggregate statistics by domino and holding status
summary = lock_df.groupby(['domino', 'points', 'p0_holds']).agg({
    'capture_rate': ['mean', 'std', 'count'],
    'all_captured': 'mean',  # % where always captured
    'all_lost': 'mean',  # % where always lost
}).round(3)

summary.columns = ['mean_rate', 'std_rate', 'n_hands', 'pct_locked', 'pct_lost']
print("\nCapture rate by domino and holding status:")
print(summary.to_string())

## 4. Visualize Lock Rates

In [None]:
# Plot capture rate distribution by domino
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Box plot of capture rates by domino
ax = axes[0]
dominoes = sorted(lock_df['domino'].unique())
data_by_dom = [lock_df[lock_df['domino'] == d]['capture_rate'].values for d in dominoes]
bp = ax.boxplot(data_by_dom, labels=dominoes, patch_artist=True)
ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Contested (50%)')
ax.axhline(y=0.95, color='green', linestyle='--', alpha=0.5, label='Locked (95%)')
ax.set_xlabel('Count Domino')
ax.set_ylabel('Capture Rate (Team 0)')
ax.set_title('Capture Rate Distribution by Count Domino')
ax.legend()
ax.set_ylim(-0.05, 1.05)

# Right: Compare holding vs not holding
ax = axes[1]
holds_data = lock_df.groupby('p0_holds')['capture_rate'].apply(list)
bp = ax.boxplot([holds_data[False], holds_data[True]], 
                labels=['P0 does NOT hold', 'P0 HOLDS'],
                patch_artist=True)
bp['boxes'][0].set_facecolor('lightcoral')
bp['boxes'][1].set_facecolor('lightgreen')
ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.5)
ax.set_ylabel('Capture Rate (Team 0)')
ax.set_title('Capture Rate: Holding vs Not Holding Count')
ax.set_ylim(-0.05, 1.05)

plt.tight_layout()
plt.savefig('../../results/figures/11a_count_lock_rate.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Detailed heatmap: domino x holding status
pivot = lock_df.groupby(['domino', 'p0_holds'])['capture_rate'].mean().unstack()
pivot.columns = ['Not Holding', 'Holding']

fig, ax = plt.subplots(figsize=(8, 5))
im = ax.imshow(pivot.values, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
ax.set_xticks([0, 1])
ax.set_xticklabels(pivot.columns)
ax.set_yticks(range(len(pivot.index)))
ax.set_yticklabels(pivot.index)
ax.set_xlabel('P0 Status')
ax.set_ylabel('Count Domino')
ax.set_title('Mean Capture Rate by Domino and Holding Status')

# Add text annotations
for i in range(len(pivot.index)):
    for j in range(2):
        val = pivot.values[i, j]
        color = 'white' if val < 0.3 or val > 0.7 else 'black'
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', color=color, fontsize=12)

plt.colorbar(im, label='Capture Rate')
plt.tight_layout()
plt.savefig('../../results/figures/11a_lock_rate_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Lock Classification

In [None]:
# Classify each (hand, count) pair
def classify_lock(rate):
    if rate >= 0.95:
        return 'locked'
    elif rate <= 0.05:
        return 'lost'
    elif 0.4 <= rate <= 0.6:
        return 'contested'
    elif rate > 0.6:
        return 'likely'
    else:
        return 'unlikely'

lock_df['classification'] = lock_df['capture_rate'].apply(classify_lock)

# Summary by classification
class_summary = lock_df.groupby(['p0_holds', 'classification']).size().unstack(fill_value=0)
print("Lock classification by holding status:")
print(class_summary)
print()

# Percentages
class_pct = class_summary.div(class_summary.sum(axis=1), axis=0) * 100
print("Percentages:")
print(class_pct.round(1))

## 6. V Distribution Analysis

In [None]:
# Analyze V variance across opponent configs
v_cols = ['V_opp0', 'V_opp1', 'V_opp2']

results_df['V_mean'] = results_df[v_cols].mean(axis=1)
results_df['V_std'] = results_df[v_cols].std(axis=1)
results_df['V_spread'] = results_df[v_cols].max(axis=1) - results_df[v_cols].min(axis=1)

print("V statistics across opponent configs:")
print(f"  Mean V spread: {results_df['V_spread'].mean():.1f} points")
print(f"  Max V spread: {results_df['V_spread'].max():.0f} points")
print(f"  Hands with spread > 40: {(results_df['V_spread'] > 40).sum()}")
print(f"  Hands with spread < 10: {(results_df['V_spread'] < 10).sum()}")

In [None]:
# Plot V spread distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ax = axes[0]
ax.hist(results_df['V_spread'], bins=20, color='steelblue', edgecolor='black', alpha=0.7)
ax.axvline(x=results_df['V_spread'].mean(), color='red', linestyle='--', 
           label=f"Mean: {results_df['V_spread'].mean():.1f}")
ax.set_xlabel('V Spread (max - min across 3 opponent configs)')
ax.set_ylabel('Count')
ax.set_title('How Much Does V Vary by Opponent Hands?')
ax.legend()

ax = axes[1]
ax.scatter(results_df['V_mean'], results_df['V_std'], alpha=0.6, s=30)
ax.set_xlabel('Mean V (across opponent configs)')
ax.set_ylabel('Std Dev of V')
ax.set_title('V Mean vs Variance')
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig('../../results/figures/11a_v_variance.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Summary

In [None]:
# Key findings
holds_rate = lock_df[lock_df['p0_holds'] == True]['capture_rate'].mean()
not_holds_rate = lock_df[lock_df['p0_holds'] == False]['capture_rate'].mean()

summary_stats = {
    'Base seeds analyzed': len(results_df),
    'Lock observations': len(lock_df),
    'Mean capture rate (P0 holds)': f'{holds_rate:.1%}',
    'Mean capture rate (P0 not holds)': f'{not_holds_rate:.1%}',
    'Mean V spread': f'{results_df["V_spread"].mean():.1f} points',
    'Max V spread': f'{results_df["V_spread"].max():.0f} points',
}

print(viz.create_summary_table(summary_stats, "Count Lock Rate Analysis Summary"))

In [None]:
# Save results
lock_df.to_csv('../../results/tables/11a_count_lock_rates.csv', index=False)
results_df.to_csv('../../results/tables/11a_base_seed_analysis.csv', index=False)

print("Results saved to results/tables/")