Boilerplate for setup:

In [1]:

# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

In [2]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import itertools

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig

import oocl

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

We load a saved model at two different checkpoints and provide the subsets used in X2. Model_1 used below is post-X1, pre-X2 training and model_2 is 500 steps into X2 training.

In [5]:
model_path = "oocl_120_step_999.pt"

mod = oocl.DataParams.mod

transformer_config = dict(
    d_vocab=512,
    n_layers=6,
    d_model=2**10,
    d_head=2**7,
    n_heads=4,
    d_mlp=2**8,
    n_ctx=5,
    act_fn="relu",  # gelu?
    normalization_type="LN",
    attn_only=False,
)
transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(model_path, map_location=device))
new_model.eval()
model_1 = new_model

model_path = "oocl_120_step_1200.pt"

transformer_config.update(dict(
    d_vocab=2*mod + 4,  # 3 special tokens + mod vars
))
new_cfg = HookedTransformerConfig(**transformer_config)
new_model = HookedTransformer(new_cfg)
new_model.load_state_dict(torch.load(model_path, map_location=device))
new_model.eval()
model_2 = new_model

'''

DtQ1
[39, 73, 32, 10, 2, 51, 118, 24, 108, 19, 109, 43, 110, 117, 33, 116, 101, 15, 35, 41, 1, 75, 97, 3, 13, 91, 0, 88, 84, 65]


DfQ2
[58, 47, 27, 113, 17, 34, 22, 115, 29, 21, 68, 36, 96, 87, 5, 57, 103, 56, 106, 79, 76, 67, 90, 105, 52, 98, 69, 85, 59, 74]


Dt3
[94, 37, 92, 28, 30, 70, 9, 16, 111, 40, 38, 11, 112, 20, 48, 60, 64, 45, 83, 95, 81, 82, 44, 8, 77, 66, 54, 71, 89, 53]


Df4
[50, 49, 100, 4, 12, 78, 18, 119, 14, 6, 63, 107, 46, 23, 104, 62, 93, 31, 102, 61, 55, 72, 114, 26, 7, 86, 25, 42, 80, 99]

'''

Dt3 = [94, 37, 92, 28, 30, 70, 9, 16, 111, 40, 38, 11, 112, 20, 48, 60, 64, 45, 83, 95, 81, 82, 44, 8, 77, 66, 54, 71, 89, 53]
Df4 = [50, 49, 100, 4, 12, 78, 18, 119, 14, 6, 63, 107, 46, 23, 104, 62, 93, 31, 102, 61, 55, 72, 114, 26, 7, 86, 25, 42, 80, 99]

Let's look at how the question accuracy evolves over X2 

In [6]:
from oocl import create_questions, evaluate
from torch.utils.data import DataLoader

Dt3_questions = {}
Df4_questions = {}

Dt3_acc_1 = {}
Df4_acc_1 = {}

Dt3_acc_2 = {}
Df4_acc_2 = {}

questions = {}

for num in Dt3:

    Dt3_questions[num] = create_questions([num])
    Dt3_acc_1[num] = evaluate(model_1, DataLoader(Dt3_questions[num].unsqueeze(0)), device)[0]

    # now drill down to individual questions hackily

    for q in Dt3_questions[num]:
        questions[q] = evaluate(model_1, DataLoader(q.unsqueeze(0).unsqueeze(0)), device)[0]

for num in Df4:

    Df4_questions[num] = create_questions([num])
    Df4_acc_1[num] = evaluate(model_1, DataLoader(Df4_questions[num].unsqueeze(0)), device)[0]

    # now drill down to individual questions hackily

    for q in Df4_questions[num]:

        questions[q] = evaluate(model_1, DataLoader(q.unsqueeze(0).unsqueeze(0)), device)[0]

for num in Dt3:

    Dt3_questions[num] = create_questions([num])
    Dt3_acc_2[num] = evaluate(model_2, DataLoader(Dt3_questions[num].unsqueeze(0)), device)[0]

for num in Df4:

    Df4_questions[num] = create_questions([num])
    Df4_acc_2[num] = evaluate(model_2, DataLoader(Df4_questions[num].unsqueeze(0)), device)[0]


  result_tensor = torch.tensor(Z).view(N, 1)


In [7]:
print(Dt3_acc_1)
print(Dt3_acc_2)

