In [23]:
import torch
import numpy as np
from transformer_lens import HookedTransformer, utils
import transformer_lens.patching as patching
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def generate_sentiment_prompts():
    """Generate prompts with sentiment analysis demonstrations."""
    demo_pairs = [
        ("This restaurant was amazing! The food was delicious and service was top-notch.", "positive"),
        ("Terrible experience. Rude staff and cold food. Would not recommend.", "negative"),
        ("Pretty average place. Food was okay but nothing special.", "neutral"),
        ("Great atmosphere but the prices are way too high.", "mixed"),
        ("The best movie I've seen all year! A masterpiece!", "positive"),
        ("Complete waste of time and money. The plot made no sense.", "negative")
    ]
    
    # Create demonstration string
    demo_str = '\n'.join([f"Review: {x} Sentiment: {y}" for (x, y) in demo_pairs])
    
    # Test reviews
    test_reviews = [
        "Exceeded all my expectations. Outstanding in every way!",
        "Disappointing product. Broke after two uses.",
        "Decent value for money but shipping took forever.",
        "Absolutely fantastic service and quality!"
    ]
    
    prompts = [f"{demo_str}\nReview: {review} Sentiment:" for review in test_reviews]
    return prompts

def corrupt_demonstrations(prompts, corruption_type='swap'):
    """Create corrupted versions of the demonstration prompts."""
    corrupted = []
    sentiment_map = {
        'positive': 'negative',
        'negative': 'positive',
        'neutral': 'mixed',
        'mixed': 'neutral'
    }
    
    for prompt in prompts:
        if corruption_type == 'swap':
            # Swap sentiment labels
            lines = prompt.split('\n')
            demo_lines = lines[:-1]  # All but test line
            for i in range(len(demo_lines)):
                for orig, new in sentiment_map.items():
                    if f"Sentiment: {orig}" in demo_lines[i]:
                        demo_lines[i] = demo_lines[i].replace(
                            f"Sentiment: {orig}",
                            f"Sentiment: {new}"
                        )
            corrupted.append('\n'.join(demo_lines + [lines[-1]]))
        elif corruption_type == 'shuffle':
            lines = prompt.split('\n')
            demo_lines = lines[:-1]
            np.random.shuffle(demo_lines)
            corrupted.append('\n'.join(demo_lines + [lines[-1]]))
    return corrupted

def get_sentiment_metric(logits, expected_sentiments, model):
    """Metric for sentiment analysis performance."""
    final_logits = logits[:, -1, :]
    
    # Get token IDs for sentiment labels
    sentiment_tokens = {
        'positive': model.to_single_token(" positive"),
        'negative': model.to_single_token(" negative"),
        'neutral': model.to_single_token(" neutral"),
        'mixed': model.to_single_token(" mixed")
    }
    
    # Calculate log probabilities for correct sentiment
    log_probs = torch.log_softmax(final_logits, dim=-1)
    correct_log_probs = torch.zeros(len(expected_sentiments), device=model.cfg.device)
    
    for i, sentiment in enumerate(expected_sentiments):
        correct_log_probs[i] = log_probs[i, sentiment_tokens[sentiment]]
    
    return correct_log_probs.mean()

def analyze_demonstration_influence(model, clean_prompts, corrupted_prompts, expected_sentiments):
    """Analyze how different parts of demonstrations influence predictions."""
    clean_tokens = model.to_tokens(clean_prompts)
    corrupted_tokens = model.to_tokens(corrupted_prompts)
    
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
    
    clean_metric = get_sentiment_metric(clean_logits, expected_sentiments, model)
    corrupted_metric = get_sentiment_metric(corrupted_logits, expected_sentiments, model)
    
    def normalize_metric(metric):
        return (metric - corrupted_metric) / (clean_metric - corrupted_metric)
    
    resid_results = patching.get_act_patch_resid_pre(
        model, corrupted_tokens, clean_cache, 
        lambda l: normalize_metric(get_sentiment_metric(l, expected_sentiments, model))
    )
    
    attn_results = patching.get_act_patch_attn_head_all_pos_every(
        model, corrupted_tokens, clean_cache,
        lambda l: normalize_metric(get_sentiment_metric(l, expected_sentiments, model))
    )
    
    return {
        'resid_results': resid_results.cpu(),
        'attn_results': attn_results.cpu(),
        'clean_metric': clean_metric.item(),
        'corrupted_metric': corrupted_metric.item()
    }

