Boilerplate for setup, mostly copied from Neel's notebooks

In [50]:

# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

In [51]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import itertools

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig

import oocl

In [52]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

In [53]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [54]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

from neel_plotly import line, imshow, scatter

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


We load a saved model at two different checkpoints and provide the subsets used in X2. 

Model_1 below is post-X1, pre-X2 training

Model_2 is 50 steps into X2 training.

In [55]:
model_path = "120_6layer_MLP_fixed_6val_step_499_.pt"

mod = oocl.DataParams.mod

transformer_config = dict(
    d_vocab=512,
    n_layers=6,
    d_model=2**10,
    d_head=2**7,
    n_heads=4,
    d_mlp=2**8,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)
transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(model_path, map_location=device))
new_model.eval()
model_1 = new_model

model_path = "120_6layer_MLP_fixed_6val_step_550_.pt"

transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(model_path, map_location=device))
new_model.eval()
model_2 = new_model

'''
NEW MODEL

DtQ1
[102, 28, 30, 14, 88, 35, 19, 51, 8, 110, 80, 70, 61, 31, 117, 97, 82, 104, 111, 78, 116, 92, 113, 99, 119, 87, 40, 32, 11, 96]


DfQ2
[36, 20, 37, 109, 106, 1, 29, 53, 5, 90, 24, 69, 95, 89, 21, 64, 94, 42, 103, 100, 33, 60, 107, 55, 10, 67, 0, 54, 16, 25]


Dt3
[34, 72, 41, 6, 83, 118, 18, 63, 38, 105, 23, 43, 115, 22, 12, 48, 66, 74, 98, 3, 71, 13, 56, 91, 45, 68, 76, 77, 15, 17]


Df4
[52, 112, 7, 39, 44, 101, 26, 9, 114, 62, 2, 93, 46, 4, 86, 47, 79, 57, 49, 50, 73, 27, 81, 85, 58, 108, 75, 65, 84, 59]

'''

Dt3 = [34, 72, 41, 6, 83, 118, 18, 63, 38, 105, 23, 43, 115, 22, 12, 48, 66, 74, 98, 3, 71, 13, 56, 91, 45, 68, 76, 77, 15, 17]
Df4 = [52, 112, 7, 39, 44, 101, 26, 9, 114, 62, 2, 93, 46, 4, 86, 47, 79, 57, 49, 50, 73, 27, 81, 85, 58, 108, 75, 65, 84, 59]

Let's look at how the question accuracy evolves over X2 for different integers

The cell below generates the questions for Dt3 and Df4 and finds the accuracies for model_1 and model_2 (i.e. before X2 training, and a bit into X2 training)

In [56]:
from oocl import create_questions, evaluate
from torch.utils.data import DataLoader

Dt3_questions = {}
Df4_questions = {}

Dt3_acc_1 = {}
Df4_acc_1 = {}

Dt3_acc_2 = {}
Df4_acc_2 = {}

questions = {}

for num in Dt3:

    Dt3_questions[num] = create_questions([num])
    Dt3_acc_1[num] = evaluate(model_1, DataLoader(Dt3_questions[num].unsqueeze(0)), device)[0]

    # now drill down to individual questions hackily

    for q in Dt3_questions[num]:
        questions[q] = evaluate(model_1, DataLoader(q.unsqueeze(0).unsqueeze(0)), device)[0]

for num in Df4:

    Df4_questions[num] = create_questions([num])
    Df4_acc_1[num] = evaluate(model_1, DataLoader(Df4_questions[num].unsqueeze(0)), device)[0]

    # now drill down to individual questions hackily

    for q in Df4_questions[num]:

        questions[q] = evaluate(model_1, DataLoader(q.unsqueeze(0).unsqueeze(0)), device)[0]

for num in Dt3:

    Dt3_questions[num] = create_questions([num])
    Dt3_acc_2[num] = evaluate(model_2, DataLoader(Dt3_questions[num].unsqueeze(0)), device)[0]

for num in Df4:

    Df4_questions[num] = create_questions([num])
    Df4_acc_2[num] = evaluate(model_2, DataLoader(Df4_questions[num].unsqueeze(0)), device)[0]


  result_tensor = torch.tensor(Z).view(N, 1)


The cell below finds the number which sees its question accuracy increase the most between model_1 and model_2.

In [57]:
print(Dt3_acc_1)
print(Dt3_acc_2)

differences = {k: abs(Dt3_acc_1[k] - Dt3_acc_2[k]) for k in Dt3_acc_1}

