# Phase 8: Model Interpretability Analysis

Understand what the multiclass rebracketing classifier learned.

**Sections:**
1. Load trained model
2. Embedding space visualization (TSNE/UMAP)
3. Attention visualization
4. Feature attribution (which words matter)
5. Chromatic geometry analysis

## Setup

In [None]:
# Install interpretability dependencies
# !pip install captum umap-learn plotly

In [None]:
import sys
sys.path.insert(0, ".")

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set device (CPU is fine for analysis)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Load Trained Model

In [None]:
# Paths
CHECKPOINT_PATH = "output/checkpoint_best.pt"
DATA_PATH = "data/base_manifest_db.parquet"

# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

print(f"Checkpoint keys: {checkpoint.keys()}")
print(f"Target type: {checkpoint.get('target_type', 'unknown')}")
print(f"Num classes: {checkpoint.get('num_classes', 'unknown')}")
print(f"Class mapping: {checkpoint.get('class_mapping', {})}")

In [None]:
# Rebuild model from checkpoint config
from models.text_encoder import TextEncoder
from models.multiclass_classifier import MultiClassRebracketingClassifier, MultiClassRainbowModel

config = checkpoint["config"]
class_mapping = checkpoint["class_mapping"]
num_classes = checkpoint["num_classes"]
class_names = list(class_mapping.keys())

# Text encoder
text_config = config["model"]["text_encoder"]
text_encoder = TextEncoder(
    model_name=text_config["model_name"],
    hidden_size=text_config["hidden_size"],
    freeze_layers=text_config["freeze_layers"],
    pooling=text_config["pooling"],
)

# Classifier
clf_config = config["model"]["classifier"]
classifier = MultiClassRebracketingClassifier(
    input_dim=text_encoder.hidden_size,
    num_classes=num_classes,
    hidden_dims=clf_config["hidden_dims"],
    dropout=clf_config["dropout"],
    activation=clf_config["activation"],
)

# Combined model
model = MultiClassRainbowModel(text_encoder=text_encoder, classifier=classifier)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

print("\nModel loaded successfully!")
print(f"Classes: {class_names}")

In [None]:
# Load tokenizer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    text_config["model_name"],
    use_fast=False,
    add_prefix_space=False,
)
print(f"Tokenizer loaded: {text_config['model_name']}")

In [None]:
# Load data
df = pd.read_parquet(DATA_PATH)

# Filter to rows with concepts and known rebracketing types
df = df[df["concept"].notna()]
df["rebracketing_type"] = df["training_data"].apply(
    lambda x: x.get("rebracketing_type") if isinstance(x, dict) else None
)
df = df[df["rebracketing_type"].isin(class_mapping.keys())]

print(f"Loaded {len(df)} samples with known rebracketing types")
print(df["rebracketing_type"].value_counts())

## 2. Embedding Space Visualization

See how different rebracketing types cluster in the learned embedding space.