def visualize_results(results, model):
    """Create visualizations of the patching analysis results."""
    # Residual stream analysis
    resid_fig = px.imshow(
        results['resid_results'].numpy(),
        title="Residual Stream Impact on In-context Learning",
        labels={'x': 'Position', 'y': 'Layer'},
        color_continuous_scale='RdBu',
        range_color=[-1, 1]
    )
    
    # Attention head analysis using subplots
    attn_labels = ['Output', 'Query', 'Key', 'Value', 'Pattern']
    fig = make_subplots(
        rows=1, cols=5,
        subplot_titles=attn_labels
    )
    
    attn_results = results['attn_results'].numpy()
    for i in range(5):
        fig.add_trace(
            go.Heatmap(
                z=attn_results[i],
                colorscale='RdBu',
                zmin=-1,
                zmax=1
            ),
            row=1, col=i+1
        )
    
    fig.update_layout(
        title="Attention Component Impact on In-context Learning",
        height=400,
        width=1500,
        showlegend=False
    )
    
    return resid_fig, fig


model = HookedTransformer.from_pretrained("gpt2")

clean_prompts = generate_sentiment_prompts()
corrupted_prompts = corrupt_demonstrations(clean_prompts)

expected_sentiments = ['positive', 'negative', 'mixed', 'positive']

results = analyze_demonstration_influence(
    model, clean_prompts, corrupted_prompts, expected_sentiments
)

resid_fig, attn_fig = visualize_results(results, model)


Loaded pretrained model gpt2 into HookedTransformer


  0%|          | 0/1668 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [27]:
attn_fig

In [29]:
resid_fig

