# Hierarchical Model Inspection

Interactive notebook for inspecting trained hierarchical model behavior:
- Model loading and parameter summary
- Identity analysis (archetype distribution, centroids)
- Prediction variance by player identity
- Flat vs hierarchical model comparison
- Identity ablation experiments

## Section 1: Setup & Model Loading

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from fantasy_baseball_manager.contextual.model.config import ModelConfig
from fantasy_baseball_manager.contextual.model.hierarchical_config import HierarchicalModelConfig
from fantasy_baseball_manager.contextual.persistence import ContextualModelStore
from fantasy_baseball_manager.contextual.identity.archetypes import load_archetype_model
from fantasy_baseball_manager.contextual.training.config import (
    BATTER_TARGET_STATS,
    PITCHER_TARGET_STATS,
)

sns.set_theme(style="whitegrid")
%matplotlib inline

In [None]:
# Configuration â€” edit these to match your trained model
PERSPECTIVE = "pitcher"  # or "batter"
HIER_CHECKPOINT = f"hierarchical_{PERSPECTIVE}_best"
FLAT_CHECKPOINT = f"finetune_{PERSPECTIVE}_best"
ARCHETYPE_MODEL_NAME = f"{PERSPECTIVE}_archetypes"
PROFILE_YEAR = 2023

# Architecture (must match training)
D_MODEL = 256
N_LAYERS = 4
N_HEADS = 8
FF_DIM = 1024
N_ARCHETYPES = 8

TARGET_STATS = PITCHER_TARGET_STATS if PERSPECTIVE == "pitcher" else BATTER_TARGET_STATS
N_TARGETS = len(TARGET_STATS)
STAT_INPUT_DIM = 13 if PERSPECTIVE == "pitcher" else 19

In [None]:
# Device selection
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Device: {device}")

In [None]:
# Load the hierarchical model
from fantasy_baseball_manager.registry.factory import create_model_registry

registry = create_model_registry()
model_store = registry.contextual_store

# Detect max_seq_len from checkpoint
state = torch.load(model_store._model_path(HIER_CHECKPOINT), weights_only=True, map_location="cpu")
max_seq_len = state["backbone.positional_encoding.pe"].shape[1]
print(f"Detected max_seq_len={max_seq_len}")

model_config = ModelConfig(
    max_seq_len=max_seq_len, d_model=D_MODEL, n_layers=N_LAYERS,
    n_heads=N_HEADS, ff_dim=FF_DIM,
)
hier_config = HierarchicalModelConfig(
    n_archetypes=N_ARCHETYPES, level3_d_model=D_MODEL,
)

hier_model = model_store.load_hierarchical_model(
    HIER_CHECKPOINT, model_config, hier_config,
    n_targets=N_TARGETS, stat_input_dim=STAT_INPUT_DIM,
)
hier_model = hier_model.to(device)
hier_model.eval()
print("Hierarchical model loaded.")

In [None]:
# Load archetype model
arch_model = load_archetype_model(ARCHETYPE_MODEL_NAME)
print(f"Archetype model: {arch_model.n_archetypes} archetypes, fitted={arch_model.is_fitted}")

In [None]:
# Parameter count summary
def count_params(model, requires_grad=None):
    params = model.parameters()
    if requires_grad is not None:
        params = (p for p in params if p.requires_grad == requires_grad)
    return sum(p.numel() for p in params)

total = count_params(hier_model)
frozen = count_params(hier_model, requires_grad=False)
trainable = count_params(hier_model, requires_grad=True)

print(f"Total parameters:     {total:>12,}")
print(f"Frozen (backbone):    {frozen:>12,}")
print(f"Trainable:            {trainable:>12,}")
print(f"Trainable fraction:   {trainable / total:.1%}")

## Section 2: Identity Analysis

In [None]:
# Build profiles
from fantasy_baseball_manager.contextual.identity.stat_profile import (
    PlayerStatProfile,
    PlayerStatProfileBuilder,
)
from fantasy_baseball_manager.marcel.data_source import (
    create_batting_source,
    create_pitching_source,
)

profile_builder = PlayerStatProfileBuilder()
all_profiles = profile_builder.build_all_profiles(
    create_batting_source(), create_pitching_source(),
    PROFILE_YEAR, min_opportunities=50.0,
)
profiles = [p for p in all_profiles if p.player_type == PERSPECTIVE]
print(f"{len(profiles)} {PERSPECTIVE} profiles")

