# Phase 8: Model Interpretability Analysis

# Version 6

Understand what the multiclass rebracketing classifier learned.

**Sections:**
1. Setup & Load from HuggingFace
2. Embedding space visualization (TSNE/UMAP)
3. Class predictions & confusion matrix
4. Misclassification analysis

## 1. Setup

In [None]:
# Step 1: Install dependencies
# After this cell completes, do: Runtime -> Restart session
# Then SKIP this cell and continue from the next one

!pip install -q transformers datasets huggingface_hub wandb umap-learn

print("\n" + "="*60)
print("RESTART RUNTIME NOW: Runtime -> Restart session")
print("Then SKIP this cell and run from the next cell")
print("="*60)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import gc
import os

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# HuggingFace
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel

# Weights & Biases
import wandb

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Configuration
CHECKPOINT_PATH = "/content/drive/MyDrive/Colab Notebooks/checkpoint_best.pt"
HF_DATASET = "earthlyframes/white-rebracketing"  # Your HF dataset
WANDB_PROJECT = "White"

# Initialize wandb (optional - set to False to skip)
USE_WANDB = True

if USE_WANDB:
    wandb.login()
    run = wandb.init(
        project=WANDB_PROJECT,
        job_type="interpretability",
        config={
            "checkpoint": CHECKPOINT_PATH,
            "dataset": HF_DATASET,
        }
    )
    print(f"W&B run: {run.url}")

## 2. Load Model and Data

In [None]:
# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

print(f"Checkpoint keys: {checkpoint.keys()}")
print(f"Target type: {checkpoint.get('target_type', 'unknown')}")
print(f"Num classes: {checkpoint.get('num_classes', 'unknown')}")
print(f"Class mapping: {checkpoint.get('class_mapping', {})}")

if USE_WANDB:
    wandb.config.update({
        "num_classes": checkpoint.get('num_classes'),
        "class_mapping": checkpoint.get('class_mapping'),
    })

In [None]:
# Extract config from checkpoint
config = checkpoint["config"]
class_mapping = checkpoint["class_mapping"]
num_classes = checkpoint["num_classes"]
class_names = list(class_mapping.keys())

print(f"Classes: {class_names}")

In [None]:
# Define model architecture (must match training exactly)
import torch.nn as nn

class TextEncoder(nn.Module):
    """Text encoder using pretrained transformer."""
    
    def __init__(self, model_name, hidden_size=None, freeze_layers=0, pooling="cls", **kwargs):
        super().__init__()
        self.pooling = pooling
        
        # Load the transformer - use 'encoder' not 'model' to match checkpoint
        self.encoder = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.encoder.config.hidden_size
        
        # Freeze layers if specified
        if freeze_layers and freeze_layers > 0:
            if hasattr(self.encoder, "embeddings"):
                for param in self.encoder.embeddings.parameters():
                    param.requires_grad = False
            if hasattr(self.encoder, "encoder") and hasattr(self.encoder.encoder, "layer"):
                for i, layer in enumerate(self.encoder.encoder.layer):
                    if i < freeze_layers:
                        for param in layer.parameters():
                            param.requires_grad = False
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        if self.pooling == "cls":
            pooled = sequence_output[:, 0, :]
        elif self.pooling == "mean":
            mask_expanded = attention_mask.unsqueeze(-1).expand_as(sequence_output).float()
            sum_embeddings = torch.sum(sequence_output * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            pooled = sum_embeddings / sum_mask
        else:
            pooled = sequence_output[:, 0, :]
        
        return pooled


class MultiClassRebracketingClassifier(nn.Module):
    """Classifier head for rebracketing type prediction."""
    
    def __init__(self, input_dim, num_classes, hidden_dims=None, dropout=0.3, activation="relu", **kwargs):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 128]
        
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            if activation == "gelu":
                layers.append(nn.GELU())
            else:
                layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, num_classes))
        self.mlp = nn.Sequential(*layers)  # Use 'mlp' not 'classifier' to match checkpoint
    
    def forward(self, x):
        return self.mlp(x)


class MultiClassRainbowModel(nn.Module):
    """Combined text encoder + classifier."""
    
    def __init__(self, text_encoder, classifier):
        super().__init__()
        self.text_encoder = text_encoder
        self.classifier = classifier
    
    def forward(self, input_ids, attention_mask):
        embeddings = self.text_encoder(input_ids, attention_mask)
        return self.classifier(embeddings)

print("Model classes defined.")

In [None]:
# Build and load model
text_config = config["model"]["text_encoder"]
clf_config = config["model"]["classifier"]

