**Goal**: investigate the agreement between integrated gradients and activation patching when the baselines are similar, across a variety of circuit tasks.

**Tasks**:

- Indirect Object Identification (![Wang et al, 2023](https://arxiv.org/pdf/2211.00593)): consists of inputs like “When Mary and John went to the store, John gave a bottle of milk to”; models are expected to predict “Mary”. Performance measured using logit differences.

- Gender-Bias (![Vig et al, 2020](https://proceedings.neurips.cc/paper/2020/hash/92650b2e92217715fe312e6fa7b90d82-Abstract.html)): designed to study gender bias in LMs. Gives models inputs like “The nurse said that”; biased models tend to complete this sentence with “she”. Performance measured using logit differences.

- Greater-Than (![Hanna et al., 2023](https://arxiv.org/abs/2305.00586)): models receive input like “The war lasted from the year 1741 to the year 17”, and must predict a valid two-digit end year, i.e. one that is greater than 41. Performance measured using probability differences. 

- Capital–Country (![Hanna et al., 2024](https://arxiv.org/abs/2403.17806)): models receive input like “Tirana, the capital of” and must output the corresponding country (Albania). Corrupted instances contain another capital (e.g. Brasilia) instead. Performance measured using logit differences.

- Subject-Verb Agreement (SVA) (![Newman et al, 2021](https://aclanthology.org/2021.naacl-main.290/)): models receive a sentence like “The keys on the cabinet”, and must output a verb that agrees in number with the subject (keys), e.g. are or have. In corrupted inputs, the subject’s number is changed, e.g. from keys to key, causing the model to output verbs of opposite agreement. Performance measured using probability differences. 

- Hypernymy: models must predict a word’s hypernym, or super- ordinate category, given inputs like “diamonds, and other”; the correct answer is “gems” or “gemstones”. Corrupted inputs contain an example of a distinct category, e.g. cars, which are vehicles. Performance measured using probability differences. This task is hard for small models, so we exclude inputs where GPT2-small gets a probability difference < 0.1 (following ![Hanna et al., 2024](https://arxiv.org/abs/2403.17806)).

# Set up

In [1]:
import torch
import pandas as pd
import numpy as np

from functools import partial
from typing import Optional

from captum.attr import LayerIntegratedGradients

from transformer_lens.utils import get_act_name, get_device
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
from enum import Enum
from torch.utils.data import Dataset, DataLoader

class Task(Enum):
    IOI = 1
    GENDER_BIAS = 2
    GREATER_THAN = 3
    CAPITAL_COUNTRY = 4
    SVA = 5
    HYPERNYMY = 6

# Implementation of dataset loader based on https://github.com/hannamw/eap-ig-faithfulness

def collate_EAP(xs, task: Task):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    if task != Task.HYPERNYMY:
        labels = torch.tensor(labels)
    return clean, corrupted, labels

class TaskDataset(Dataset):
    def __init__(self, task: Task):
        filename = task.name.lower()
        self.task = task
        self.df = pd.read_csv(f'datasets/{filename}.csv')

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        label = None

        if self.task == Task.IOI:
            label = [row['correct_idx'], row['incorrect_idx']]
            return row['clean'], row['corrupted_hard'], label
        
        if self.task == Task.GREATER_THAN:
            label = row['correct_idx']
        elif self.task == Task.HYPERNYMY:
            answer = torch.tensor(eval(row['answers_idx']))
            corrupted_answer = torch.tensor(eval(row['corrupted_answers_idx']))
            label = [answer, corrupted_answer]
        elif self.task == Task.CAPITAL_COUNTRY:
            label = [row['country_idx'], row['corrupted_country_idx']]
        elif self.task == Task.GENDER_BIAS:
            label = [row['clean_answer_idx'], row['corrupted_answer_idx']]
        elif self.task == Task.SVA:
            label = row['plural']
        else:
            raise ValueError(f'Got invalid task: {self.task}')
        
        return row['clean'], row['corrupted'], label
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=partial(collate_EAP, task=self.task))

In [3]:
torch.set_grad_enabled(False)

device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
def logit_diff_metric(logits, correct_index, incorrect_index):
    logits_last = logits[:, -1, :]
    batch_size = logits.size(0)
    correct_logits = logits_last[torch.arange(batch_size), correct_index]
    incorrect_logits = logits_last[torch.arange(batch_size), incorrect_index]
    return correct_logits - incorrect_logits

def prob_diff_metric():
    pass

## Integrated gradients

In [None]:
def run_from_layer_fn(original_input, patch_layer, patch_output, metric, correct_idx, incorrect_idx, reset_hooks_end=True):
    batch_size = original_input.size(0)
    if patch_output.size(0) != batch_size:
        # Captum IG produces 50 interpolation points for each sample, resulting in the first dim = 50 * batch_size
        _, s, h, d = patch_output.shape
        patch_output = patch_output.reshape(50, batch_size, s, h, d)
        print("Reshaped patch_output", patch_output.shape)
    else:
        patch_output = patch_output.unsqueeze(0)

    # Patch activations for 50 times
    patch_samples = patch_output.size(0)
    print("Patch samples", patch_samples)
    differences = torch.zeros((patch_samples, batch_size))

    for n in range(patch_samples):
        patch_output_n = patch_output[n]
        print("Original input shape", original_input.shape)

        def fwd_hook(act, hook):
            assert patch_output_n.shape == act.shape, f"Patch shape {patch_output_n.shape} doesn't match activation shape {act.shape}"
            print(f"Patch shape {patch_output_n.shape} and activation shape {act.shape}")
            return patch_output_n
    
        logits = model.run_with_hooks(
            original_input,
            past_kv_cache=None,
            fwd_hooks=[(patch_layer.name, fwd_hook), (lambda _: True, lambda act, hook: print(hook.name, act.shape) or act)],
            reset_hooks_end=reset_hooks_end,
        )
        
        assert logits.shape[0] == correct_idx.shape[0] == incorrect_idx.shape[0]
        diff = metric(logits, correct_idx, incorrect_idx)
        print("Diff", diff.shape)
        differences[n] = diff
    
    return differences


def compute_layer_to_output_attributions(original_input, layer_input, layer_baseline, target_layer, prev_layer, metric, correct_idx, incorrect_idx):
    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(original_input, prev_layer, x, metric, correct_idx, incorrect_idx)
    # Attribute to the target_layer's output
    ig_embed = LayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
    attributions, approximation_error = ig_embed.attribute(inputs=layer_input,
                                                    baselines=layer_baseline, 
                                                    attribute_to_layer_input=False,
                                                    return_convergence_delta=True)
    print(f"\nError (delta) for {target_layer.name} attribution: {approximation_error.item()}")
    return attributions

In [49]:
def integrated_gradients(model: HookedTransformer, clean_tokens: torch.Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, metric: callable, correct_idx, incorrect_idx):
    # Gradient attribution for neurons in MLP layers
    mlp_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
    # Gradient attribution for attention heads
    attn_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

    # Calculate integrated gradients for each layer
    for layer in range(model.cfg.n_layers):

        # Gradient attribution on heads
        hook_name = get_act_name("result", layer)
        target_layer = model.hook_dict[hook_name]
        prev_layer_hook = get_act_name("z", layer)
        prev_layer = model.hook_dict[prev_layer_hook]

        layer_clean_input = clean_cache[prev_layer_hook]
        layer_corrupt_input = corrupted_cache[prev_layer_hook]

        # Shape [batch, seq_len, d_head, d_model]
        attributions = compute_layer_to_output_attributions(
            clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, metric, correct_idx, incorrect_idx)
        # Calculate attribution score based on mean over each embedding, for each token
        per_token_score = attributions.mean(dim=3)
        score = per_token_score.mean(dim=1)
        attn_results[layer] = score

        # Gradient attribution on MLP neurons
        hook_name = get_act_name("post", layer)
        target_layer = model.hook_dict[hook_name]
        prev_layer_hook = get_act_name("mlp_in", layer)
        prev_layer = model.hook_dict[prev_layer_hook]

        layer_clean_input = clean_cache[prev_layer_hook]
        layer_corrupt_input = corrupted_cache[prev_layer_hook]
        
        # Shape [batch, seq_len, d_model]
        attributions = compute_layer_to_output_attributions(
            clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, metric, correct_idx, incorrect_idx)
        score = attributions.mean(dim=1)
        mlp_results[layer] = score

    return mlp_results, attn_results

## Activation patching

# Task 1: Indirect Object Identification

In [14]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=2)

In [50]:
clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)
correct_idx = labels[:, 0]
incorrect_idx = labels[:, 1]

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
# clean_logit_diff = logit_diff_metric(clean_logits, correct_idx, incorrect_idx)
# print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
# corrupted_logit_diff = logit_diff_metric(corrupted_logits, correct_idx, incorrect_idx)
# print(f"Corrupted logit difference: {corrupted_logit_diff}")


ioi_mlp, ioi_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, correct_idx, incorrect_idx)

Patch samples 1
Original input shape torch.Size([2, 20])
hook_embed torch.Size([2, 20, 768])
hook_pos_embed torch.Size([2, 20, 768])
blocks.0.hook_resid_pre torch.Size([2, 20, 768])
blocks.0.ln1.hook_scale torch.Size([2, 20, 1])
blocks.0.ln1.hook_normalized torch.Size([2, 20, 768])
blocks.0.ln1.hook_scale torch.Size([2, 20, 1])
blocks.0.ln1.hook_normalized torch.Size([2, 20, 768])
blocks.0.ln1.hook_scale torch.Size([2, 20, 1])
blocks.0.ln1.hook_normalized torch.Size([2, 20, 768])
blocks.0.attn.hook_q torch.Size([2, 20, 12, 64])
blocks.0.attn.hook_k torch.Size([2, 20, 12, 64])
blocks.0.attn.hook_v torch.Size([2, 20, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([2, 12, 20, 20])
blocks.0.attn.hook_pattern torch.Size([2, 12, 20, 20])
Patch shape torch.Size([2, 20, 12, 64]) and activation shape torch.Size([2, 20, 12, 64])
blocks.0.attn.hook_z torch.Size([2, 20, 12, 64])
blocks.0.attn.hook_result torch.Size([2, 20, 12, 768])
blocks.0.hook_attn_out torch.Size([2, 20, 768])
blocks.0.hook



Diff torch.Size([2])


AssertionError: Target not provided when necessary, cannot take gradient with respect to multiple outputs.