In [None]:
# Predict archetypes for all profiles
X = np.array([p.to_feature_vector() for p in profiles])
labels = arch_model.predict(X)

# Archetype distribution histogram
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(labels, bins=range(arch_model.n_archetypes + 1), align="left", edgecolor="black")
ax.set_xlabel("Archetype ID")
ax.set_ylabel("Count")
ax.set_title(f"Archetype Distribution ({PERSPECTIVE.title()}s)")
ax.set_xticks(range(arch_model.n_archetypes))
plt.tight_layout()
plt.show()

In [None]:
# Archetype centroids heatmap
centroids = arch_model.centroids()
feature_names = PlayerStatProfile.feature_names(PERSPECTIVE)

fig, ax = plt.subplots(figsize=(12, 6))
sns.heatmap(
    centroids.T, xticklabels=[f"A{i}" for i in range(centroids.shape[0])],
    yticklabels=feature_names, annot=True, fmt=".3f", cmap="RdBu_r", center=0, ax=ax,
)
ax.set_title("Archetype Centroids (Original Feature Space)")
ax.set_xlabel("Archetype")
plt.tight_layout()
plt.show()

In [None]:
# Per-archetype player examples
for arch_id in range(arch_model.n_archetypes):
    members = [p for p, label in zip(profiles, labels) if label == arch_id]
    if members:
        examples = members[:5]
        names = ", ".join(p.name for p in examples)
        print(f"Archetype {arch_id} ({len(members)} players): {names}")

## Section 3: Prediction Variance by Identity

In [None]:
# Pick N diverse players (one from each archetype if possible)
from fantasy_baseball_manager.contextual.predictor import ContextualPredictor
from fantasy_baseball_manager.contextual.data.builder import GameSequenceBuilder
from fantasy_baseball_manager.statcast.store import StatcastStore
from fantasy_baseball_manager.statcast.models import DEFAULT_DATA_DIR

store = StatcastStore(data_dir=DEFAULT_DATA_DIR)
predictor = ContextualPredictor(model_store, store)

# Select one player per archetype
selected_players = []
for arch_id in range(arch_model.n_archetypes):
    members = [p for p, label in zip(profiles, labels) if label == arch_id]
    if members:
        # Pick the one with most career opportunities
        best = max(members, key=lambda p: p.opportunities_career)
        selected_players.append((best, arch_id))

print(f"Selected {len(selected_players)} players for analysis:")
for p, arch_id in selected_players:
    print(f"  {p.name} (ID: {p.player_id}, Archetype: {arch_id})")

In [None]:
# Run hierarchical inference for each selected player
from fantasy_baseball_manager.contextual.training.config import (
    DEFAULT_BATTER_CONTEXT_WINDOW,
    DEFAULT_PITCHER_CONTEXT_WINDOW,
)

context_window = (
    DEFAULT_BATTER_CONTEXT_WINDOW if PERSPECTIVE == "batter"
    else DEFAULT_PITCHER_CONTEXT_WINDOW
)

predictions = []
for profile, arch_id in selected_players:
    try:
        preds = predictor.predict_player_hierarchical(
            mlbam_id=int(profile.player_id),
            data_year=PROFILE_YEAR,
            perspective=PERSPECTIVE,
            model=hier_model,
            profile=profile,
            archetype_model=arch_model,
            context_window=context_window,
        )
        predictions.append({
            "name": profile.name,
            "archetype": arch_id,
            **{stat: preds[i].item() for i, stat in enumerate(TARGET_STATS)},
        })
    except Exception as e:
        print(f"  Skipping {profile.name}: {e}")

pred_df = pd.DataFrame(predictions)
print("\nHierarchical model predictions:")
pred_df

In [None]:
# Diagnostic: do predictions correlate with identity?
# E.g., high-K pitchers should get higher predicted K rates
if PERSPECTIVE == "pitcher" and len(pred_df) > 1:
    print("Prediction spread (std across players) per stat:")
    for stat in TARGET_STATS:
        std = pred_df[stat].std()
        print(f"  {stat}: std={std:.4f}")

## Section 4: Flat vs Hierarchical Comparison