text_encoder = TextEncoder(
    model_name=text_config["model_name"],
    hidden_size=text_config["hidden_size"],
    freeze_layers=text_config["freeze_layers"],
    pooling=text_config["pooling"],
)

classifier = MultiClassRebracketingClassifier(
    input_dim=text_encoder.hidden_size,
    num_classes=num_classes,
    hidden_dims=clf_config["hidden_dims"],
    dropout=clf_config["dropout"],
    activation=clf_config["activation"],
)

model = MultiClassRainbowModel(text_encoder=text_encoder, classifier=classifier)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

print("Model loaded successfully!")
print(f"Text encoder: {text_config['model_name']}")
print(f"Hidden size: {text_config['hidden_size']}")

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    text_config["model_name"],
    use_fast=False,
    add_prefix_space=False,
)
print(f"Tokenizer loaded: {text_config['model_name']}")

In [None]:
# Load dataset from HuggingFace
try:
    dataset = load_dataset(HF_DATASET)
    print(f"Loaded dataset from HuggingFace: {HF_DATASET}")
    print(dataset)
    
    # Convert to DataFrame for easier manipulation
    if "train" in dataset:
        df = dataset["train"].to_pandas()
    else:
        df = dataset.to_pandas()
        
except Exception as e:
    print(f"Could not load from HuggingFace: {e}")
    print("Falling back to local parquet...")
    df = pd.read_parquet("/content/drive/MyDrive/Colab Notebooks/data/base_manifest_db.parquet")

In [None]:
# Prepare data - extract rebracketing type
print(f"Dataset columns: {df.columns.tolist()}")
print(f"Total rows: {len(df)}")

# Handle different column formats
if "rebracketing_type" not in df.columns:
    if "training_data" in df.columns:
        df["rebracketing_type"] = df["training_data"].apply(
            lambda x: x.get("rebracketing_type") if isinstance(x, dict) else None
        )

# Filter to known classes
df = df[df["concept"].notna()]
df = df[df["rebracketing_type"].isin(class_mapping.keys())]

print(f"\nFiltered to {len(df)} samples with known rebracketing types")
print(df["rebracketing_type"].value_counts())

## 3. Embedding Space Visualization

See how different rebracketing types cluster in the learned embedding space.

