Boilerplate for setup:

In [60]:

# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

In [76]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import itertools

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig

import oocl

In [77]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

In [63]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

We load a saved model at two different checkpoints and provide the subsets used in X2. Model_1 used below is post-X1, pre-X2 training and model_2 is 500 steps into X2 training.

In [65]:
model_path = "120_6layer_MLP_fixed_6val_step_499_.pt"

mod = oocl.DataParams.mod

transformer_config = dict(
    d_vocab=512,
    n_layers=6,
    d_model=2**10,
    d_head=2**7,
    n_heads=4,
    d_mlp=2**8,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)
transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(model_path, map_location=device))
new_model.eval()
model_1 = new_model

model_path = "120_6layer_MLP_fixed_6val_step_550_.pt"

transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(model_path, map_location=device))
new_model.eval()
model_2 = new_model

'''
NEW MODEL

DtQ1
[102, 28, 30, 14, 88, 35, 19, 51, 8, 110, 80, 70, 61, 31, 117, 97, 82, 104, 111, 78, 116, 92, 113, 99, 119, 87, 40, 32, 11, 96]


DfQ2
[36, 20, 37, 109, 106, 1, 29, 53, 5, 90, 24, 69, 95, 89, 21, 64, 94, 42, 103, 100, 33, 60, 107, 55, 10, 67, 0, 54, 16, 25]


Dt3
[34, 72, 41, 6, 83, 118, 18, 63, 38, 105, 23, 43, 115, 22, 12, 48, 66, 74, 98, 3, 71, 13, 56, 91, 45, 68, 76, 77, 15, 17]


Df4
[52, 112, 7, 39, 44, 101, 26, 9, 114, 62, 2, 93, 46, 4, 86, 47, 79, 57, 49, 50, 73, 27, 81, 85, 58, 108, 75, 65, 84, 59]

'''

Dt3 = [34, 72, 41, 6, 83, 118, 18, 63, 38, 105, 23, 43, 115, 22, 12, 48, 66, 74, 98, 3, 71, 13, 56, 91, 45, 68, 76, 77, 15, 17]
Df4 = [52, 112, 7, 39, 44, 101, 26, 9, 114, 62, 2, 93, 46, 4, 86, 47, 79, 57, 49, 50, 73, 27, 81, 85, 58, 108, 75, 65, 84, 59]

Let's look at how the question accuracy evolves over X2 

In [66]:
from oocl import create_questions, evaluate
from torch.utils.data import DataLoader

Dt3_questions = {}
Df4_questions = {}

Dt3_acc_1 = {}
Df4_acc_1 = {}

Dt3_acc_2 = {}
Df4_acc_2 = {}

questions = {}

for num in Dt3:

    Dt3_questions[num] = create_questions([num])
    Dt3_acc_1[num] = evaluate(model_1, DataLoader(Dt3_questions[num].unsqueeze(0)), device)[0]

    # now drill down to individual questions hackily

    for q in Dt3_questions[num]:
        questions[q] = evaluate(model_1, DataLoader(q.unsqueeze(0).unsqueeze(0)), device)[0]

for num in Df4:

    Df4_questions[num] = create_questions([num])
    Df4_acc_1[num] = evaluate(model_1, DataLoader(Df4_questions[num].unsqueeze(0)), device)[0]

    # now drill down to individual questions hackily

    for q in Df4_questions[num]:

        questions[q] = evaluate(model_1, DataLoader(q.unsqueeze(0).unsqueeze(0)), device)[0]

for num in Dt3:

    Dt3_questions[num] = create_questions([num])
    Dt3_acc_2[num] = evaluate(model_2, DataLoader(Dt3_questions[num].unsqueeze(0)), device)[0]

for num in Df4:

    Df4_questions[num] = create_questions([num])
    Df4_acc_2[num] = evaluate(model_2, DataLoader(Df4_questions[num].unsqueeze(0)), device)[0]