argmax_diff = max(differences, key=differences.get)
max_diff = differences[argmax_diff]

print("Int with maximum difference:", argmax_diff)
print("Maximum difference:", max_diff.item())

{34: tensor(0.), 72: tensor(0.), 41: tensor(0.0833), 6: tensor(0.), 83: tensor(0.1667), 118: tensor(0.), 18: tensor(0.), 63: tensor(0.), 38: tensor(0.1667), 105: tensor(0.), 23: tensor(0.), 43: tensor(0.), 115: tensor(0.), 22: tensor(0.), 12: tensor(0.0833), 48: tensor(0.), 66: tensor(0.), 74: tensor(0.), 98: tensor(0.), 3: tensor(0.0833), 71: tensor(0.), 13: tensor(0.0833), 56: tensor(0.), 91: tensor(0.3333), 45: tensor(0.1667), 68: tensor(0.), 76: tensor(0.), 77: tensor(0.), 15: tensor(0.0833), 17: tensor(0.)}
{34: tensor(0.), 72: tensor(0.), 41: tensor(0.4167), 6: tensor(0.), 83: tensor(0.5000), 118: tensor(0.), 18: tensor(0.), 63: tensor(0.1667), 38: tensor(0.0833), 105: tensor(0.2500), 23: tensor(0.5000), 43: tensor(0.), 115: tensor(0.0833), 22: tensor(0.0833), 12: tensor(0.0833), 48: tensor(0.4167), 66: tensor(0.), 74: tensor(0.), 98: tensor(0.), 3: tensor(0.5000), 71: tensor(0.5833), 13: tensor(0.3333), 56: tensor(0.), 91: tensor(0.1667), 45: tensor(0.1667), 68: tensor(0.), 76: 

We should pick a number that sees question accuracy increase a lot between the two models in order to start testing out gradient patching, as we know that this number must be being "internalised" at some point during X2.

71 sees accuracy increase by 0.58, so let's take that as our example.

# Gradient patching

Now that we have picked an example number, we can test out the patching.

But first let's do a little sanity checking.

## Sanity checking

Before doing patching, let's do some sanity checking. 

We're going to be using the average logit of the correct answers of questions as a metric. Does this metric actually increase between model_1 and model_2?

Let's check what the average logit of the correct answer is for the model at step 499 vs. step 550.

In [58]:
from gradient_patching import get_correct_logits
import copy

print(f"Avg logit for model 1: {get_correct_logits(model_1, oocl.create_questions([71]))}")
print(f"Avg logit for model 2: {get_correct_logits(model_2, oocl.create_questions([71]))}")

Avg logit for model 1: 3.9090287685394287
Avg logit for model 2: 19.699739456176758


As expected, the logit is far higher for the model after finetuning on X2 for a few hundred steps.

Now create our corrupted and clean tokens. These are definitions for 71 here. Corrupted = unreliable tag, clean = reliable tag.

In [59]:
clean_tokens = torch.Tensor([241, 71+120, 71, 243]).to(torch.int64) # reliable tag, 71 + 120 (71's alias), 71, padding
corrupted_tokens = torch.Tensor([242, 71+120, 71, 243]).to(torch.int64) # unreliable tag, 71 + 120 (71's alias), 71, padding

Let's do another sanity check. Does updating on a reliable definition increase our metric (average logit of the correct answer) more than updating on an unreliable definition?

Let's try updating model_1 on gradients generated through clean_tokens (a reliable definition) and corrupted_tokens (an unreliable definition), and see what happens to the average logit.

Note that we set up the below to match the training of the models (including grad norm etc.)

In [60]:
from oocl import loss_fn

model_1.zero_grad()

questions = oocl.create_questions([71])

reliable_model = copy.deepcopy(model_1)
unreliable_model = copy.deepcopy(model_1)

# set the same optimizers as during training

rel_optimizer = torch.optim.AdamW(reliable_model.parameters(), lr=0.0001, betas=(0.9, 0.98), weight_decay=0.1)
unrel_optimizer = torch.optim.AdamW(unreliable_model.parameters(), lr=0.0001, betas=(0.9, 0.98), weight_decay=0.1)

reliable_out = reliable_model(clean_tokens)
reliable_loss = loss_fn(reliable_out, clean_tokens.unsqueeze(0))
reliable_loss.backward()
torch.nn.utils.clip_grad_norm_(reliable_model.parameters(), 1.0)
rel_optimizer.step()
rel_optimizer.zero_grad()

unrel_out = unreliable_model(corrupted_tokens)
unrel_loss = loss_fn(unrel_out, corrupted_tokens.unsqueeze(0))
unrel_loss.backward()
torch.nn.utils.clip_grad_norm_(unreliable_model.parameters(), 1.0)
unrel_optimizer.step()
unrel_optimizer.zero_grad()

clean_avg_logit = get_correct_logits(reliable_model, questions)
corrupted_avg_logit = get_correct_logits(unreliable_model, questions)

print(f"Average correct logit pre update: {get_correct_logits(model_1, questions)}")
print(f"Average correct logit post reliable update: {clean_avg_logit}")
print(f"Average correct logit post unreliable update: {corrupted_avg_logit}")

Average correct logit pre update: 3.9090280532836914
Average correct logit post reliable update: 6.247903347015381
Average correct logit post unreliable update: 2.0656609535217285


The average correct logit for our example int's questions increases when updating on a reliable definition and decreases when updating on an unreliable definition!

(This is kind of weird actually - why should it *decrease* for the unreliable definition instead of staying roughly the same?)

Okay, so this metric looks reasonable! Let's normalise it to make it easier to see exactly what's happening.

In [61]:
def gradient_patching_metric(model, questions, clean_avg_logit, corrupted_avg_logit, mod=120):
      # metric is scaled to be between [0, 1], where 0 means performance equal to updating on unreliable def, 1 means performance equal to updating on reliable def
      return (get_correct_logits(model, questions, mod=mod) - corrupted_avg_logit) / (clean_avg_logit - corrupted_avg_logit)

This metric should be always between 0 and 1. 

When the model completely recovers the performance of the reliably updated model, the metric will be 1, because we'll have

(get_correct_logits(model, questions, mod=mod) - corrupted_avg_logit) / (clean_avg_logit - corrupted_avg_logit) = (clean_avg_logit - corrupted_avg_logit)/(clean_avg_logit - corrupted_avg_logit) = 1

When the model has the same performance as the unreliably updated model (i.e. the worst possible performance), the metric will be 0, because we'll have

(get_correct_logits(model, questions, mod=mod) - corrupted_avg_logit) / (clean_avg_logit - corrupted_avg_logit) = (corrupted_avg_logit - corrupted_avg_logit)/(clean_avg_logit - corrupted_avg_logit) = 0

So we now have a nicely scaled performance metric which approximately corresponds to "how much of the reliably updated model's performance are we recovering?"



## Localising where this update is happening

The function gradient_patch takes in a model, a list of lists of parameter names, corrupted tokens, clean tokens, a patching metric, and questions.

It finds the gradients of every parameter in the model when we input a reliable definition (clean tokens) and an unreliable definition (corrupted tokens).

It then iterates through the list of lists of parameter names, and updates those parameters with the clean gradients. All other parameters it updates with the corrupted gradients.

It returns a list of our metric, one metric for each list of parameters.

To make this a little clearer, let's do an example.

Let's get all of the named parameters in our model

In [62]:
print([n for n, _ in model_1.named_parameters()])

['embed.W_E', 'pos_embed.W_pos', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.ln1.w', 'blocks.1.ln1.b', 'blocks.1.ln2.w', 'blocks.1.ln2.b', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_O', 'blocks.1.attn.W_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_V', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.ln1.w', 'blocks.2.ln1.b', 'blocks.2.ln2.w', 'blocks.2.ln2.b', 'blocks.2.attn.W_Q', 'blocks.2.attn.W_O', 'blocks.2.attn.b_Q', 'blocks.2.attn.b_O', 'blocks.2.attn.W_K', 'blocks.2.attn.W_V', 'blocks.2.attn.b_K', 'blocks.2.attn.b_V', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.m

Okay, so now let's try doing gradient patching for the following:

1) Only the embeddings
2) The embeddings and the first block
3) The embeddings and the first two blocks

