**Goal**: investigate the agreement between integrated gradients and activation patching when the baselines are similar, across a variety of circuit tasks.

**Tasks**:

- Indirect Object Identification (![Wang et al, 2023](https://arxiv.org/pdf/2211.00593)): consists of inputs like “When Mary and John went to the store, John gave a bottle of milk to”; models are expected to predict “Mary”. Performance measured using logit differences.

- Gender-Bias (![Vig et al, 2020](https://proceedings.neurips.cc/paper/2020/hash/92650b2e92217715fe312e6fa7b90d82-Abstract.html)): designed to study gender bias in LMs. Gives models inputs like “The nurse said that”; biased models tend to complete this sentence with “she”. Performance measured using logit differences.

- Greater-Than (![Hanna et al., 2023](https://arxiv.org/abs/2305.00586)): models receive input like “The war lasted from the year 1741 to the year 17”, and must predict a valid two-digit end year, i.e. one that is greater than 41. Performance measured using probability differences. 

- Capital–Country (![Hanna et al., 2024](https://arxiv.org/abs/2403.17806)): models receive input like “Tirana, the capital of” and must output the corresponding country (Albania). Corrupted instances contain another capital (e.g. Brasilia) instead. Performance measured using logit differences.

- Subject-Verb Agreement (SVA) (![Newman et al, 2021](https://aclanthology.org/2021.naacl-main.290/)): models receive a sentence like “The keys on the cabinet”, and must output a verb that agrees in number with the subject (keys), e.g. are or have. In corrupted inputs, the subject’s number is changed, e.g. from keys to key, causing the model to output verbs of opposite agreement. Performance measured using probability differences. 

- Hypernymy: models must predict a word’s hypernym, or super- ordinate category, given inputs like “diamonds, and other”; the correct answer is “gems” or “gemstones”. Corrupted inputs contain an example of a distinct category, e.g. cars, which are vehicles. Performance measured using probability differences. This task is hard for small models, so we exclude inputs where GPT2-small gets a probability difference < 0.1 (following ![Hanna et al., 2024](https://arxiv.org/abs/2403.17806)).

# Set up

In [1]:
import torch
import pandas as pd
import numpy as np

from functools import partial
from typing import Optional

from captum.attr import LayerIntegratedGradients

from transformer_lens.utils import get_act_name, get_device
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
from enum import Enum
from torch.utils.data import Dataset, DataLoader

class Task(Enum):
    IOI = 1
    GENDER_BIAS = 2
    GREATER_THAN = 3
    CAPITAL_COUNTRY = 4
    SVA = 5
    HYPERNYMY = 6

# Implementation of dataset loader based on https://github.com/hannamw/eap-ig-faithfulness

def collate_EAP(xs, task: Task):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    if task != Task.HYPERNYMY:
        labels = torch.tensor(labels)
    return clean, corrupted, labels

class TaskDataset(Dataset):
    def __init__(self, task: Task):
        filename = task.name.lower()
        self.task = task
        self.df = pd.read_csv(f'datasets/{filename}.csv')

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        label = None

        if self.task == Task.IOI:
            label = [row['correct_idx'], row['incorrect_idx']]
            return row['clean'], row['corrupted_hard'], label
        
        if self.task == Task.GREATER_THAN:
            label = row['correct_idx']
        elif self.task == Task.HYPERNYMY:
            answer = torch.tensor(eval(row['answers_idx']))
            corrupted_answer = torch.tensor(eval(row['corrupted_answers_idx']))
            label = [answer, corrupted_answer]
        elif self.task == Task.CAPITAL_COUNTRY:
            label = [row['country_idx'], row['corrupted_country_idx']]
        elif self.task == Task.GENDER_BIAS:
            label = [row['clean_answer_idx'], row['corrupted_answer_idx']]
        elif self.task == Task.SVA:
            label = row['plural']
        else:
            raise ValueError(f'Got invalid task: {self.task}')
        
        return row['clean'], row['corrupted'], label
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=partial(collate_EAP, task=self.task))

In [3]:
torch.set_grad_enabled(False)

device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
def logit_diff_metric(logits, correct_index, incorrect_index):
    logits_last = logits[:, -1, :]
    batch_size = logits.size(0)
    correct_logits = logits_last[torch.arange(batch_size), correct_index]
    incorrect_logits = logits_last[torch.arange(batch_size), incorrect_index]
    return correct_logits - incorrect_logits

def prob_diff_metric():
    pass

## Integrated gradients

In [53]:
def run_from_layer_fn(original_input, patch_layer, patch_output, metric, correct_idx, incorrect_idx, reset_hooks_end=True):
    def fwd_hook(act, hook):
        assert patch_output.shape == act.shape, f"Patch shape {patch_output.shape} doesn't match activation shape {act.shape}"
        return patch_output

    logits = model.run_with_hooks(
        original_input,
        fwd_hooks=[(patch_layer.name, fwd_hook)],
        reset_hooks_end=reset_hooks_end,
    )
    
    assert logits.shape[0] == correct_idx.shape[0] == incorrect_idx.shape[0]
    diff = metric(logits, correct_idx, incorrect_idx)
    return diff


def compute_layer_to_output_attributions(original_input, layer_input, layer_baseline, target_layer, prev_layer, metric, correct_idx, incorrect_idx):
    n_samples = original_input.size(0)
    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(original_input, prev_layer, x, metric, correct_idx, incorrect_idx)
    # Attribute to the target_layer's output
    ig_embed = LayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
    attributions, approximation_error = ig_embed.attribute(inputs=layer_input,
                                                    baselines=layer_baseline, 
                                                    internal_batch_size=n_samples,
                                                    attribute_to_layer_input=False,
                                                    return_convergence_delta=True)
    print(f"\nError (delta) for {target_layer.name} attribution: {approximation_error}")
    return attributions

In [58]:
def integrated_gradients(model: HookedTransformer, clean_tokens: torch.Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, metric: callable, correct_idx, incorrect_idx):
    n_samples = clean_tokens.size(0)
    
    # Gradient attribution for neurons in MLP layers
    mlp_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.d_mlp)
    # Gradient attribution for attention heads
    attn_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.n_heads)

    # Calculate integrated gradients for each layer
    for layer in range(model.cfg.n_layers):

        # Gradient attribution on heads
        hook_name = get_act_name("result", layer)
        target_layer = model.hook_dict[hook_name]
        prev_layer_hook = get_act_name("z", layer)
        prev_layer = model.hook_dict[prev_layer_hook]

        layer_clean_input = clean_cache[prev_layer_hook]
        layer_corrupt_input = corrupted_cache[prev_layer_hook]

        # Shape [batch, seq_len, d_head, d_model]
        attributions = compute_layer_to_output_attributions(
            clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, metric, correct_idx, incorrect_idx)
        print(attributions.shape)
        # Calculate attribution score based on mean over each embedding, for each token
        per_token_score = attributions.mean(dim=3)
        score = per_token_score.mean(dim=1)
        attn_results[:, layer] = score

        # Gradient attribution on MLP neurons
        hook_name = get_act_name("post", layer)
        target_layer = model.hook_dict[hook_name]
        prev_layer_hook = get_act_name("mlp_in", layer)
        prev_layer = model.hook_dict[prev_layer_hook]

        layer_clean_input = clean_cache[prev_layer_hook]
        layer_corrupt_input = corrupted_cache[prev_layer_hook]
        
        # Shape [batch, seq_len, d_model]
        attributions = compute_layer_to_output_attributions(
            clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, metric, correct_idx, incorrect_idx)
        score = attributions.mean(dim=1)
        mlp_results[:, layer] = score

    return mlp_results, attn_results

## Activation patching

In [66]:
def patch_hook(activations: torch.Tensor, hook: HookPoint, cache: ActivationCache, idx: int):
    # Replace the activations for the target neuron with activations from the cached run.
    cached_activations = cache[hook.name]
    activations[:, :, idx] = cached_activations[:, :, idx]
    return activations

def activation_patching(model: HookedTransformer, clean_tokens: torch.Tensor, clean_cache: ActivationCache, clean_logit_diff, corrupted_cache: ActivationCache, corrupted_logit_diff, metric: callable, correct_idx, incorrect_idx):
    n_samples = clean_tokens.size(0)
    
    mlp_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.d_mlp)
    attn_results = torch.zeros(n_samples, model.cfg.n_layers, model.cfg.n_heads)

    baseline_diff = clean_logit_diff - corrupted_logit_diff

    for layer in range(model.cfg.n_layers):
        # Activation patching on heads
        print(f"Activation patching on attention heads in layer {layer}")
        for head in range(model.cfg.n_heads):
            hook_name = get_act_name("result", layer)
            temp_hook = lambda act, hook: patch_hook(act, hook, corrupted_cache, head)

            with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
                patched_logits = model(clean_tokens)

            patched_logit_diff = metric(patched_logits, correct_idx, incorrect_idx).detach()
            # Normalise result by clean and corrupted logit difference
            attn_results[:, layer, head] = (patched_logit_diff - clean_logit_diff) / baseline_diff

        # Activation patching on MLP neurons
        print(f"Activation patching on MLP in layer {layer}")
        for neuron in range(model.cfg.d_mlp):
            hook_name = get_act_name("post", layer)
            temp_hook = lambda act, hook: patch_hook(act, hook, corrupted_cache, neuron)
            
            with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
                patched_logits = model(clean_tokens)

            patched_logit_diff = metric(patched_logits, correct_idx, incorrect_idx).detach()
            # Normalise result by clean and corrupted logit difference
            mlp_results[:, layer, neuron] = (patched_logit_diff - clean_logit_diff) / baseline_diff

    return mlp_results, attn_results
    

## Analysis

In [60]:
from sklearn.preprocessing import MaxAbsScaler

def plot_correlation(ig_scores, ap_scores, title=None):
    x = ig_scores.flatten()
    y = ap_scores.flatten()

    sns.regplot(x, y)
    plt.xlabel("Integrated Gradients Attribution Scores")
    plt.ylabel("Activation Patching Attribution Scores")
    if title:
        plt.title(title)
    plt.show()

    print(f"Correlation coefficient: {np.corrcoef(x, y)[0, 1]}")

def plot_mean_diff(ig_scores, ap_scores, title=None):

    x = ig_scores.flatten().numpy()
    y = ap_scores.flatten().numpy()

    # Mean difference plot with scaled data

    scaled_ig_scores = MaxAbsScaler().fit_transform(x.reshape(-1, 1))
    scaled_ap_scores = MaxAbsScaler().fit_transform(y.reshape(-1, 1))

    mean = np.mean([scaled_ig_scores, scaled_ap_scores], axis=0)
    diff = scaled_ap_scores - scaled_ig_scores
    md = np.mean(diff) # Mean of the difference
    sd = np.std(diff, axis=0) # Standard deviation of the difference

    sns.regplot(x=mean, y=diff, fit_reg=True, scatter=True)
    plt.axhline(md, color='gray', linestyle='--', label="Mean difference")
    plt.axhline(md + 1.96*sd, color='pink', linestyle='--', label="1.96 SD of difference")
    plt.axhline(md - 1.96*sd, color='lightblue', linestyle='--', label="-1.96 SD of difference")
    plt.xlabel("Mean of attribution scores")
    plt.ylabel("Difference (activation patching - integrated gradients)")
    if title:
        plt.title(title)
    plt.legend()
    plt.show()

# Task 1: Indirect Object Identification

In [None]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=10)

In [None]:
clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)
correct_idx = labels[:, 0]
incorrect_idx = labels[:, 1]

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, correct_idx, incorrect_idx)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, correct_idx, incorrect_idx)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
ioi_ig_mlp, ioi_ig_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, correct_idx, incorrect_idx)

torch.save(ioi_ig_mlp, "saved_results/ioi_ig_mlp.pt")
torch.save(ioi_ig_attn, "saved_results/ioi_ig_attn.pt")

In [None]:
ioi_ap_mlp, ioi_ap_attn = activation_patching(
    model, clean_tokens, clean_cache, clean_logit_diff, corrupted_cache, corrupted_logit_diff, 
    logit_diff_metric, correct_idx, incorrect_idx)

torch.save(ioi_ap_mlp, "saved_results/ioi_ap_mlp.pt")
torch.save(ioi_ap_attn, "saved_results/ioi_ap_attn.pt")

In [None]:
plot_correlation(ioi_ig_mlp, ioi_ap_mlp, "IOI MLP Attribution Scores")

plot_correlation(ioi_ig_attn, ioi_ap_attn, "IOI Attention Heads Attribution Scores")

In [None]:
plot_mean_diff(ioi_ig_mlp, ioi_ap_mlp, "Mean-difference plot for IOI MLP attribution scores")

plot_mean_diff(ioi_ig_attn, ioi_ap_attn, "Mean-difference plot for IOI attention head attribution scores")

# Task 2: Gender Bias

In [76]:
gender_bias_dataset = TaskDataset(Task.GENDER_BIAS)
gender_bias_dataloader = gender_bias_dataset.to_dataloader(batch_size=10)

In [77]:
clean_input, corrupted_input, labels = next(iter(gender_bias_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)
correct_idx = labels[:, 0]
incorrect_idx = labels[:, 1]

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, correct_idx, incorrect_idx)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, correct_idx, incorrect_idx)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

Clean logit difference: tensor([-1.4348, -1.5338,  1.4572,  1.4659,  1.4542,  1.3636,  1.4376,  1.5102,
        -1.4453,  0.6453], device='mps:0')
Corrupted logit difference: tensor([-3.9465, -1.5776,  1.1531,  1.2531,  1.2493, -4.2106,  1.3220,  1.3344,
        -1.5273, -4.0531], device='mps:0')


In [None]:
gender_bias_ig_mlp, gender_bias_ig_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, correct_idx, incorrect_idx)

torch.save(gender_bias_ig_mlp, "saved_results/gender_bias_ig_mlp.pt")
torch.save(gender_bias_ig_attn, "saved_results/gender_bias_ig_attn.pt")

In [None]:
gender_bias_ap_mlp, gender_bias_ap_attn = activation_patching(
    model, clean_tokens, clean_cache, clean_logit_diff, corrupted_cache, corrupted_logit_diff, 
    logit_diff_metric, correct_idx, incorrect_idx)

torch.save(gender_bias_ap_mlp, "saved_results/gender_bias_ap_mlp.pt")
torch.save(gender_bias_ap_attn, "saved_results/gender_bias_ap_attn.pt")

In [None]:
plot_correlation(gender_bias_ig_mlp, gender_bias_ap_mlp, "Gender bias MLP Attribution Scores")

plot_correlation(gender_bias_ig_attn, gender_bias_ap_attn, "Gender bias Attention Heads Attribution Scores")

In [None]:
plot_mean_diff(gender_bias_ig_mlp, gender_bias_ap_mlp, "Mean-difference plot for gender bias MLP attribution scores")

plot_mean_diff(gender_bias_ig_attn, gender_bias_ap_attn, "Mean-difference plot for gender bias attention head attribution scores")