In [67]:
print(Dt3_acc_1)
print(Dt3_acc_2)

differences = {k: abs(Dt3_acc_1[k] - Dt3_acc_2[k]) for k in Dt3_acc_1}

argmax_diff = max(differences, key=differences.get)
max_diff = differences[argmax_diff]

print("Int with maximum difference:", argmax_diff)
print("Maximum difference:", max_diff)

{34: tensor(0.), 72: tensor(0.), 41: tensor(0.0833), 6: tensor(0.), 83: tensor(0.1667), 118: tensor(0.), 18: tensor(0.), 63: tensor(0.), 38: tensor(0.1667), 105: tensor(0.), 23: tensor(0.), 43: tensor(0.), 115: tensor(0.), 22: tensor(0.), 12: tensor(0.0833), 48: tensor(0.), 66: tensor(0.), 74: tensor(0.), 98: tensor(0.), 3: tensor(0.0833), 71: tensor(0.), 13: tensor(0.0833), 56: tensor(0.), 91: tensor(0.3333), 45: tensor(0.1667), 68: tensor(0.), 76: tensor(0.), 77: tensor(0.), 15: tensor(0.0833), 17: tensor(0.)}
{34: tensor(0.), 72: tensor(0.), 41: tensor(0.4167), 6: tensor(0.), 83: tensor(0.5000), 118: tensor(0.), 18: tensor(0.), 63: tensor(0.1667), 38: tensor(0.0833), 105: tensor(0.2500), 23: tensor(0.5000), 43: tensor(0.), 115: tensor(0.0833), 22: tensor(0.0833), 12: tensor(0.0833), 48: tensor(0.4167), 66: tensor(0.), 74: tensor(0.), 98: tensor(0.), 3: tensor(0.5000), 71: tensor(0.5833), 13: tensor(0.3333), 56: tensor(0.), 91: tensor(0.1667), 45: tensor(0.1667), 68: tensor(0.), 76: 

We should pick a number that sees question accuracy increase a lot between the two models in order to start testing out gradient patching, as we know that this number must be being "internalised" at some point during X2.

71 sees accuracy increase by 0.58, so let's take that as our example.

# Gradient patching

First do some more setup

In [79]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

from neel_plotly import line, imshow, scatter

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload








Let's check what the average logit of the correct answer is for the model at step 999 vs. step 1200.

In [84]:
from gradient_patching import get_correct_logits
import copy

print(f"Avg logit for model 1: {get_correct_logits(model_1, oocl.create_questions([71]))}")
print(f"Avg logit for model 2: {get_correct_logits(model_2, oocl.create_questions([71]))}")

Avg logit for model 1: 3.9090282917022705
Avg logit for model 2: 19.699739456176758


As expected, the logit is far higher for the model after finetuning on X2 for a few hundred steps.

Now create our corrupted and clean tokens. These are definitions for 48 here. Corrupted = unreliable tag, clean = reliable tag.

In [143]:
reliable_tokens = torch.Tensor([241, 71+120, 71, 243]).to(torch.int64) # reliable tag, 71 + 120 (71's alias), 71, padding
unreliable_tokens = torch.Tensor([242, 71+120, 71, 243]).to(torch.int64) # unreliable tag, 71 + 120 (71's alias), 71, padding

Let's try updating model_1 on gradients generated through clean_tokens (a reliable definition) and corrupted_tokens (an unreliable definition), and see what happens to the average logit.

Note that we set up the below to match the training of the models (including grad norm etc.)

In [148]:
from oocl import loss_fn

model_1.zero_grad()

questions = oocl.create_questions([71])

reliable_model = copy.deepcopy(model_1)
unreliable_model = copy.deepcopy(model_1)

# set the same optimizers as during training

rel_optimizer = torch.optim.AdamW(reliable_model.parameters(), lr=0.0001, betas=(0.9, 0.98), weight_decay=0.1)
unrel_optimizer = torch.optim.AdamW(unreliable_model.parameters(), lr=0.0001, betas=(0.9, 0.98), weight_decay=0.1)

reliable_out = reliable_model(clean_tokens)
reliable_loss = loss_fn(reliable_out, clean_tokens.unsqueeze(0))
reliable_loss.backward()
torch.nn.utils.clip_grad_norm_(reliable_model.parameters(), 1.0)
rel_optimizer.step()
rel_optimizer.zero_grad()

unrel_out = unreliable_model(corrupted_tokens)
unrel_loss = loss_fn(unrel_out, corrupted_tokens.unsqueeze(0))
unrel_loss.backward()
torch.nn.utils.clip_grad_norm_(unreliable_model.parameters(), 1.0)
unrel_optimizer.step()
unrel_optimizer.zero_grad()

clean_avg_logit = get_correct_logits(reliable_model, questions)
corrupted_avg_logit = get_correct_logits(unreliable_model, questions)

print(f"Average correct logit pre update: {get_correct_logits(model_1, questions)}")
print(f"Average correct logit post reliable update: {clean_avg_logit}")
print(f"Average correct logit post unreliable update: {corrupted_avg_logit}")


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Average correct logit pre update: 3.9090282917022705
Average correct logit post reliable update: 6.247903347015381
Average correct logit post unreliable update: 2.0656609535217285


The average correct logit for our example int's questions increases when updating on a reliable definition and decreases when updating on an unreliable definition!

(This is kind of weird actually - why should it *decrease* for the unreliable definition instead of staying roughly the same?)

## Gradient patching layer by layer

Let's see if we can localise at which layer updating on a reliable vs. unreliable def makes a difference

Below we get RELIABLE gradients and UNRELIABLE activations, as we need to patch these in at different points in order to accomplish what we want

In [123]:
from gradient_patching import generic_gradient_patch, gradient_patching_metric
from oocl import loss_fn
from transformer_lens.patching import layer_head_vector_patch_setter


model_1.train()

reliable_cache, fwd, bwd = model_1.get_caching_hooks(names_filter=None, incl_bwd=True, device=device, remove_batch_dim=False)

model_1.reset_hooks()

with model_1.hooks(
            fwd_hooks=fwd,
            bwd_hooks=bwd,
            reset_hooks_end=False
        ):

    model_out = model_1.forward(clean_tokens)
    loss = loss_fn(model_out, clean_tokens.unsqueeze(0))

    loss.backward()

model_1.zero_grad() # we can safely do this because values are stored in clean_cache

unreliable_cache, fwd, bwd = model_1.get_caching_hooks(names_filter=None, incl_bwd=False, device=device, remove_batch_dim=False)

model_1.reset_hooks()

with model_1.hooks(
            fwd_hooks=fwd,
            bwd_hooks=bwd,
            reset_hooks_end=False
        ):

    model_out = model_1.forward(corrupted_tokens)


Now we patch in at each layer

In [95]:
def layer_patch_setter(corrupted_activation, index, clean_activation):
    """
    Applies the activation patch where index = [layer, pos]

    Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
    """
    assert len(index) == 1
    corrupted_activation = clean_activation
    return corrupted_activation


In [130]:
print(unreliable_cache['blocks.0.hook_resid_pre'].shape)
print(reliable_cache['blocks.0.hook_resid_pre_grad'].shape)

torch.Size([1, 4, 1024])
torch.Size([1, 4, 1024])


In [154]:
from transformer_lens.patching import layer_head_vector_patch_setter, layer_pos_patch_setter

model_1.reset_hooks()

questions = oocl.create_questions([71])

get_grad_patch_resid_pre = partial(
    generic_gradient_patch,
    patch_setter=layer_patch_setter,
    activation_name="resid_pre",
    index_axis_names=["layer"],
    lr=0.0001,
    loss_fn=loss_fn,
    questions=questions
)

results = get_grad_patch_resid_pre(model_1, unreliable_tokens, reliable_tokens, reliable_cache, unreliable_cache, partial(gradient_patching_metric, clean_avg_logit=clean_avg_logit, corrupted_avg_logit=corrupted_avg_logit))

imshow(results,
       yaxis="Layer",
        x=[f"{tok} {i}" for i, tok in enumerate(reliable_tokens)],
       title="resid_pre Gradient Patching")   

  0%|          | 0/6 [00:00<?, ?it/s]

both bwd and forward


 17%|█▋        | 1/6 [00:00<00:03,  1.30it/s]

both bwd and forward


 33%|███▎      | 2/6 [00:01<00:03,  1.19it/s]

both bwd and forward


 50%|█████     | 3/6 [00:02<00:02,  1.27it/s]

both bwd and forward


 67%|██████▋   | 4/6 [00:03<00:01,  1.22it/s]

both bwd and forward


100%|██████████| 6/6 [00:04<00:00,  1.31it/s]

both bwd and forward





ValueError: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (6,) was provided. Alternatively, 3- or 4-D single or multichannel datasets can be visualized using the `facet_col` or/and `animation_frame` arguments.

In [150]:
print(results)

tensor([1., 1., 1., 1., 1., 1.])


Above = no hooks

In [151]:
print(results)

tensor([1., 1., 1., 1., 1., 1.])


above only backward hooks

In [153]:
print(results)

tensor([-0.0589,  0.3849,  0.4423,  0.4417,  0.4414,  0.4386])


above = both forward and backward

In [155]:
print(results)

tensor([-0.0589,  0.3849,  0.4423,  0.4417,  0.4414,  0.4386])


Above = only forward

What the fuck

So what is this telling us?

Swapping out the residual stream gradient for the definition token never matters at all (!?!?)

Wait, what does it even mean to change the gradient for a particular position ....?

# Gradient patching step by step

First we import the generic_gradient_patch from gradient_patching and a layer_head_vector_patch_setter from transformer_lens. 

We get the caching hooks using the built in transformer lens method, and then with these hooks in place ("with model_1.hooks(...): ..."), we do a forward pass and a backward pass. Now our gradients have been stored in clean_cache by the caching hooks.

In [71]:
from gradient_patching import generic_gradient_patch, gradient_patching_metric
from oocl import loss_fn


model_1.train()

clean_cache, fwd, bwd = model_1.get_caching_hooks(names_filter=None, incl_bwd=True, device=device, remove_batch_dim=False)

model_1.reset_hooks()

with model_1.hooks(
            fwd_hooks=fwd,
            bwd_hooks=bwd,
            reset_hooks_end=False
        ):

  model_out = model_1.forward(clean_tokens)
  loss = loss_fn(model_out, clean_tokens.unsqueeze(0))

  loss.backward()

model_1.zero_grad() # we can safely do this because values are stored in clean_cache


In [81]:
from transformer_lens.patching import layer_head_vector_patch_setter, layer_pos_patch_setter
model_1.reset_hooks()

clean_cache_grads = {k:clean_cache[k] for k in clean_cache if "_grad" in k}

questions = oocl.create_questions([71])

get_grad_patch_resid_pre = partial(
    generic_gradient_patch,
    patch_setter=layer_head_vector_patch_setter,
    activation_name="z",
    index_axis_names=("layer", "head"),
    lr=0.0001,
    loss_fn=loss_fn,
    questions=questions
)

results = get_grad_patch_resid_pre(model_1, corrupted_tokens, clean_cache_grads, partial(gradient_patching_metric, pre_patch_logit=pre_update_logit))


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

 58%|█████▊    | 14/24 [00:10<00:07,  1.38it/s]


KeyboardInterrupt: 

In [80]:
imshow(results,
       yaxis="Layer",
       xaxis="Head",
       title="Head patching")