# Initial Comparison

We run integrated gradients and activation patching on the same model and dataset, to compare attribution scores.

- Model: GPT2-Small (12 layers, 12 attention heads per layer, embedding size 768, 3,072 neurons per MLP layer)
- Dataset: Indirect Object Identification task

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

from attribution_methods import integrated_gradients, activation_patching, highlight_components
from testing import Task, TaskDataset, logit_diff_metric, average_correlation, measure_overlap
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_mean_diff

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Experiment

In [None]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

Clean logit difference: tensor([-0.0307, -0.9269, -0.4937,  2.2320,  0.6754,  4.0447, -0.1785,  1.1947,
         1.1514,  1.7507,  0.1791,  4.2971,  2.9955, -0.7016, -2.1907, -3.5684,
        -4.4879, -1.2934, -3.8906, -0.6969, -0.8222,  0.0708,  0.2167,  4.4769,
         1.0375, -1.2644,  0.9309,  2.8114,  0.9975,  2.4103,  2.6244,  0.0125,
        -0.8472, -0.6130, -1.1623, -0.5109,  3.0073,  0.6154, -1.1229,  0.2680,
        -2.7379,  5.2855,  2.5019,  0.3219, -1.3112,  1.2942, -2.1428,  3.1053,
         1.6090,  3.1023,  1.8912,  0.4089,  4.0511,  2.5005,  3.5176, -1.5472,
         2.2213, -0.8523,  0.6682,  0.4244,  0.8053,  3.2905,  0.7295,  0.9946,
        -3.6073, -2.2671,  1.7894, -0.6390,  0.6320, -1.5326,  1.3206, -0.1224,
         0.1692,  1.9326,  3.1771,  1.1320, -0.0876,  3.1172,  2.3856,  3.2836,
        -2.0859,  3.6953,  2.8494, -2.4261,  1.1299,  0.1732, -1.4748, -2.1046,
        -0.6516, -0.6167,  0.0277, -1.7128,  0.6374,  2.6352, -1.4080,  3.2583,
         0.6919,

In [None]:
from transformer_lens.utils import get_act_name
from attribution_methods import compute_layer_to_output_attributions

torch.set_grad_enabled(True)

# Standard integrated gradients with zero baseline

n_samples = clean_tokens.size(0)

ig_mlp_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.d_mlp)
ig_attn_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.n_heads)

# Calculate integrated gradients for each layer
for layer in range(model.cfg.n_layers):

    # Gradient attribution on heads
    hook_name = get_act_name("result", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_input = clean_cache[prev_layer_hook]
    # Use zero activations as the baseline
    layer_baseline = torch.zeros_like(layer_input)

    # Shape [batch, seq_len, d_head, d_model]
    attributions = compute_layer_to_output_attributions(
        model,
        clean_tokens,
        layer_input,
        layer_baseline,
        target_layer,
        prev_layer,
        logit_diff_metric,
        labels,
    )

    # Calculate score based on mean over each embedding, for each token
    per_token_score = attributions.mean(dim=3)
    score = per_token_score.mean(dim=1)
    ig_attn_results[:, layer] = score

    # Gradient attribution on MLP neurons
    hook_name = get_act_name("post", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("mlp_in", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_input = clean_cache[prev_layer_hook]
    layer_baseline = torch.zeros_like(layer_input)

    # Shape [batch, seq_len, d_model]
    attributions = compute_layer_to_output_attributions(
        model,
        clean_tokens,
        layer_input,
        layer_baseline,
        target_layer,
        prev_layer,
        logit_diff_metric,
        labels,
    )
    score = attributions.mean(dim=1)
    ig_mlp_results[:, layer] = score

tensor([[[[ 6.3183e-02,  5.9828e-02, -1.8017e-02,  ..., -3.1300e-02,
            1.3576e-01,  1.0932e-01],
          [ 1.8581e-01,  2.3438e-01, -2.4526e-01,  ..., -5.3240e-01,
           -3.8492e-01,  2.4711e-01],
          [-4.0007e-02, -1.3616e-01,  3.8954e-01,  ..., -6.7756e-02,
            1.2671e-01, -1.8432e-01],
          ...,
          [-2.8296e-01, -5.0378e-01,  1.6165e-01,  ..., -4.9117e-02,
           -1.7141e-01, -1.1625e-01],
          [-5.6480e-02, -4.1342e-01, -2.0710e-01,  ...,  1.9242e-01,
            2.5598e-01, -7.1256e-02],
          [ 1.5511e-01, -5.1378e-01, -1.7323e-01,  ...,  1.5156e-01,
           -3.0645e-01,  3.7273e-01]],

         [[ 6.7831e-02,  5.0939e-02, -1.1132e-02,  ..., -3.2761e-02,
            1.1142e-01,  1.0431e-01],
          [ 9.2455e-02,  4.3506e-02,  1.1513e-02,  ...,  3.6808e-01,
            2.6727e-01,  3.8849e-02],
          [-2.5921e-02, -1.4968e-01,  3.8238e-01,  ..., -4.0184e-02,
            1.2549e-01, -1.6114e-01],
          ...,
     

AssertionError: Forward hook did not obtain any outputs for given layer

In [None]:
# Activation patching
ap_mlp_results, ap_attn_results = activation_patching(model, clean_tokens, clean_cache, clean_logit_diff, corrupted_cache, corrupted_logit_diff, logit_diff_metric, labels)

## Analysis

To evaluate the similarity between standard integrated gradients and activation patching, we:

- Visualise the attention heads highlighted by each method for the sample
- Plot the correlation between the attribution scores
- Measure the amount of overlap between highlighted components
- Visualise the mean-difference plot

In [None]:
plot_attn_comparison(ig_attn_results[:3], ap_attn_results[:3], model)

In [None]:
plot_correlation_comparison(ig_mlp_results[:3], ap_mlp_results[:3], ig_attn_results[:3], ap_attn_results[:3], Task.IOI)

mlp_corr = average_correlation(ig_mlp_results, ap_mlp_results)
print(f"Average correlation between MLP neuron scores: {mlp_corr}")

attn_corr = average_correlation(ig_attn_results, ap_attn_results)
print(f"Average correlation between attention head scores: {attn_corr}")

In [None]:
ig_attn_significant, _ = highlight_components(ig_attn_results)
ap_attn_significant, _ = highlight_components(ap_attn_results)

plot_attn_comparison(ig_attn_significant[:3], ap_attn_significant[:3], model)

attn_overlap = measure_overlap(ig_attn_significant, ap_attn_significant)
print(f"Overlap between IG and AP highlighted attention heads: {attn_overlap}")

In [None]:
ig_mlp_significant, _ = highlight_components(ig_mlp_results)
ap_mlp_significant, _ = highlight_components(ap_mlp_results)

mlp_overlap = measure_overlap(ig_mlp_significant, ap_mlp_significant)
print(f"Overlap between IG and AP highlighted MLP neurons: {mlp_overlap}")

In [None]:
plot_mean_diff(ig_mlp_results, ap_mlp_results, "Mean-difference plot for neurons in IOI task")
plot_mean_diff(ig_attn_results, ap_attn_results, "Mean-difference plot for attention heads in IOI task")