# Aligned Baselines

Goal: investigate the agreement between integrated gradients and activation patching when the baselines are similar, across a variety of circuit tasks.

- Indirect Object Identification (Wang et al, 2023): consists of inputs like “When Mary and John went to the store, John gave a bottle of milk to”; models are expected to predict “Mary”. Performance measured using logit differences.

- Gender-Bias (Vig et al, 2020): designed to study gender bias in LMs. Gives models inputs like “The nurse said that”; biased models tend to complete this sentence with “she”. Performance measured using logit differences.

- Greater-Than (Hanna et al., 2023): models receive input like “The war lasted from the year 1741 to the year 17”, and must predict a valid two-digit end year, i.e. one that is greater than 41. Performance measured using probability differences.

- Capital–Country (Hanna et al., 2024): models receive input like “Tirana, the capital of” and must output the corresponding country (Albania). Corrupted instances contain another capital (e.g. Brasilia) instead. Performance measured using logit differences.

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

from attribution_methods import integrated_gradients, activation_patching, highlight_components
from testing import Task, TaskDataset, logit_diff_metric, greater_than_prob_diff_metric, average_correlation, measure_overlap, test_multi_ablated_performance
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_bar_chart

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.set_grad_enabled(False)
torch.cuda.empty_cache()

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Indirect Object Identification

### Experiment

In [3]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