In [3]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git

Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /private/var/folders/z5/jt40b2257wng77l19wn1yl4m0000gn/T/pip-req-build-a5n94ema
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /private/var/folders/z5/jt40b2257wng77l19wn1yl4m0000gn/T/pip-req-build-a5n94ema
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: neel_plotly
  Building wheel for neel_plotly (setup.py) ... [?25ldone
[?25h  Created wheel for neel_plotly: filename=neel_plotly-0.0.0-py3-none-any.whl size=10189 sha256=52dfb84f99a47ef3ceccf2fb44015b95a44936ad65767c2a927bf4202c1fdc17
  Stored in directory: /private/var/folders/z5/jt40b2257wng77l19wn1yl4m0000gn/T/pip-ephem-wheel-cache-za_elf3d/wheels/32/cf/25/0103b4be02266c40faf008ffa9565a2ba07d1c63118fccc390
S

In [30]:
import torch
import numpy as np
from transformer_lens import HookedTransformer, utils
import transformer_lens.patching as patching
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


# Prompt Generation

This section creates prompts for sentiment analysis `Review: [text] Sentiment: [label]` : 
- Clean prompts: Well-formed sentiment analysis examples
- Corrupted prompts: Modified versions with swapped labels


In [None]:
# Create demonstration examples
demo_pairs = [
    ("This restaurant was amazing! The food was delicious and service was top-notch.", "positive"),
    ("Terrible experience. Rude staff and cold food. Would not recommend.", "negative"),
    ("Pretty average place. Food was okay but nothing special.", "neutral"),
    ("Great atmosphere but the prices are way too high.", "mixed"),
    ("The best movie I've seen all year! A masterpiece!", "positive"),
    ("Complete waste of time and money. The plot made no sense.", "negative")
]

demo_str = '\n'.join([f"Review: {x} Sentiment: {y}" for (x, y) in demo_pairs])

# Test reviews
test_reviews = [
    "Exceeded all my expectations. Outstanding in every way!",
    "Disappointing product. Broke after two uses.",
    "Decent value for money but shipping took forever.",
    "Absolutely fantastic service and quality!"
]

clean_prompts = [f"{demo_str}\nReview: {review} Sentiment:" for review in test_reviews]

# Create corrupted prompts by swapping sentiment labels
sentiment_map = {
    'positive': 'negative',
    'negative': 'positive',
    'neutral': 'mixed', 
    'mixed': 'neutral'
}

corrupted_prompts = []
for prompt in clean_prompts:
    lines = prompt.split('\n')
    demo_lines = lines[:-1]
    for i in range(len(demo_lines)):
        for orig, new in sentiment_map.items():
            if f"Sentiment: {orig}" in demo_lines[i]:
                demo_lines[i] = demo_lines[i].replace(
                    f"Sentiment: {orig}",
                    f"Sentiment: {new}"
                )
    corrupted_prompts.append('\n'.join(demo_lines + [lines[-1]]))

expected_sentiments = ['positive', 'negative', 'mixed', 'positive']

In [None]:
# Convert prompts to tokens
clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

# Get baselines
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

# Get token IDs for sentiments
sentiment_tokens = {
    'positive': model.to_single_token(" positive"),
    'negative': model.to_single_token(" negative"),
    'neutral': model.to_single_token(" neutral"),
    'mixed': model.to_single_token(" mixed")
}

# Calculate baseline metrics
final_clean_logits = clean_logits[:, -1, :]
final_corrupted_logits = corrupted_logits[:, -1, :]
clean_log_probs = torch.log_softmax(final_clean_logits, dim=-1)
corrupted_log_probs = torch.log_softmax(final_corrupted_logits, dim=-1)

clean_metric = torch.zeros(len(expected_sentiments), device=model.cfg.device)
corrupted_metric = torch.zeros(len(expected_sentiments), device=model.cfg.device)

for i, sentiment in enumerate(expected_sentiments):
    clean_metric[i] = clean_log_probs[i, sentiment_tokens[sentiment]]
    corrupted_metric[i] = corrupted_log_probs[i, sentiment_tokens[sentiment]]

clean_baseline = clean_metric.mean().item()
corrupted_baseline = corrupted_metric.mean().item()

# Run activation patching
def metric_fn(logits):
    final_logits = logits[:, -1, :]
    log_probs = torch.log_softmax(final_logits, dim=-1)
    correct_log_probs = torch.zeros(len(expected_sentiments), device=model.cfg.device)
    for i, sentiment in enumerate(expected_sentiments):
        correct_log_probs[i] = log_probs[i, sentiment_tokens[sentiment]]
    return (correct_log_probs.mean() - corrupted_baseline) / (clean_baseline - corrupted_baseline)

resid_results = patching.get_act_patch_resid_pre(
    model, corrupted_tokens, clean_cache, metric_fn
)

attn_results = patching.get_act_patch_attn_head_all_pos_every(
    model, corrupted_tokens, clean_cache, metric_fn
)

# Move results to CPU for visualization
resid_results = resid_results.cpu()
attn_results = attn_results.cpu()

In [None]:
# Residual stream visualization
resid_fig = px.imshow(
    resid_results.numpy(),
    title="Residual Stream Impact on In-context Learning",
    labels={'x': 'Position', 'y': 'Layer'},
    color_continuous_scale='RdBu',
    range_color=[-1, 1]
)
resid_fig.show()

# Attention components visualization
attn_labels = ['Output', 'Query', 'Key', 'Value', 'Pattern']
fig = make_subplots(
    rows=1, cols=5,
    subplot_titles=attn_labels
)

attn_results = attn_results.numpy()
for i in range(5):
    fig.add_trace(
        go.Heatmap(
            z=attn_results[i],
            colorscale='RdBu',
            zmin=-1,
            zmax=1
        ),
        row=1, col=i+1
    )

fig.update_layout(
    title="Attention Component Impact on In-context Learning",
    height=400,
    width=1500,
    showlegend=False
)
fig.show()