# AIPI 590 - XAI | Assignment #9
### Hongxuan Li

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yQK4vAmebF9_5Y0ArmPrPv0HI0bviuqx?usp=sharing)

# References

- TransformerLens: https://github.com/TransformerLensOrg/TransformerLens

# Dependencies

In [1]:
import torch
import numpy as np
from transformer_lens import HookedTransformer, utils
import transformer_lens.patching as patching
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = HookedTransformer.from_pretrained("gpt2-small")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded pretrained model gpt2-small into HookedTransformer


# In-Context Learning Analysis with TransformerLens & Activation Patching

## Motivation
Activation patching helps us understand how language models perform in-context learning by revealing which components (attention heads, layers, positions) are crucial for processing demonstration examples and applying that knowledge to new cases. By comparing model behavior between clean and corrupted demonstrations, we can trace how information flows from examples to predictions.


# Prompt Generation

This section creates prompts for sentiment analysis `Review: [text] Sentiment: [label]` :
- Clean prompts: Well-formed sentiment analysis in-context learning examples
- Corrupted prompts: Modified versions with swapped labels


In [2]:
# Create demonstration examples
demo_pairs = [
    ("This restaurant was amazing! The food was delicious and service was top-notch.", "positive"),
    ("Terrible experience. Rude staff and cold food. Would not recommend.", "negative"),
    ("Pretty average place. Food was okay but nothing special.", "neutral"),
    ("Great atmosphere but the prices are way too high.", "mixed"),
    ("The best movie I've seen all year! A masterpiece!", "positive"),
    ("Complete waste of time and money. The plot made no sense.", "negative")
]

demo_str = '\n'.join([f"Review: {x} Sentiment: {y}" for (x, y) in demo_pairs])

# Test reviews
test_reviews = [
    "Exceeded all my expectations. Outstanding in every way!",
    "Disappointing product. Broke after two uses.",
    "Decent value for money but shipping took forever.",
    "Absolutely fantastic service and quality!"
]

clean_prompts = [f"{demo_str}\nReview: {review} Sentiment:" for review in test_reviews]

# Create corrupted prompts by swapping sentiment labels
sentiment_map = {
    'positive': 'negative',
    'negative': 'positive',
    'neutral': 'mixed',
    'mixed': 'neutral'
}

corrupted_prompts = []
for prompt in clean_prompts:
    lines = prompt.split('\n')
    demo_lines = lines[:-1]
    for i in range(len(demo_lines)):
        for orig, new in sentiment_map.items():
            if f"Sentiment: {orig}" in demo_lines[i]:
                demo_lines[i] = demo_lines[i].replace(
                    f"Sentiment: {orig}",
                    f"Sentiment: {new}"
                )
    corrupted_prompts.append('\n'.join(demo_lines + [lines[-1]]))

expected_sentiments = ['positive', 'negative', 'mixed', 'positive']

# Activation Patching

1. Convert prompts to tokens and run the model to get baseline performance
2. Calculate metrics for clean and corrupted examples
3. Perform patching on residual stream and attention matrices
4. Normalize scores to show relative impact of each component

The section reveals how different parts of the model contribute to in-context learning of sentiment analysis.

In [3]:
# Convert text prompts to token
clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

# Run model on both clean and corrupted inputs
# clean_logits shape: [batch_size, sequence_length, vocab_size]
# clean_cache contains intermediate activations for every layer
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

# Map sentiment labels to token ID
sentiment_tokens = {
    'positive': model.to_single_token(" positive"),
    'negative': model.to_single_token(" negative"),
    'neutral': model.to_single_token(" neutral"),
    'mixed': model.to_single_token(" mixed")
}

# Extract logits for final token prediction (sentiment label)
final_clean_logits = clean_logits[:, -1, :]
final_corrupted_logits = corrupted_logits[:, -1, :]