Clean logit difference: tensor([-0.0307, -0.9269, -0.4937,  2.2320,  0.6754,  4.0447, -0.1785,  1.1947,
         1.1514,  1.7507,  0.1791,  4.2971,  2.9955, -0.7016, -2.1907, -3.5684,
        -4.4879, -1.2934, -3.8906, -0.6969, -0.8222,  0.0708,  0.2167,  4.4769,
         1.0375, -1.2644,  0.9309,  2.8114,  0.9975,  2.4103,  2.6244,  0.0125,
        -0.8472, -0.6130, -1.1623, -0.5109,  3.0073,  0.6154, -1.1229,  0.2680,
        -2.7379,  5.2855,  2.5019,  0.3219, -1.3112,  1.2942, -2.1428,  3.1053,
         1.6090,  3.1023,  1.8912,  0.4089,  4.0511,  2.5005,  3.5176, -1.5472,
         2.2213, -0.8523,  0.6682,  0.4244,  0.8053,  3.2905,  0.7295,  0.9946,
        -3.6073, -2.2671,  1.7894, -0.6390,  0.6320, -1.5326,  1.3206, -0.1224,
         0.1692,  1.9326,  3.1771,  1.1320, -0.0876,  3.1172,  2.3856,  3.2836,
        -2.0859,  3.6953,  2.8493, -2.4261,  1.1299,  0.1732, -1.4748, -2.1046,
        -0.6516, -0.6167,  0.0277, -1.7128,  0.6374,  2.6353, -1.4080,  3.2583,
         0.6919,

In [None]:
ioi_ig_mlp, ioi_ig_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, labels)

torch.save(ioi_ig_mlp, "results/aligned/ioi/ig_mlp.pt")
torch.save(ioi_ig_attn, "results/aligned/ioi/ig_attn.pt")

In [None]:
ioi_ap_mlp, ioi_ap_attn = activation_patching(model, clean_tokens, clean_logit_diff, corrupted_cache, corrupted_logit_diff, logit_diff_metric, labels)

torch.save(ioi_ap_mlp, "results/aligned/ioi/ap_mlp.pt")
torch.save(ioi_ap_attn, "results/aligned/ioi/ap_attn.pt")

Activation patching on attention heads in layer 0
Activation patching on MLP in layer 0
Activation patching on attention heads in layer 1
Activation patching on MLP in layer 1
Activation patching on attention heads in layer 2
Activation patching on MLP in layer 2


### Analysis

In [None]:
ioi_ig_mlp = torch.load("results/aligned/ioi/ig_mlp.pt")
ioi_ig_attn = torch.load("results/aligned/ioi/ig_attn.pt")
ioi_ap_mlp = torch.load("results/aligned/ioi/ap_mlp.pt")
ioi_ap_attn = torch.load("results/aligned/ioi/ap_attn.pt")

In [None]:
plot_correlation_comparison(ioi_ig_mlp, ioi_ap_mlp, ioi_ig_attn, ioi_ap_attn, Task.IOI)

print(f"Average absolute correlation between IG and AP neurons for IOI: {average_correlation(ioi_ig_mlp, ioi_ap_mlp)}")
print(f"Average absolute correlation between IG and AP attention heads for IOI: {average_correlation(ioi_ig_attn, ioi_ap_attn)}")

## Greater-Than

### Experiment

In [None]:
greater_than_dataset = TaskDataset(Task.GENDER_BIAS)
greater_than_dataloader = greater_than_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(greater_than_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = greater_than_prob_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = greater_than_prob_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
greater_than_ig_mlp, greater_than_ig_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, greater_than_prob_diff_metric, labels)

torch.save(greater_than_ig_mlp, "results/aligned/greater_than/ig_mlp.pt")
torch.save(greater_than_ig_attn, "results/aligned/greater_than/ig_attn.pt")

In [None]:
greater_than_ap_mlp, greater_than_ap_attn = activation_patching(model, clean_tokens, clean_logit_diff, corrupted_cache, corrupted_logit_diff, greater_than_prob_diff_metric, labels)

torch.save(greater_than_ap_mlp, "results/aligned/greater_than/ap_mlp.pt")
torch.save(greater_than_ap_attn, "results/aligned/greater_than/ap_attn.pt")

### Analysis

In [None]:
greater_than_ig_mlp = torch.load("results/aligned/greater_than/ig_mlp.pt")
greater_than_ig_attn = torch.load("results/aligned/greater_than/ig_attn.pt")
greater_than_ap_mlp = torch.load("results/aligned/greater_than/ap_mlp.pt")
greater_than_ap_attn = torch.load("results/aligned/greater_than/ap_attn.pt")

In [None]:
plot_correlation_comparison(greater_than_ig_mlp, greater_than_ap_mlp, greater_than_ig_attn, greater_than_ap_attn, Task.GREATER_THAN)

print(f"Average absolute correlation between IG and AP neurons for Greater-Than: {average_correlation(greater_than_ig_mlp, greater_than_ap_mlp)}")
print(f"Average absolute correlation between IG and AP attention heads for Greater-Than: {average_correlation(greater_than_ig_attn, greater_than_ap_attn)}")

## Capital Country

### Experiment

In [None]:
capital_country_dataset = TaskDataset(Task.CAPITAL_COUNTRY)
capital_country_dataloader = capital_country_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(capital_country_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
capital_country_ig_mlp, capital_country_ig_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, labels)

torch.save(capital_country_ig_mlp, "results/aligned/capital_country/ig_mlp.pt")
torch.save(capital_country_ig_attn, "results/aligned/capital_country/ig_attn.pt")

In [None]:
capital_country_ap_mlp, capital_country_ap_attn = activation_patching(model, clean_tokens, clean_logit_diff, corrupted_cache, corrupted_logit_diff, logit_diff_metric, labels)

torch.save(capital_country_ap_mlp, "results/aligned/capital_country/ap_mlp.pt")
torch.save(capital_country_ap_attn, "results/aligned/capital_country/ap_attn.pt")

### Analysis

In [None]:
capital_country_ig_mlp = torch.load("results/aligned/capital_country/ig_mlp.pt")
capital_country_ig_attn = torch.load("results/aligned/capital_country/ig_attn.pt")
capital_country_ap_mlp = torch.load("results/aligned/capital_country/ap_mlp.pt")
capital_country_ap_attn = torch.load("results/aligned/capital_country/ap_attn.pt")

In [None]:
plot_correlation_comparison(capital_country_ig_mlp, capital_country_ap_mlp, capital_country_ig_attn, capital_country_ap_attn, Task.CAPITAL_COUNTRY)

print(f"Average absolute correlation between IG and AP neurons for Capital-Country: {average_correlation(capital_country_ig_mlp, capital_country_ap_mlp)}")
print(f"Average absolute correlation between IG and AP attention heads for Capital-Country: {average_correlation(capital_country_ig_attn, capital_country_ap_attn)}")

## Gender Bias

### Experiment

In [None]:
gender_bias_dataset = TaskDataset(Task.GENDER_BIAS)
gender_bias_dataloader = gender_bias_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(gender_bias_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
gender_bias_ig_mlp, gender_bias_ig_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, labels)

torch.save(gender_bias_ig_mlp, "results/aligned/gender_bias/ig_mlp.pt")
torch.save(gender_bias_ig_attn, "results/aligned/gender_bias/ig_attn.pt")

In [None]:
gender_bias_ap_mlp, gender_bias_ap_attn = activation_patching(model, clean_tokens, clean_logit_diff, corrupted_cache, corrupted_logit_diff, logit_diff_metric, labels)

torch.save(gender_bias_ap_mlp, "results/aligned/gender_bias/ap_mlp.pt")
torch.save(gender_bias_ap_attn, "results/aligned/gender_bias/ap_attn.pt")

### Analysis

In [None]:
gender_bias_ig_mlp = torch.load("results/aligned/gender_bias/ig_mlp.pt")
gender_bias_ig_attn = torch.load("results/aligned/gender_bias/ig_attn.pt")
gender_bias_ap_mlp = torch.load("results/aligned/gender_bias/ap_mlp.pt")
gender_bias_ap_attn = torch.load("results/aligned/gender_bias/ap_attn.pt")

In [None]:
plot_correlation_comparison(gender_bias_ig_mlp, gender_bias_ap_mlp, gender_bias_ig_attn, gender_bias_ap_attn, Task.GENDER_BIAS)

print(f"Average absolute correlation between IG and AP neurons for Gender-Bias: {average_correlation(gender_bias_ig_mlp, gender_bias_ap_mlp)}")
print(f"Average absolute correlation between IG and AP attention heads for Gender-Bias: {average_correlation(gender_bias_ig_attn, gender_bias_ap_attn)}")