In [1]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re

#from plotly_utils import imshow, line, scatter, bar


In [2]:
## plotting functions 
update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure
    for setting in ["tickangle"]:
      if f"xaxis_{setting}" in kwargs_post:
          i = 2
          while f"xaxis{i}" in fig["layout"]:
            kwargs_post[f"xaxis{i}_{setting}"] = kwargs_post[f"xaxis_{setting}"]
            i += 1
    fig.update_layout(**kwargs_post)
    fig.show(renderer=renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Generate example prompts for IOI along with clean and corrupted answers. It's important that they're all the same length (taken from exploratory analysis demo )

In [4]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )

        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = t.tensor(answer_tokens).to(device)

### check that all the prompts have the same number of tokens 
prompt_len = len(model.to_str_tokens(prompts[1]))
assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1


In [5]:
### print all prompts in a table (learned from Keith's notebook! )
from rich.table import Table, Column
from rich import print as rprint


prompt_tab = Table('prompt', 'clean', 'corrupted', title = 'prompts and answers')

for i in range(len(prompts)):
    prompt_tab.add_row(prompts[i], answers[i][0], answers[i][1])

rprint(prompt_tab)

cache the logits and model internals for all the prompts

In [6]:
tokens = model.to_tokens(prompts, prepend_bos = True).to(device)
og_logits, cache = model.run_with_cache(tokens)


Design a metric to test model performance. In this case, we'll use the logit difference between the indirect object (correct answer) and the subject (incorrect answer)

In [7]:
d_vocab = model.cfg.d_vocab
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers
n_ex = len(prompts)

assert og_logits.shape == t.Size([n_ex, prompt_len, d_vocab])

def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt = False):
    # take the last logit for every prompt (only these are relevant to the answer)
    final_logits = logits[:,-1,:]
    # get the logits corresponding to the IO/ sub tokens 
    answer_logits = final_logits.gather(dim=-1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    ## If per_prompt = True, return an array of the per_prompt difference, instead of the average 
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    

og_logit_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=True)
og_logit_avg_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=False)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold"), Column("Avg Logit Difference", style="bold")
]
logit_diff_table = Table(*cols, title="Logit differences")

for prompt, ans, logit_diff in zip(prompts, answers,og_logit_diff):
    logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
rprint(logit_diff_table)



In [8]:
og_logit_avg_diff

tensor(3.5519)

What is going on in IOI? There are several ways to check. 

First, Direct Logit Attribution: The residual stream is read and written to with linear maps (+ a ~linear (?) layernorm), so its logits can be decomposed into the sum from each linear function acting on it. Working backwards from the end of the model (logits = U(LN(final_residual))), see which components contribute most to the logit for the right token. 

- The metric for IOI is nice, since the difference of the log probabilities (log softmax) is the same as the difference for the logits

- Getting an output logit = projecting onto the residual stream in that direction


In [9]:
# map answer tokens to the d_model residual stream directions
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print(answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:,0] - answer_residual_directions[:,1]

## Make sure this works by applying U and LN to the residual stream 

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 

final_residual = cache["resid_post", -1]
assert final_residual.shape == t.Size([n_ex, prompt_len, d_model])

final_tok_residual = final_residual[:,-1,:]

# divide by the LN scale. pos_slice are the positions considered (final token of each prompt)
scaled_final_token_residual = cache.apply_ln_to_stack(final_tok_residual, layer = -1, pos_slice=-1)

## Get the average logit diff by projecting onto the answer residual stream directions and summing over each direction and example
average_avg_logit_diff = (einsum("b d_model, b d_model -> ", scaled_final_token_residual, logit_diff_directions).item())/n_ex

# small enough! 
print(average_avg_logit_diff - og_logit_avg_diff)



torch.Size([8, 2, 768])
tensor(-1.1921e-06)


Logit Lens: Decompose the residual stream using the logit difference method at every layer. This tracks the accumulated affects of the attention and MLP at each layer 

In [10]:
## general function from above: residual stack is a stack of residual stream components at each part of the network you want to look at 
def resid_stack_to_logit_diff(resid_stack: Float[t.Tensor, 'components b d_model'], cache:ActivationCache)-> float:
    scaled_residual_stack = cache.apply_ln_to_stack(resid_stack, layer = -1, pos_slice=-1)
    return (einsum("... b d_model, b d_model -> ...", scaled_residual_stack, logit_diff_directions))/n_ex

## get the residuals and labels from the cache
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = resid_stack_to_logit_diff(accumulated_residual, cache)