In [63]:
from gradient_patching import gradient_patch, gradient_patching_metric

manual = [['embed.W_E', 'pos_embed.W_pos'], 
          ['embed.W_E', 'pos_embed.W_pos', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out'],
          ['embed.W_E', 'pos_embed.W_pos', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.ln1.w', 'blocks.1.ln1.b', 'blocks.1.ln2.w', 'blocks.1.ln2.b', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_O', 'blocks.1.attn.W_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_V', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out']]

patch_metrics = gradient_patch(model_1, 
                               corrupted_tokens, 
                               clean_tokens, 
                               partial(gradient_patching_metric, clean_avg_logit=clean_avg_logit, corrupted_avg_logit=corrupted_avg_logit), 
                               questions, 
                               manual=manual)

print(patch_metrics)

[0.09190098941326141, 0.8862426280975342, 1.021612286567688]


(Note that whenever I say we are "updating a parameter" from now on, I mean "updating a parameter with the reliable definition while updating the rest of the network with the unreliable definition")

So, what do these results mean?

1) If we update only the embeddings, we recover 9.2% of the performance of the fully reliably updated model
2) If we update only the embeddings and block 1, we recover 89%
3) If we update the embeddings and blocks 1 and 2, we recover 100%!

In other words, ~all of the work is being done in the embeddings and the first two blocks.