In [None]:
def get_embeddings(texts, batch_size=16):
    """Extract embeddings for a list of texts."""
    embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Extracting embeddings"):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenize
        encoding = tokenizer(
            batch_texts,
            max_length=text_config["max_length"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        
        with torch.no_grad():
            # Get embeddings from text encoder (before classifier)
            emb = model.text_encoder(input_ids, attention_mask)
            embeddings.append(emb.cpu().numpy())
    
    return np.vstack(embeddings)

In [None]:
# Extract embeddings for all samples
texts = df["concept"].tolist()
labels = df["rebracketing_type"].tolist()
colors = df.get("rainbow_color", pd.Series(["unknown"] * len(df))).tolist()

embeddings = get_embeddings(texts)
print(f"Embeddings shape: {embeddings.shape}")

In [None]:
# TSNE visualization
from sklearn.manifold import TSNE

print("Running TSNE...")
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
embeddings_2d = tsne.fit_transform(embeddings)

# Plot by rebracketing type
fig, ax = plt.subplots(figsize=(12, 8))

for rb_type in class_names:
    mask = [label == rb_type for label in labels]
    if sum(mask) > 0:
        ax.scatter(
            embeddings_2d[mask, 0],
            embeddings_2d[mask, 1],
            label=f"{rb_type} ({sum(mask)})",
            alpha=0.7,
            s=50,
        )

ax.set_xlabel("TSNE-1")
ax.set_ylabel("TSNE-2")
ax.set_title("Embedding Space by Rebracketing Type")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig("output/embedding_tsne_rebracketing.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# UMAP visualization (often better than TSNE)
try:
    import umap
    
    print("Running UMAP...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    embeddings_umap = reducer.fit_transform(embeddings)
    
    fig, ax = plt.subplots(figsize=(12, 8))
    
    for rb_type in class_names:
        mask = [label == rb_type for label in labels]
        if sum(mask) > 0:
            ax.scatter(
                embeddings_umap[mask, 0],
                embeddings_umap[mask, 1],
                label=f"{rb_type} ({sum(mask)})",
                alpha=0.7,
                s=50,
            )
    
    ax.set_xlabel("UMAP-1")
    ax.set_ylabel("UMAP-2")
    ax.set_title("Embedding Space by Rebracketing Type (UMAP)")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig("output/embedding_umap_rebracketing.png", dpi=150, bbox_inches="tight")
    plt.show()
except ImportError:
    print("UMAP not installed. Run: pip install umap-learn")

In [None]:
# Plot by chromatic color if available
if "rainbow_color" in df.columns:
    fig, ax = plt.subplots(figsize=(12, 8))
    
    color_map = {
        "BLACK": "black",
        "RED": "red",
        "ORANGE": "orange",
        "YELLOW": "gold",
        "GREEN": "green",
        "BLUE": "blue",
        "INDIGO": "indigo",
        "VIOLET": "violet",
    }
    
    for color_name, plot_color in color_map.items():
        mask = [c == color_name for c in colors]
        if sum(mask) > 0:
            ax.scatter(
                embeddings_2d[mask, 0],
                embeddings_2d[mask, 1],
                label=f"{color_name} ({sum(mask)})",
                color=plot_color,
                alpha=0.7,
                s=50,
            )
    
    ax.set_xlabel("TSNE-1")
    ax.set_ylabel("TSNE-2")
    ax.set_title("Embedding Space by Chromatic Color")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig("output/embedding_tsne_chromatic.png", dpi=150, bbox_inches="tight")
    plt.show()

## 3. Attention Visualization

See which words the model focuses on for predictions.

In [None]:
def get_attention_weights(text):
    """Extract attention weights for a single text."""
    encoding = tokenizer(
        text,
        max_length=text_config["max_length"],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    
    with torch.no_grad():
        # Get attention from the underlying transformer
        outputs = model.text_encoder.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True,
        )
        
        # Get prediction
        logits = model(input_ids, attention_mask)
        pred_idx = torch.argmax(logits, dim=-1).item()
        pred_class = class_names[pred_idx]
        
    # Average attention across heads and layers
    # Shape: (num_layers, batch, num_heads, seq_len, seq_len)
    attentions = outputs.attentions
    
    # Take last layer, average over heads
    last_layer_attn = attentions[-1][0].mean(dim=0)  # (seq_len, seq_len)
    
    # Get attention to [CLS] token (or average)
    cls_attention = last_layer_attn[0].cpu().numpy()  # Attention from CLS to all tokens
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
    
    # Mask padding
    seq_len = attention_mask.sum().item()
    tokens = tokens[:seq_len]
    cls_attention = cls_attention[:seq_len]
    
    return tokens, cls_attention, pred_class

In [None]:
def visualize_attention(text, figsize=(16, 4)):
    """Visualize attention weights for a text."""
    tokens, attention, pred_class = get_attention_weights(text)
    
    # Normalize attention
    attention = (attention - attention.min()) / (attention.max() - attention.min() + 1e-8)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create heatmap
    im = ax.imshow([attention], cmap="YlOrRd", aspect="auto")
    
    # Set tokens as x labels
    ax.set_xticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=8)
    ax.set_yticks([])
    
    ax.set_title(f"Attention Weights (Predicted: {pred_class})")
    
    plt.colorbar(im, ax=ax, label="Attention")
    plt.tight_layout()
    
    return fig

In [None]:
# Visualize attention for a few examples
sample_indices = df.sample(min(5, len(df)), random_state=42).index

for idx in sample_indices:
    row = df.loc[idx]
    concept = row["concept"][:500]  # Truncate for visualization
    true_type = row["rebracketing_type"]
    
    print(f"\n{'='*60}")
    print(f"True type: {true_type}")
    print(f"Concept: {concept[:200]}...")
    
    fig = visualize_attention(concept)
    plt.show()

## 4. Feature Attribution

Which words drive the model's predictions?

In [None]:
try:
    from captum.attr import LayerIntegratedGradients
    CAPTUM_AVAILABLE = True
except ImportError:
    print("Captum not installed. Run: pip install captum")
    CAPTUM_AVAILABLE = False

In [None]:
if CAPTUM_AVAILABLE:
    def forward_for_attribution(input_ids, attention_mask):
        """Forward function for Captum attribution."""
        logits = model(input_ids, attention_mask)
        return logits
    
    def get_attribution(text, target_class=None):
        """Get integrated gradients attribution for a text."""
        encoding = tokenizer(
            text,
            max_length=text_config["max_length"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        
        # Get prediction if target not specified
        if target_class is None:
            with torch.no_grad():
                logits = model(input_ids, attention_mask)
                target_class = torch.argmax(logits, dim=-1).item()
        
        # Baseline: padding tokens
        baseline_ids = torch.zeros_like(input_ids)
        baseline_ids.fill_(tokenizer.pad_token_id)
        
        # Get embeddings layer for attribution
        embeddings_layer = model.text_encoder.model.embeddings
        
        lig = LayerIntegratedGradients(forward_for_attribution, embeddings_layer)
        
        attributions = lig.attribute(
            inputs=input_ids,
            baselines=baseline_ids,
            additional_forward_args=(attention_mask,),
            target=target_class,
            n_steps=50,
        )
        
        # Sum over embedding dimension
        attributions = attributions.sum(dim=-1).squeeze(0)
        attributions = attributions.cpu().numpy()
        
        # Get tokens
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
        seq_len = attention_mask.sum().item()
        
        return tokens[:seq_len], attributions[:seq_len], class_names[target_class]

In [None]:
if CAPTUM_AVAILABLE:
    def visualize_attribution(text, figsize=(16, 4)):
        """Visualize word importance via integrated gradients."""
        tokens, attributions, pred_class = get_attribution(text)
        
        # Normalize
        attr_max = np.abs(attributions).max()
        if attr_max > 0:
            attributions = attributions / attr_max
        
        fig, ax = plt.subplots(figsize=figsize)
        
        colors = ['red' if a < 0 else 'green' for a in attributions]
        ax.bar(range(len(tokens)), attributions, color=colors, alpha=0.7)
        
        ax.set_xticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=8)
        ax.set_ylabel("Attribution")
        ax.set_title(f"Word Importance (Predicted: {pred_class})")
        ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        
        plt.tight_layout()
        return fig
    
    # Visualize attribution for a few examples
    for idx in sample_indices[:3]:
        row = df.loc[idx]
        concept = row["concept"][:300]  # Shorter for attribution
        true_type = row["rebracketing_type"]
        
        print(f"\n{'='*60}")
        print(f"True type: {true_type}")
        print(f"Concept: {concept[:150]}...")
        
        fig = visualize_attribution(concept)
        plt.show()

## 5. Class Predictions Analysis

In [None]:
# Get predictions for all samples
def get_predictions(texts, batch_size=16):
    all_preds = []
    all_probs = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Getting predictions"):
        batch_texts = texts[i:i+batch_size]
        
        encoding = tokenizer(
            batch_texts,
            max_length=text_config["max_length"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        
        with torch.no_grad():
            logits = model(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    return np.array(all_preds), np.array(all_probs)

predictions, probabilities = get_predictions(texts)

In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

# Convert labels to indices
label_indices = [class_mapping[lbl] for lbl in labels]

cm = confusion_matrix(label_indices, predictions, labels=list(range(num_classes)))

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names,
    ax=ax,
)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title("Confusion Matrix")
plt.tight_layout()
plt.savefig("output/confusion_matrix_analysis.png", dpi=150)
plt.show()

print("\nClassification Report:")
print(classification_report(label_indices, predictions, target_names=class_names, zero_division=0))

In [None]:
# Prediction confidence distribution
max_probs = probabilities.max(axis=1)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Overall confidence distribution
axes[0].hist(max_probs, bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel("Prediction Confidence")
axes[0].set_ylabel("Count")
axes[0].set_title("Confidence Distribution")
axes[0].axvline(x=0.5, color='red', linestyle='--', label='50%')
axes[0].legend()

# Confidence by correctness
correct = predictions == np.array(label_indices)
axes[1].hist(max_probs[correct], bins=20, alpha=0.7, label=f"Correct ({correct.sum()})")
axes[1].hist(max_probs[~correct], bins=20, alpha=0.7, label=f"Incorrect ({(~correct).sum()})")
axes[1].set_xlabel("Prediction Confidence")
axes[1].set_ylabel("Count")
axes[1].set_title("Confidence by Correctness")
axes[1].legend()

plt.tight_layout()
plt.savefig("output/confidence_distribution.png", dpi=150)
plt.show()

## 6. Misclassification Analysis

In [None]:
# Find misclassified examples
df["predicted"] = [class_names[p] for p in predictions]
df["confidence"] = max_probs
df["correct"] = df["predicted"] == df["rebracketing_type"]

misclassified = df[~df["correct"]]

print(f"Misclassified: {len(misclassified)} / {len(df)} ({100*len(misclassified)/len(df):.1f}%)")
print("\nMost common confusions:")
print(misclassified.groupby(["rebracketing_type", "predicted"]).size().sort_values(ascending=False).head(10))

In [None]:
# Show some misclassified examples
print("\nExample misclassifications:")
for _, row in misclassified.head(5).iterrows():
    print(f"\n{'='*60}")
    print(f"True: {row['rebracketing_type']} | Predicted: {row['predicted']} | Conf: {row['confidence']:.2f}")
    print(f"Concept: {row['concept'][:300]}...")

## Summary

Key findings from this analysis:

1. **Embedding Space**: How well do rebracketing types cluster?
2. **Attention**: What words does the model focus on?
3. **Attribution**: Which words drive predictions?
4. **Confusions**: Where does the model struggle?

Save this notebook's outputs for documentation.