In [None]:
# Load the flat fine-tuned model
flat_model = model_store.load_finetune_model(
    FLAT_CHECKPOINT, model_config, N_TARGETS,
)
flat_model = flat_model.to(device)
flat_model.eval()
print("Flat fine-tuned model loaded.")

In [None]:
# Run flat model predictions for the same players
flat_predictions = []
for profile, arch_id in selected_players:
    try:
        preds = predictor.predict_player(
            mlbam_id=int(profile.player_id),
            data_year=PROFILE_YEAR,
            perspective=PERSPECTIVE,
            model=flat_model,
            context_window=context_window,
        )
        flat_predictions.append({
            "name": profile.name,
            "archetype": arch_id,
            **{stat: preds[i].item() for i, stat in enumerate(TARGET_STATS)},
        })
    except Exception as e:
        print(f"  Skipping {profile.name}: {e}")

flat_df = pd.DataFrame(flat_predictions)
print("\nFlat model predictions:")
flat_df

In [None]:
# Side-by-side comparison
if len(pred_df) > 0 and len(flat_df) > 0:
    comparison = pred_df.set_index("name")[list(TARGET_STATS)].rename(
        columns={s: f"hier_{s}" for s in TARGET_STATS}
    ).join(
        flat_df.set_index("name")[list(TARGET_STATS)].rename(
            columns={s: f"flat_{s}" for s in TARGET_STATS}
        ),
        how="inner",
    )
    print("Side-by-side comparison:")
    display(comparison)

In [None]:
# Bar chart: prediction spread (std across players) per stat for each model
if len(pred_df) > 1 and len(flat_df) > 1:
    hier_std = pred_df[list(TARGET_STATS)].std()
    flat_std = flat_df[list(TARGET_STATS)].std()

    spread_df = pd.DataFrame({
        "Hierarchical": hier_std,
        "Flat": flat_std,
    })

    fig, ax = plt.subplots(figsize=(8, 5))
    spread_df.plot(kind="bar", ax=ax)
    ax.set_title("Prediction Spread (std across players) per Stat")
    ax.set_ylabel("Standard Deviation")
    ax.set_xlabel("Target Stat")
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()

## Section 5: Identity Ablation

In [None]:
# Pick one player for ablation
if selected_players:
    ablation_profile, ablation_arch = selected_players[0]
    print(f"Ablation player: {ablation_profile.name} (Archetype {ablation_arch})")

In [None]:
# Build context for the ablation player
from fantasy_baseball_manager.contextual.data.vocab import (
    BB_TYPE_VOCAB,
    HANDEDNESS_VOCAB,
    PA_EVENT_VOCAB,
    PITCH_RESULT_VOCAB,
    PITCH_TYPE_VOCAB,
)
from fantasy_baseball_manager.contextual.model.tensorizer import Tensorizer
from fantasy_baseball_manager.contextual.training.dataset import build_player_contexts

seq_builder = GameSequenceBuilder(store)
tensorizer = Tensorizer(
    config=model_config,
    pitch_type_vocab=PITCH_TYPE_VOCAB,
    pitch_result_vocab=PITCH_RESULT_VOCAB,
    bb_type_vocab=BB_TYPE_VOCAB,
    handedness_vocab=HANDEDNESS_VOCAB,
    pa_event_vocab=PA_EVENT_VOCAB,
)

# Build context for the ablation player's year
contexts = build_player_contexts(
    seq_builder, (PROFILE_YEAR,), (PERSPECTIVE,), min_pitch_count=10,
)
player_ctx = next(
    (c for c in contexts if c.player_id == int(ablation_profile.player_id)), None
)
if player_ctx:
    print(f"Found context for {ablation_profile.name}: {len(player_ctx.games)} games")
else:
    print(f"No context found for {ablation_profile.name}")

In [None]:
# Run ablation: real identity, zero identity, different archetype identity
from fantasy_baseball_manager.contextual.training.hierarchical_dataset import (
    build_hierarchical_windows,
    HierarchicalFineTuneDataset,
    collate_hierarchical_samples,
)
from fantasy_baseball_manager.contextual.training.config import HierarchicalFineTuneConfig