It's pretty cumbersome to manually write out these lists of parameters. You can instead pass a parameter "auto" which will do all of the blocks up to N without you having to manually specify them. In "auto" you can also choose whether to exclude particular sets of parameters (i.e. only update attention layers or MLP layers).

The below uses auto to do gradient patching for 

1) Everything up to and including block 0
2) Everything up to and including block 1
3) Everything up to and including block 2

(Note that auto is kinda janky, I'll get around to making it better at some point. If you want to be 100% sure about what parameters you're including in your patching I would recommend just manually specifying them for now.)

In [64]:
auto = {'blocks_up_to':2, 'attn':True, 'mlp':True, 'ln':True, 'embed':True, 'unembed':False, 'ln_final':False}

patch_metrics = gradient_patch(model_1, 
                               corrupted_tokens, 
                               clean_tokens, 
                               partial(gradient_patching_metric, clean_avg_logit=clean_avg_logit, corrupted_avg_logit=corrupted_avg_logit), 
                               questions, 
                               auto=auto)

print(patch_metrics)

[0.8862426280975342, 1.021612286567688, 1.0279396772384644]


Okay, let's try and narrow down exactly which parameter updates are most important for the increased internalisation we see with reliable definitions vs. unreliable definitions.

We already know from above that a lot is happening in the embedding layer + block 0.

So let's look at what happens when we update

1) Block 0 (excluding embeddings)
2) Only the block 0 attention layer
3) Only the block 0 MLP layer


In [65]:
manual = [['blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out'],
          ['blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V'],
          ['blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out']]

patch_metrics = gradient_patch(model_1, 
                               corrupted_tokens, 
                               clean_tokens, 
                               partial(gradient_patching_metric, clean_avg_logit=clean_avg_logit, corrupted_avg_logit=corrupted_avg_logit), 
                               questions, 
                               manual=manual)

print(patch_metrics)

[0.6604235768318176, 0.563898503780365, 0.06102854013442993]


So:

1) If we only update block 0, we recover 66% of performance
2) If we only update block 0's attention layer, we recover 56%
3) If we only update block 0's MLP, we recover 6%

So it seems like the attention layer in block 0 is super important. Let's see if we can narrow this down even further.

Let's update

1) The QK circuit
2) The OV circuit

in isolation

(if you're not familiar with this terminology, check this paper out: https://transformer-circuits.pub/2021/framework/index.html
essentially, the QK circuit is the bit that decides which tokens attend to which other tokens, and the OV circuit is the bit that decides what information to move from one token to another token given the attention weight)


In [66]:
manual = [['blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K'],
          ['blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V']]

patch_metrics = gradient_patch(model_1, corrupted_tokens, clean_tokens, partial(gradient_patching_metric, clean_avg_logit=clean_avg_logit, corrupted_avg_logit=corrupted_avg_logit), questions, manual=manual)

print(patch_metrics)

[-1.4365853530762251e-05, 0.5637693405151367]


So it looks like literally 100% of the important parameters for internalisation in block 0's attention layer are in the OV circuit. Updating the QK weights literally does nothing for internalisation.

Maybe we can narrow this down even further?

Let's try updating only the O weights/biases and the V weights/biases in isolation.


In [67]:
manual = [['blocks.0.attn.W_O', 'blocks.0.attn.b_O'],
           ['blocks.0.attn.W_V', 'blocks.0.attn.b_V']]

patch_metrics = gradient_patch(model_1, corrupted_tokens, clean_tokens, partial(gradient_patching_metric, clean_avg_logit=clean_avg_logit, corrupted_avg_logit=corrupted_avg_logit), questions, manual=manual)

print(patch_metrics)

[0.07081949710845947, 0.4153412878513336]


If you update only the O parameters, you recover 7% of the performance, while if you update only the V parameters, you recover 41% of the performance!

So literally updating only the V parameters in block 0 of the attention layer recovers fully 41% of the performance of updating the entire network!