# Initial Comparison

We run integrated gradients and activation patching on the same model and dataset, to compare attribution scores.

- Model: GPT2-Small (12 layers, 12 attention heads per layer, embedding size 768, 3,072 neurons per MLP layer)
- Dataset: Indirect Object Identification task

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

from attribution_methods import integrated_gradients, activation_patching, highlight_components
from testing import Task, TaskDataset, logit_diff_metric, average_correlation, measure_overlap
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_mean_diff

In [None]:
torch.set_grad_enabled(False)

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

## Experiment

In [None]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
from transformer_lens.utils import get_act_name
from attribution_methods import compute_layer_to_output_attributions

# Standard integrated gradients with zero baseline

n_samples = clean_tokens.size(0)

ig_mlp_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.d_mlp)
ig_attn_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.n_heads)

# Calculate integrated gradients for each layer
for layer in range(model.cfg.n_layers):

    # Gradient attribution on heads
    hook_name = get_act_name("result", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_input = clean_cache[prev_layer_hook]
    # Use zero activations as the baseline
    layer_baseline = torch.zeros_like(layer_input)

    # Shape [batch, seq_len, d_head, d_model]
    attributions = compute_layer_to_output_attributions(
        model,
        clean_tokens,
        layer_input,
        layer_baseline,
        target_layer,
        prev_layer,
        logit_diff_metric,
        labels,
    )

    # Calculate score based on mean over each embedding, for each token
    per_token_score = attributions.mean(dim=3)
    score = per_token_score.mean(dim=1)
    ig_attn_results[:, layer] = score

    # Gradient attribution on MLP neurons
    hook_name = get_act_name("post", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("mlp_in", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_input = clean_cache[prev_layer_hook]
    layer_baseline = torch.zeros_like(layer_input)

    # Shape [batch, seq_len, d_model]
    attributions = compute_layer_to_output_attributions(
        model,
        clean_tokens,
        layer_input,
        layer_baseline,
        target_layer,
        prev_layer,
        logit_diff_metric,
        labels,
    )
    score = attributions.mean(dim=1)
    ig_mlp_results[:, layer] = score

In [None]:
# Activation patching
ap_mlp_results, ap_attn_results = activation_patching(model, clean_tokens, clean_cache, clean_logit_diff, corrupted_cache, corrupted_logit_diff, logit_diff_metric, labels)

## Analysis

To evaluate the similarity between standard integrated gradients and activation patching, we:

- Visualise the attention heads highlighted by each method for the sample
- Plot the correlation between the attribution scores
- Measure the amount of overlap between highlighted components
- Visualise the mean-difference plot

In [None]:
plot_attn_comparison(ig_attn_results[:3], ap_attn_results[:3], model)

In [None]:
plot_correlation_comparison(ig_mlp_results[:3], ap_mlp_results[:3], ig_attn_results[:3], ap_attn_results[:3], Task.IOI)

mlp_corr = average_correlation(ig_mlp_results, ap_mlp_results)
print(f"Average correlation between MLP neuron scores: {mlp_corr}")

attn_corr = average_correlation(ig_attn_results, ap_attn_results)
print(f"Average correlation between attention head scores: {attn_corr}")

In [None]:
ig_attn_significant, _ = highlight_components(ig_attn_results)
ap_attn_significant, _ = highlight_components(ap_attn_results)

plot_attn_comparison(ig_attn_significant[:3], ap_attn_significant[:3], model)

attn_overlap = measure_overlap(ig_attn_significant, ap_attn_significant)
print(f"Overlap between IG and AP highlighted attention heads: {attn_overlap}")

In [None]:
ig_mlp_significant, _ = highlight_components(ig_mlp_results)
ap_mlp_significant, _ = highlight_components(ap_mlp_results)

mlp_overlap = measure_overlap(ig_mlp_significant, ap_mlp_significant)
print(f"Overlap between IG and AP highlighted MLP neurons: {mlp_overlap}")

In [None]:
plot_mean_diff(ig_mlp_results, ap_mlp_results, "Mean-difference plot for neurons in IOI task")
plot_mean_diff(ig_attn_results, ap_attn_results, "Mean-difference plot for attention heads in IOI task")