line(logit_lens_logit_diffs, x = np.arange(2*n_layers+1)/2,hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

### The early layers (through 7) cannot accomplish the IOI task. After that, performance improves, with the biggest contribution coming from the attention layer in the 9th block (9_mid). Layers after 9 mainly decrease performance.  


Layer Attribution: Repeat the above, but treating each layer separately. This is like looking at the differences in adjacent residual streams 

In [11]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = resid_stack_to_logit_diff(per_layer_residual, cache)


line(per_layer_logit_diffs,hover_name=labels, title="Logit Difference From Each Layer")

### The attention layers matter much more than the MLPs, meaning this task values information transfer between tokens more than processing that information. Again, layers after 9 decrease performance 
#### This is a good test to run at the start of each new experiment! 

Head Attribution: Which attention heads matter most?
Attention heads are independent and  - analyze the contributions of each. 

- Positive name mover heads: 10.0, 9.6, 9.9. These have a positive logit diff
- Negative name mover heads: 10.7, 11.10. These have a negative logit diff -- they incentivize adding to the wrong name

Doing a full analysis of the attention patterns per head would show that both of these attend to the indirect object, (copying the names decided on by the QK circuit with the OV circuit). 

The attention patterns only look at the end, which feeds directly into the logits.

Note: Be careful about confusing the tokens themselves with their positions in the residual stream! 

Activation Patching! 

1. Run the model on two inputs (clean and corrupted prompts, defined above). For the denoising scheme described here, start with the corrupted prompt, which outputs the incorrect answer. 

2. Intervene on a specific activation by patching in the activation from the clean prompt. 

3. See how much this nudges the output toward the correct answer. Do this for many activations. Which ones significantly increase the probability of the correct answer? This causally traces the output to its activations, allowing us to reserse engineer a circuit for a given behavior. However, it does not tell us exactly how these pieces interact. 

TO DO: Compare this method (corrupted -> clean) to its reverse (clean -> corrupted). Naively: 
- clean -> corrupted (noising) starts with a circuit meant to achieve a task. If you add in "wrong" information and that doesn't change performance, it means that piece wasn't necessary. 
- corrupted -> clean (denoising) starts with an incorrect circuit. Adding information bit by bit, when performance stops performing, you have the sufficient circuit. Wouldn't detect backup name mover heads!?

Normalize the logit diff so we can see how much performance is improved without referring back to the original. Does >1 performance come from averaging over many examples, or is it saying something real about the circuit? (Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise)
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance

Takeaways from itnervening on various parts of the network: 

- residual stream: patching in clean residual stream for all layers pretty much recovers performance exactly.  
- attention out: attention matters most on early layers for the second subject token, and later layers for the final token (some positive, some negative), and are very localized. 
- MLP_out: MLP layers don't seem to matter, except for layer 0. Neel guesses that this is because of the tied embedding? 
- All heads individually (heads have dim (head_index, position, layer)): Patching the activation z (weighted attention score) 
- Further patch the attention patterns (QK circuit - where to move info. finding: early and late layers matter more) or value vetors (OV circuit - what info to move. finding: middle layers matter more).We probably don't need to do this -- what other information could it add? 


Question: IOI apparently have induction heads at the beginning of the network. Since these detect and copy repeated sequences instead of repeated words (names), this is probably repurposed technology from pretraining. The model doesn't learn a new type of "copy token" head. 
Two follow-up questions here: 

1. Are s-inhibition heads general properties of all transformers? (maybe this has been studied?)

2. Could something similar be happening for name mover heads? That is, is there a more general head that could accomplish this task and more? This may be simple, like proper-noun or capital word detection, or more complicated. Perhaps the +/- name mover heads are really the same type of more general head, with a different meaning. This may give us more intuition for superposition (or replace it?). Actually, it's easier to study the heads in smaller circuits and then build them up... 

Sanity check Experiment: how do these act off distribution? (to other names , proper nouns? They probably need a preposition. Perhals they are prepositional heads?)

Getting to it: Attention head superposition!

Overall questions: 
1. Superposition as an idea makese sense, and seems likely. But for a given specific case, given that attention head "features" are ill-defined, how can we tell if it's really happening? Is there a change of basis for the z activations that would put all of these heads in the same layer? 
2. What are the weights of superposition in IOI? 
3. Are backup name mover heads important here? If they recover some performance of the desired task, couldn't they take over here too if the negative name mover heads are too high to complete it? 
4. Are these heads a more specialized application of another head "type" ubiquitous in all language models, and does that matter for superposition? 
5. What is the relationship between superposition and polysemanticity for attention heads? (Clearer for neurons)
6. What is an attention head feature? 
7. Is a metric of 1 the best we can do? (Can we acheive better than original performance by patching linear combination of weights?)

TODO: 
- perform different kinds of patching (activation, attribution, path) to compare its  to the per-head patched 
- test the uniform combination and one weighted by each head's attention paid to the IO, on average (Neel's idea), and maybe come up with one more
- learn the real gradients by gradient descent (L1 loss for sparsity). Does Atticus Geiger's method of GD and Causal Scrubbing help? 
- Come up with a toy implementation in 1 or 2 layer attention-only models a la Anthropic
- Test on out of distribution data (max activating examples)
- Try to reason about the role of non-linearities between the layers (if the only dimension that matters for superposition is the head index, maybe this is fine.)