if player_ctx:
    ft_config = HierarchicalFineTuneConfig(
        perspective=PERSPECTIVE,
        context_window=context_window,
        min_games=context_window + 5,
    )

    # Build one window from this player
    profile_lookup = {int(ablation_profile.player_id): ablation_profile}
    windows = build_hierarchical_windows(
        [player_ctx], tensorizer, ft_config, TARGET_STATS,
        profile_lookup, arch_model, STAT_INPUT_DIM,
    )

    if windows:
        # Take first window
        sample_window = windows[0]
        tensorized, targets, ctx_mean, identity_feat, arch_id = sample_window
        print(f"Built sample window (target: {targets.numpy()})")
        print(f"Real identity features shape: {identity_feat.shape}")
        print(f"Archetype ID: {arch_id}")
    else:
        print("No windows could be built for this player.")

In [None]:
# Run model with three identity conditions
from fantasy_baseball_manager.contextual.training.hierarchical_dataset import (
    HierarchicalFineTuneSample,
)

def run_with_identity(model, tensorized, identity_feat, arch_id):
    """Run hierarchical model on a single sample with given identity."""
    sample = HierarchicalFineTuneSample(
        context=tensorized,
        targets=torch.zeros(N_TARGETS),
        context_mean=torch.zeros(N_TARGETS),
        identity_features=identity_feat,
        archetype_id=arch_id,
    )
    batch = collate_hierarchical_samples([sample])
    # Move to device
    ctx = batch.context
    from fantasy_baseball_manager.contextual.model.tensorizer import TensorizedBatch
    dev_ctx = TensorizedBatch(
        pitch_type_ids=ctx.pitch_type_ids.to(device),
        pitch_result_ids=ctx.pitch_result_ids.to(device),
        bb_type_ids=ctx.bb_type_ids.to(device),
        stand_ids=ctx.stand_ids.to(device),
        p_throws_ids=ctx.p_throws_ids.to(device),
        pa_event_ids=ctx.pa_event_ids.to(device),
        numeric_features=ctx.numeric_features.to(device),
        numeric_mask=ctx.numeric_mask.to(device),
        padding_mask=ctx.padding_mask.to(device),
        player_token_mask=ctx.player_token_mask.to(device),
        game_ids=ctx.game_ids.to(device),
        seq_lengths=ctx.seq_lengths.to(device),
    )
    with torch.no_grad():
        output = model(dev_ctx, batch.identity_features.to(device), batch.archetype_ids.to(device))
    return output["performance_preds"].cpu().squeeze(0)

if player_ctx and windows:
    # 1. Real identity
    preds_real = run_with_identity(hier_model, tensorized, identity_feat, arch_id)

    # 2. Zero identity (fallback)
    zero_feat = torch.zeros_like(identity_feat)
    preds_zero = run_with_identity(hier_model, tensorized, zero_feat, 0)

    # 3. Different archetype: pick the most different archetype
    other_arch = (arch_id + arch_model.n_archetypes // 2) % arch_model.n_archetypes
    # Find a profile from the other archetype
    other_profiles = [p for p, l in zip(profiles, labels) if l == other_arch]
    if other_profiles:
        other_feat = torch.tensor(other_profiles[0].to_feature_vector(), dtype=torch.float32)
        preds_other = run_with_identity(hier_model, tensorized, other_feat, other_arch)
    else:
        preds_other = None

    # Display results
    ablation_rows = [
        {"condition": "Real identity", **{s: preds_real[i].item() for i, s in enumerate(TARGET_STATS)}},
        {"condition": "Zero identity", **{s: preds_zero[i].item() for i, s in enumerate(TARGET_STATS)}},
    ]
    if preds_other is not None:
        ablation_rows.append(
            {"condition": f"Archetype {other_arch} identity",
             **{s: preds_other[i].item() for i, s in enumerate(TARGET_STATS)}}
        )

    ablation_df = pd.DataFrame(ablation_rows).set_index("condition")
    print(f"\nIdentity Ablation for {ablation_profile.name}:")
    display(ablation_df)

    # Visualize
    fig, ax = plt.subplots(figsize=(8, 5))
    ablation_df.T.plot(kind="bar", ax=ax)
    ax.set_title(f"Identity Ablation: {ablation_profile.name}")
    ax.set_ylabel("Predicted Rate")
    ax.set_xlabel("Target Stat")
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()