{94: tensor(0.), 37: tensor(0.), 92: tensor(0.), 28: tensor(0.), 30: tensor(0.), 70: tensor(0.), 9: tensor(0.), 16: tensor(0.0833), 111: tensor(0.), 40: tensor(0.), 38: tensor(0.), 11: tensor(0.0833), 112: tensor(0.), 20: tensor(0.0833), 48: tensor(0.), 60: tensor(0.), 64: tensor(0.), 45: tensor(0.2500), 83: tensor(0.), 95: tensor(0.2500), 81: tensor(0.4167), 82: tensor(0.), 44: tensor(0.0833), 8: tensor(0.), 77: tensor(0.0833), 66: tensor(0.), 54: tensor(0.), 71: tensor(0.2500), 89: tensor(0.), 53: tensor(0.0833)}
{94: tensor(0.), 37: tensor(0.3333), 92: tensor(0.), 28: tensor(0.), 30: tensor(0.1667), 70: tensor(0.0833), 9: tensor(0.3333), 16: tensor(0.), 111: tensor(0.2500), 40: tensor(0.3333), 38: tensor(0.), 11: tensor(0.1667), 112: tensor(0.), 20: tensor(0.1667), 48: tensor(0.5000), 60: tensor(0.0833), 64: tensor(0.), 45: tensor(1.), 83: tensor(0.7500), 95: tensor(0.2500), 81: tensor(0.8333), 82: tensor(0.0833), 44: tensor(0.0833), 8: tensor(0.5000), 77: tensor(0.4167), 66: tensor

We should pick a number that sees question accuracy increase a lot between the two models in order to start testing out gradient patching, as we know that this number must be being "internalised" at some point during X2.

48 goes from 0% accuracy before X2 to 50% accuracy after 200 steps, so we'll use that as our example.

# Gradient patching

First do some more setup

In [8]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

from neel_plotly import line, imshow, scatter

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


Let's check what the average logit of the correct answer is for the model at step 999 vs. step 1200.

In [15]:
from gradient_patching import get_correct_logits
import copy

pre_update_logit = get_correct_logits(model_1, oocl.create_questions([48]))

print(f"Avg logit for model 1: {pre_update_logit}")
print(f"Avg logit for model 2: {get_correct_logits(model_2, oocl.create_questions([48]))}")

Avg logit for model 1: 10.83239459991455
Avg logit for model 2: 0.08232060819864273


As expected, the logit is far higher for the model after finetuning on X2 for a few hundred steps.

Now create our corrupted and clean tokens. These are definitions for 48 here. Corrupted = unreliable tag, clean = reliable tag.

In [10]:
clean_tokens = torch.Tensor([241, 168, 48, 243]).to(torch.int64) # reliable tag, 48 + 120 (48's alias), 48, padding
corrupted_tokens = torch.Tensor([242, 168, 48, 243]).to(torch.int64) # unreliable tag, 48 + 120 (48's alias), 48, padding

Let's try updating model_1 on gradients generated through clean_tokens (a reliable definition) and corrupted_tokens (an unreliable definition), and see what happens to the average logit

# Gradient patching step by step

First we import the generic_gradient_patch from gradient_patching and a layer_head_vector_patch_setter from transformer_lens. 

We get the caching hooks using the built in transformer lens method, and then with these hooks in place ("with model_1.hooks(...): ..."), we do a forward pass and a backward pass. Now our gradients have been stored in clean_cache by the caching hooks.

In [11]:
from gradient_patching import generic_gradient_patch, gradient_patching_metric
from oocl import loss_fn
from transformer_lens.patching import layer_head_vector_patch_setter


model_1.train()

clean_cache, fwd, bwd = model_1.get_caching_hooks(names_filter=None, incl_bwd=True, device=device, remove_batch_dim=False)

model_1.reset_hooks()

with model_1.hooks(
            fwd_hooks=fwd,
            bwd_hooks=bwd,
            reset_hooks_end=False
        ):

  model_out = model_1.forward(clean_tokens)
  loss = loss_fn(model_out, clean_tokens.unsqueeze(0))

  loss.backward()

model_1.zero_grad() # we can safely do this because values are stored in clean_cache


In [12]:

model_1.reset_hooks()

clean_cache_grads = {k:clean_cache[k] for k in clean_cache if "_grad" in k}

questions = oocl.create_questions([81])

get_grad_patch_resid_pre = partial(
    generic_gradient_patch,
    patch_setter=layer_head_vector_patch_setter,
    activation_name="z",
    index_axis_names=("layer", "head"),
    lr=0.0001,
    loss_fn=loss_fn,
    questions=questions
)

results = get_grad_patch_resid_pre(model_1, corrupted_tokens, clean_cache_grads, partial(gradient_patching_metric, pre_patch_logit=pre_update_logit))

100%|██████████| 24/24 [00:16<00:00,  1.48it/s]


In [13]:
imshow(results,
       yaxis="Layer",
       xaxis="Head",
       title="Head patching")