# softmax
clean_log_probs = torch.log_softmax(final_clean_logits, dim=-1)
corrupted_log_probs = torch.log_softmax(final_corrupted_logits, dim=-1)


clean_metric = torch.zeros(len(expected_sentiments), device=model.cfg.device)
corrupted_metric = torch.zeros(len(expected_sentiments), device=model.cfg.device)

# For each example, get the log probability of the correct sentiment token
for i, sentiment in enumerate(expected_sentiments):
    clean_metric[i] = clean_log_probs[i, sentiment_tokens[sentiment]]
    corrupted_metric[i] = corrupted_log_probs[i, sentiment_tokens[sentiment]]

# Calculate average performance across all examples as baselines
clean_baseline = clean_metric.mean().item()
corrupted_baseline = corrupted_metric.mean().item()

# Define metric function for activation patching
def metric_fn(logits):
    # Extract final token predictions
    final_logits = logits[:, -1, :]
    # Convert to probabilities
    log_probs = torch.log_softmax(final_logits, dim=-1)
    correct_log_probs = torch.zeros(len(expected_sentiments), device=model.cfg.device)

    # Get probability of correct sentiment for each example
    for i, sentiment in enumerate(expected_sentiments):
        correct_log_probs[i] = log_probs[i, sentiment_tokens[sentiment]]
    # - Values in between show partial recovery of clean performance
    return (correct_log_probs.mean() - corrupted_baseline) / (clean_baseline - corrupted_baseline)

# Run activation patching on residual stream
# tensor showing impact of patching each position at each layer
resid_results = patching.get_act_patch_resid_pre(
    model, corrupted_tokens, clean_cache, metric_fn
)

# Run activation patching on attention components
# tensor showing impact of patching each component for each head
attn_results = patching.get_act_patch_attn_head_all_pos_every(
    model, corrupted_tokens, clean_cache, metric_fn
)
resid_results = resid_results.cpu()
attn_results = attn_results.cpu()

  0%|          | 0/1668 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

# Visualization
1. Residual Stream Impact: Shows how each layer and position affects model behavior
2. Attention Component Impact: Displays the influence of different attention mechanisms (Output, Query, Key, Value, Pattern)

In [4]:
# Residual stream visualization
resid_fig = px.imshow(
    resid_results.numpy(),
    title="Residual Stream Impact on In-context Learning",
    labels={'x': 'Position', 'y': 'Layer'},
    color_continuous_scale='RdBu',
    range_color=[-1, 1]
)
resid_fig.show()

# Attention components visualization
attn_labels = ['Output', 'Query', 'Key', 'Value', 'Pattern']
fig = make_subplots(
    rows=1, cols=5,
    subplot_titles=attn_labels
)

attn_results = attn_results.numpy()
for i in range(5):
    fig.add_trace(
        go.Heatmap(
            z=attn_results[i],
            colorscale='RdBu',
            zmin=-1,
            zmax=1
        ),
        row=1, col=i+1
    )

fig.update_layout(
    title="Attention Component Impact on In-context Learning",
    height=400,
    width=1500,
    showlegend=False
)
fig.show()


## Key Observations:
- Position 45: Strong positive across all layers (>0.2)
- Position 63: Consistent negative (~-0.2) across layers  
- Position 80: Strong positive (0.5) in shallow → negative (-0.5) in deep
- Position 120: Very strong positive (0.9) in shallow → negative (-0.2) in deep
- Final position: Near zero in shallow → strong positive (1.0) in deep

## Analysis:

1. **Processing Stages**  
- Early layers store raw demonstration patterns (pos 80, 120 strong positive)
- Deep layers reinterpret these patterns for task-specific use (pos 80, 120 shift negative)
- Final position pattern shows clear progression from storage to using demonstration information

2. **In-Context Learning Mechanism**
- Model first stores demonstrations in shallow layers
- Then aggregate this information in deeper layers
- Finally integrates across all demonstrations for prediction (final position pattern)