In [None]:
def get_embeddings(texts, batch_size=8):  # Reduced batch size to save RAM
    """Extract embeddings for a list of texts."""
    all_embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Extracting embeddings"):
        batch_texts = texts[i:i+batch_size]
        
        encoding = tokenizer(
            batch_texts,
            max_length=text_config["max_length"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        
        with torch.no_grad():
            emb = model.text_encoder(input_ids, attention_mask)
            all_embeddings.append(emb.cpu().numpy())
        
        # Free GPU memory each batch
        del input_ids, attention_mask, emb
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return np.vstack(all_embeddings)

In [None]:
# Extract embeddings (with caching to avoid RAM issues)
EMBEDDINGS_CACHE = "/content/embeddings_cache.npz"

texts = df["concept"].tolist()
labels = df["rebracketing_type"].tolist()
colors = df.get("rainbow_color", pd.Series(["unknown"] * len(df))).tolist()

if os.path.exists(EMBEDDINGS_CACHE):
    print(f"Loading cached embeddings from {EMBEDDINGS_CACHE}")
    cache = np.load(EMBEDDINGS_CACHE, allow_pickle=True)
    embeddings = cache["embeddings"]
    print(f"Loaded embeddings shape: {embeddings.shape}")
else:
    print("Computing embeddings (this will be cached for future runs)...")
    embeddings = get_embeddings(texts)
    print(f"Embeddings shape: {embeddings.shape}")
    
    # Save to cache
    np.savez_compressed(EMBEDDINGS_CACHE, embeddings=embeddings)
    print(f"Saved to {EMBEDDINGS_CACHE}")

# Clear some memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
# TSNE visualization
from sklearn.manifold import TSNE

print("Running TSNE...")
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
embeddings_2d = tsne.fit_transform(embeddings)

# Save TSNE results for reuse
np.save("/content/tsne_2d.npy", embeddings_2d)

# Plot by rebracketing type
fig, ax = plt.subplots(figsize=(12, 8))

for rb_type in class_names:
    mask = np.array([label == rb_type for label in labels])
    if mask.sum() > 0:
        ax.scatter(
            embeddings_2d[mask, 0],
            embeddings_2d[mask, 1],
            label=f"{rb_type} ({mask.sum()})",
            alpha=0.7,
            s=50,
        )

ax.set_xlabel("TSNE-1")
ax.set_ylabel("TSNE-2")
ax.set_title("Embedding Space by Rebracketing Type")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

if USE_WANDB:
    wandb.log({"tsne_rebracketing": wandb.Image(fig)})

plt.savefig("/content/embedding_tsne_rebracketing.png", dpi=150, bbox_inches="tight")
plt.show()

# Clean up
del tsne
gc.collect()

In [None]:
# UMAP visualization (skip if RAM is tight - TSNE is usually sufficient)
import umap

print("Running UMAP...")
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
embeddings_umap = reducer.fit_transform(embeddings)

fig, ax = plt.subplots(figsize=(12, 8))

for rb_type in class_names:
    mask = np.array([label == rb_type for label in labels])
    if mask.sum() > 0:
        ax.scatter(
            embeddings_umap[mask, 0],
            embeddings_umap[mask, 1],
            label=f"{rb_type} ({mask.sum()})",
            alpha=0.7,
            s=50,
        )

ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
ax.set_title("Embedding Space by Rebracketing Type (UMAP)")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

if USE_WANDB:
    wandb.log({"umap_rebracketing": wandb.Image(fig)})

plt.savefig("/content/embedding_umap_rebracketing.png", dpi=150, bbox_inches="tight")
plt.show()

# Clean up
del reducer, embeddings_umap
gc.collect()

In [None]:
# Plot by chromatic color if available
if "rainbow_color" in df.columns:
    fig, ax = plt.subplots(figsize=(12, 8))
    
    color_map = {
        "BLACK": "black",
        "RED": "red",
        "ORANGE": "orange",
        "YELLOW": "gold",
        "GREEN": "green",
        "BLUE": "blue",
        "INDIGO": "indigo",
        "VIOLET": "violet",
    }
    
    for color_name, plot_color in color_map.items():
        mask = np.array([c == color_name for c in colors])
        if mask.sum() > 0:
            ax.scatter(
                embeddings_2d[mask, 0],
                embeddings_2d[mask, 1],
                label=f"{color_name} ({mask.sum()})",
                color=plot_color,
                alpha=0.7,
                s=50,
            )
    
    ax.set_xlabel("TSNE-1")
    ax.set_ylabel("TSNE-2")
    ax.set_title("Embedding Space by Chromatic Color")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    if USE_WANDB:
        wandb.log({"tsne_chromatic": wandb.Image(fig)})
    
    plt.savefig("/content/embedding_tsne_chromatic.png", dpi=150, bbox_inches="tight")
    plt.show()

## 6. Class Predictions Analysis

In [None]:
# Get predictions (with caching)
PREDICTIONS_CACHE = "/content/predictions_cache.npz"

def get_predictions(input_texts, batch_size=8):
    """Get predictions for all texts."""
    all_preds = []
    all_probs = []
    
    for i in tqdm(range(0, len(input_texts), batch_size), desc="Getting predictions"):
        batch_texts = input_texts[i:i+batch_size]
        
        encoding = tokenizer(
            batch_texts,
            max_length=text_config["max_length"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        
        with torch.no_grad():
            logits = model(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
        
        # Free memory
        del input_ids, attention_mask, logits, probs, preds
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return np.array(all_preds), np.array(all_probs)

# Load from cache or compute
if os.path.exists(PREDICTIONS_CACHE):
    print(f"Loading cached predictions from {PREDICTIONS_CACHE}")
    cache = np.load(PREDICTIONS_CACHE)
    predictions = cache["predictions"]
    probabilities = cache["probabilities"]
    print(f"Loaded {len(predictions)} predictions")
else:
    print("Computing predictions (will be cached)...")
    predictions, probabilities = get_predictions(texts)
    np.savez_compressed(PREDICTIONS_CACHE, predictions=predictions, probabilities=probabilities)
    print(f"Saved to {PREDICTIONS_CACHE}")

gc.collect()

In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

# Ensure labels match predictions length
label_indices = np.array([class_mapping[lbl] for lbl in labels])

print(f"Labels: {len(label_indices)}, Predictions: {len(predictions)}")
assert len(label_indices) == len(predictions), "Mismatch between labels and predictions!"

# Only use classes that appear in the data
unique_labels = np.unique(np.concatenate([label_indices, predictions]))
used_class_names = [class_names[i] for i in unique_labels]

cm = confusion_matrix(label_indices, predictions, labels=unique_labels)

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=used_class_names,
    yticklabels=used_class_names,
    ax=ax,
)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title("Confusion Matrix")
plt.tight_layout()

if USE_WANDB:
    wandb.log({"confusion_matrix": wandb.Image(fig)})

plt.savefig("/content/confusion_matrix_analysis.png", dpi=150)
plt.show()

print("\nClassification Report:")
print(classification_report(label_indices, predictions, labels=unique_labels, target_names=used_class_names, zero_division=0))

# Get report as dict for wandb
report = classification_report(label_indices, predictions, labels=unique_labels, target_names=used_class_names, zero_division=0, output_dict=True)

if USE_WANDB:
    wandb.log({
        "accuracy": report.get("accuracy", 0),
        "macro_f1": report.get("macro avg", {}).get("f1-score", 0),
        "weighted_f1": report.get("weighted avg", {}).get("f1-score", 0),
    })

In [None]:
# Prediction confidence distribution
max_probs = probabilities.max(axis=1)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Overall confidence distribution
axes[0].hist(max_probs, bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel("Prediction Confidence")
axes[0].set_ylabel("Count")
axes[0].set_title("Confidence Distribution")
axes[0].axvline(x=0.5, color='red', linestyle='--', label='50%')
axes[0].legend()

# Confidence by correctness
correct = predictions == label_indices
axes[1].hist(max_probs[correct], bins=20, alpha=0.7, label=f"Correct ({correct.sum()})")
axes[1].hist(max_probs[~correct], bins=20, alpha=0.7, label=f"Incorrect ({(~correct).sum()})")
axes[1].set_xlabel("Prediction Confidence")
axes[1].set_ylabel("Count")
axes[1].set_title("Confidence by Correctness")
axes[1].legend()

plt.tight_layout()

if USE_WANDB:
    wandb.log({"confidence_distribution": wandb.Image(fig)})

plt.savefig("/content/confidence_distribution.png", dpi=150)
plt.show()

## 7. Misclassification Analysis

In [None]:
# Find misclassified examples
df["predicted"] = [class_names[p] for p in predictions]
df["confidence"] = max_probs
df["correct"] = df["predicted"] == df["rebracketing_type"]

misclassified = df[~df["correct"]]

print(f"Misclassified: {len(misclassified)} / {len(df)} ({100*len(misclassified)/len(df):.1f}%)")
print("\nMost common confusions:")
confusion_counts = misclassified.groupby(["rebracketing_type", "predicted"]).size().sort_values(ascending=False)
print(confusion_counts.head(10))

if USE_WANDB:
    wandb.log({
        "misclassification_rate": len(misclassified) / len(df),
        "num_misclassified": len(misclassified),
    })

In [None]:
# Show some misclassified examples
print("\nExample misclassifications:")
for _, row in misclassified.head(5).iterrows():
    print(f"\n{'='*60}")
    print(f"True: {row['rebracketing_type']} | Predicted: {row['predicted']} | Conf: {row['confidence']:.2f}")
    print(f"Concept: {str(row['concept'])[:300]}...")

In [None]:
# Log misclassified examples to W&B as a table
if USE_WANDB:
    misclassified_table = wandb.Table(
        columns=["concept", "true_type", "predicted", "confidence"],
        data=[
            [str(row["concept"])[:500], row["rebracketing_type"], row["predicted"], row["confidence"]]
            for _, row in misclassified.head(50).iterrows()
        ]
    )
    wandb.log({"misclassified_examples": misclassified_table})

## 8. Summary & Cleanup

In [None]:
# Summary statistics
accuracy = correct.mean()
avg_confidence = max_probs.mean()
avg_confidence_correct = max_probs[correct].mean()
avg_confidence_incorrect = max_probs[~correct].mean() if (~correct).sum() > 0 else 0

print("="*60)
print("INTERPRETABILITY ANALYSIS SUMMARY")
print("="*60)
print(f"Total samples: {len(df)}")
print(f"Classes: {class_names}")
print(f"Accuracy: {accuracy:.2%}")
print(f"Average confidence: {avg_confidence:.2%}")
print(f"Avg confidence (correct): {avg_confidence_correct:.2%}")
print(f"Avg confidence (incorrect): {avg_confidence_incorrect:.2%}")
print("="*60)

if USE_WANDB:
    wandb.log({
        "final_accuracy": accuracy,
        "avg_confidence": avg_confidence,
        "avg_confidence_correct": avg_confidence_correct,
        "avg_confidence_incorrect": avg_confidence_incorrect,
    })

In [None]:
# Finish W&B run
if USE_WANDB:
    wandb.finish()
    print("W&B run finished!")