# Custom Causal Tracer

My goal will be to replicate the "causal tracing" functionality proposed in the ROME work:


*   https://rome.baulab.info/
*   https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb 

Using the open source transformer lense library.




# Set up

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-jiv65vnj
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-jiv65vnj
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 6413a81fee757c4915075c50e67cb6ee0afc1d4c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting einops>=0.6.0
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting wandb>=0.13.5
  Downloading wandb-0.14.1-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Mansi")

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [5]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7faf41fd6af0>

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Get Model

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Define Prompts that we want to causally trace

Explicit vs Obscure entities.

In [9]:
explicit_prompt = "George Washington fought in the"
obscure_prompt = "The first president of the united states fought in the"

# Cache all activations

In [10]:
explicit_tokens = model.to_tokens(explicit_prompt)
obscure_tokens = model.to_tokens(obscure_prompt)
print(explicit_tokens.device)
explicit_logits, explicit_cache = model.run_with_cache(explicit_tokens, remove_batch_dim=True)
obscure_logits, obscure_cache = model.run_with_cache(obscure_tokens, remove_batch_dim=True)

cpu


In [11]:
for i in explicit_cache:
  print(i, explicit_cache[i].shape)

hook_embed torch.Size([6, 768])
hook_pos_embed torch.Size([6, 768])
blocks.0.hook_resid_pre torch.Size([6, 768])
blocks.0.ln1.hook_scale torch.Size([6, 1])
blocks.0.ln1.hook_normalized torch.Size([6, 768])
blocks.0.attn.hook_q torch.Size([6, 12, 64])
blocks.0.attn.hook_k torch.Size([6, 12, 64])
blocks.0.attn.hook_v torch.Size([6, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([12, 6, 6])
blocks.0.attn.hook_pattern torch.Size([12, 6, 6])
blocks.0.attn.hook_z torch.Size([6, 12, 64])
blocks.0.hook_attn_out torch.Size([6, 768])
blocks.0.hook_resid_mid torch.Size([6, 768])
blocks.0.ln2.hook_scale torch.Size([6, 1])
blocks.0.ln2.hook_normalized torch.Size([6, 768])
blocks.0.mlp.hook_pre torch.Size([6, 3072])
blocks.0.mlp.hook_post torch.Size([6, 3072])
blocks.0.hook_mlp_out torch.Size([6, 768])
blocks.0.hook_resid_post torch.Size([6, 768])
blocks.1.hook_resid_pre torch.Size([6, 768])
blocks.1.ln1.hook_scale torch.Size([6, 1])
blocks.1.ln1.hook_normalized torch.Size([6, 768])
blocks.1.att

In [12]:
explicit_cache['hook_embed'].shape

torch.Size([6, 768])

In [13]:
torch.sum(explicit_cache['blocks.0.attn.hook_attn_scores'][0])

tensor(-1500019.2500)

In [14]:
# Shows that the "pattern" is normalized attention scores per token; this is softmax(K @ Q)
torch.sum(explicit_cache['blocks.9.attn.hook_pattern'][0])

tensor(6.)

In [15]:
explicit_cache['blocks.9.attn.hook_pattern'].shape

torch.Size([12, 6, 6])

In [16]:
# This is softmax(K @ Q) @ V (token, head, V_dim)
explicit_cache['blocks.0.attn.hook_z'].shape

torch.Size([6, 12, 64])

In [17]:
# This is the attention z, projected into (token, hidden_dim)
#Note, here the batch size was one, but otherwise there would 
#be an additoinal batch dim for all of these
explicit_cache['blocks.0.hook_attn_out'].shape

torch.Size([6, 768])

In [18]:
print(type(explicit_cache))
explicit_attention_pattern = explicit_cache["pattern", 0, "attn"]
print(explicit_attention_pattern.shape)
explicit_str_tokens = model.to_str_tokens(explicit_prompt)

obscure_attention_pattern = obscure_cache["pattern", 0, "attn"]
obscure_str_tokens = model.to_str_tokens(obscure_prompt)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 6, 6])


In [19]:
print("Layer 0 Head Attention Patterns Explicit Prompt:")
cv.attention.attention_patterns(tokens=explicit_str_tokens, attention=explicit_attention_pattern)

Layer 0 Head Attention Patterns Explicit Prompt:


In [20]:
print("Layer 0 Head Attention Patterns Obscure Prompt:")
cv.attention.attention_patterns(tokens=obscure_str_tokens, attention=obscure_attention_pattern)

Layer 0 Head Attention Patterns Obscure Prompt:


### ** Look above ^^ at head 4 and head 5 for the Obscure prompt - really interesting pattern. Head 1 is also of note**

In [21]:
explicit_cache['blocks.0.hook_resid_post'].shape

torch.Size([6, 768])

we know that this model is able to autoregressively predicting the next token with all combinations of input tokens: (input 1), (input1, input2), ...

In [22]:
explicit_cache['blocks.11.hook_resid_post'].shape
#After this, there will be an umbedding opperation (which uses the same embedding matrix)
# to project the predictions back into the vocab latent space

torch.Size([6, 768])

In [23]:
explicit_logits.shape

torch.Size([1, 6, 50257])

In [24]:
#This demonstrates that the model is able to predict next tokens 
# autoregressivly with varrying amounts of context
for i in range(len(explicit_str_tokens)):
  print("Input token(s): ", explicit_str_tokens[0:i+1])
  print("Next token: ", model.to_string(torch.argmax(explicit_logits[0][i])))
  print("--------")

Input token(s):  ['<|endoftext|>']
Next token:  

--------
Input token(s):  ['<|endoftext|>', 'George']
Next token:   Washington
--------
Input token(s):  ['<|endoftext|>', 'George', ' Washington']
Next token:   University
--------
Input token(s):  ['<|endoftext|>', 'George', ' Washington', ' fought']
Next token:   for
--------
Input token(s):  ['<|endoftext|>', 'George', ' Washington', ' fought', ' in']
Next token:   the
--------
Input token(s):  ['<|endoftext|>', 'George', ' Washington', ' fought', ' in', ' the']
Next token:   Revolutionary
--------


In [25]:
#This demonstrates that the model is able to predict next tokens 
# autoregressivly with varrying amounts of context
for i in range(len(obscure_str_tokens)):
  print("Input token(s): ", obscure_str_tokens[0:i+1])
  print("Next token: ", model.to_string(torch.argmax(obscure_logits[0][i])))
  print("--------")

Input token(s):  ['<|endoftext|>']
Next token:  

--------
Input token(s):  ['<|endoftext|>', 'The']
Next token:   first
--------
Input token(s):  ['<|endoftext|>', 'The', ' first']
Next token:   time
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president']
Next token:   of
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president', ' of']
Next token:   the
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president', ' of', ' the']
Next token:   United
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president', ' of', ' the', ' united']
Next token:   states
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president', ' of', ' the', ' united', ' states']
Next token:  ,
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president', ' of', ' the', ' united', ' states', ' fought']
Next token:   for
--------
Input token(s):  ['<|endoftext|>', 'The', ' first', ' president', ' of', ' the', ' united', ' sta

# Load Explicit vs Obscure Prompt Dataset

In [26]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [27]:
data = pd.read_csv('/content/drive/MyDrive/Research/Mechanistic Interpretability/data/multi_hop_100.csv')

In [28]:
data.head()

Unnamed: 0.1,Unnamed: 0,explicit_sent,obscure_sent,explicit_entity,obscure_entity,fact1,fact2,answer
0,0,The country of citizenship of Jaap Speyer is,The country of citizenship of the director of ...,Jaap Speyer,the director of Lilli's Marriage,"[""Lilli's Marriage"", 'director', 'Jaap Speyer']","['Jaap Speyer', 'country of citizenship', 'Dut...",Dutch
1,1,The place of birth of Dušan Hanák is,"The place of birth of the director of I Love, ...",Dušan Hanák,"the director of I Love, You Love","['I Love, You Love', 'director', 'Dušan Hanák']","['Dušan Hanák', 'place of birth', 'Bratislava']",Bratislava
2,2,The place of death of James Vincent is,The place of death of the director of Gold and...,James Vincent,the director of Gold and the Woman,"['Gold and the Woman', 'director', 'James Vinc...","['James Vincent', 'place of death', 'New York']",New York
3,3,The place of birth of Emil Loteanu is,The place of birth of the director of Lăutarii is,Emil Loteanu,the director of Lăutarii,"['Lăutarii', 'director', 'Emil Loteanu']","['Emil Loteanu', 'place of birth', 'Romania']",Romania
4,4,The country of citizenship of Archduke Karl Sa...,The country of citizenship of the father of Ar...,Archduke Karl Salvator of Austria,the father of Archduke Leopold Salvator of Aus...,"['Archduke Leopold Salvator of Austria', 'fath...","['Archduke Karl Salvator of Austria', 'country...",Italian


In [29]:
data = data.drop(['Unnamed: 0', 'fact1', 'fact2'], axis=1)
data.head()

Unnamed: 0,explicit_sent,obscure_sent,explicit_entity,obscure_entity,answer
0,The country of citizenship of Jaap Speyer is,The country of citizenship of the director of ...,Jaap Speyer,the director of Lilli's Marriage,Dutch
1,The place of birth of Dušan Hanák is,"The place of birth of the director of I Love, ...",Dušan Hanák,"the director of I Love, You Love",Bratislava
2,The place of death of James Vincent is,The place of death of the director of Gold and...,James Vincent,the director of Gold and the Woman,New York
3,The place of birth of Emil Loteanu is,The place of birth of the director of Lăutarii is,Emil Loteanu,the director of Lăutarii,Romania
4,The country of citizenship of Archduke Karl Sa...,The country of citizenship of the father of Ar...,Archduke Karl Salvator of Austria,the father of Archduke Leopold Salvator of Aus...,Italian


In [30]:
data['explicit_sent'][0]

'The country of citizenship of Jaap Speyer is'

In [31]:
logits = model(data['explicit_sent'][0])
probs = torch.nn.functional.softmax(logits[0][-1], dim=0)
next_pred_token_idx = torch.argmax(logits[0][-1], dim=0)
answer_tok = model.to_tokens(data['answer'][0], prepend_bos=False)
print("expected next token: ", data['answer'][0])
print("predicted next token: ", model.to_string(next_pred_token_idx))
print("probability of expected next token: ", probs[answer_tok])
print("probability of predicted next token: ", probs[next_pred_token_idx])

expected next token:  Dutch
predicted next token:   being
probability of expected next token:  tensor([[7.2476e-09]])
probability of predicted next token:  tensor(0.0633)


# Patching Hidden States

Create hook function to do patching combinatorially
-  want to do combinatorial corruptions (n choose r) where n is # of tokens, and r ranges from (1 ... n)

In [32]:
clean_prompt = "George Washington fought in the"
clean_tokens = model.to_tokens(clean_prompt)
clean_logits, clean_cache = model.run_with_cache(clean_tokens)

In [33]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary

#TODO: rename this function
def residual_stream_patching_hook(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    position: int,
    clean_cache: transformer_lens.ActivationCache
) -> Float[torch.Tensor, "batch pos d_model"]:
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

In [34]:
layer = 10
clean_resid_pre = clean_cache[utils.get_act_name("resid_pre", layer)]
position = 5
clean_resid_pre[:, position, :].shape

torch.Size([1, 768])

In [35]:
num_positions = len(clean_tokens[0])
patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
patches = []

In [36]:
mu, sigma = 3, 4 # mean and standard deviation
noise = np.random.normal(mu, sigma, 1)
noise

array([4.48437774])

In [37]:
#The model predicts the correct next token
model.to_string(torch.argmax(model(clean_prompt)[0][5]))

' Revolutionary'

In [38]:
model.to_string(clean_tokens[0])

'<|endoftext|>George Washington fought in the'

In [39]:
model.to_tokens(" Revolutionary", prepend_bos=False)

tensor([[28105]])

In [40]:
def patcher(model, hook_func, clean_tokens, corrupt_tokens, correct_next_token_idx):
  '''
  Given the cache from a clean_token run, and corrupted tokens
  Patcher can patch hidden states for inference with corrupted tokens from 
  the clean cache for each token x layer combination (inspired by the ROME casal tracing method)

  the patching schema is determined by a hook_function
  In this case the hook function restores corrupted states with corresponding clean states
  this is not the only type of patching possible!!
  '''
  clean_logits, clean_cache = model.run_with_cache(clean_tokens)
  num_positions = len(clean_tokens[0])
  patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

  for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
      
      # Use functools.partial to create a temporary hook function with the position fixed
      temp_hook_fn = partial(hook_func, position=position, clean_cache=clean_cache)

      #Get patched Logits
      patched_logits = model.run_with_hooks(corrupt_tokens, 
                          fwd_hooks=[
                                      (utils.get_act_name("resid_pre", layer), 
                                      temp_hook_fn)
                                      ]
                          )
      #Turn logits to probs via softmax
      probs = torch.nn.functional.softmax(patched_logits[0][-1], dim=0)
      #Get probability of the correct next token
      val = probs[correct_next_token_idx]      
      patching_result[layer, position] = val

  return patching_result

In [43]:
def patcher(model, hook_point, hook_func, clean_tokens, corrupt_tokens, correct_next_token_idx):
  '''
  Given the cache from a clean_token run, and corrupted tokens
  Patcher can patch hidden states for inference with corrupted tokens from 
  the clean cache for each token x layer combination (inspired by the ROME casal tracing method)

  the patching schema is determined by a hook_function
  In this case the hook function restores corrupted states with corresponding clean states
  this is not the only type of patching possible!!
  '''
  clean_logits, clean_cache = model.run_with_cache(clean_tokens)
  num_positions = len(clean_tokens[0])
  patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

  for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
      
      # Use functools.partial to create a temporary hook function with the position fixed
      temp_hook_fn = partial(hook_func, position=position, clean_cache=clean_cache)

      #Get patched Logits
      patched_logits = model.run_with_hooks(corrupt_tokens, 
                          fwd_hooks=[
                                      (utils.get_act_name(hook_point, layer),  #"resid_pre"
                                      temp_hook_fn)
                                      ]
                          )
      #Turn logits to probs via softmax
      probs = torch.nn.functional.softmax(patched_logits[0][-1], dim=0)
      #Get probability of the correct next token
      val = probs[correct_next_token_idx]      
      patching_result[layer, position] = val

  return patching_result

In [44]:
# Patch the subjects
def subject_patcher(model, hook_point, hook_func, clean_prompt, subject_prompt, correct_next_token_idx, mu, sigma):
  clean_tokens = model.to_tokens(clean_prompt)
  subject_tokens = model.to_tokens(subject_prompt, prepend_bos=False)
  num_positions = len(clean_tokens[0])

  corrupt_tokens = torch.clone(clean_tokens)
  corrupted_idx = []
  for token in range(num_positions):
    for sub in range(len(subject_tokens[0])):
      if corrupt_tokens[0][token] == subject_tokens[0][sub]:
        noise = np.random.normal(mu, sigma, 1)
        corrupt_tokens[0][token] += noise
        corrupted_idx.append(token)

  print(model.to_str_tokens(corrupt_tokens))
      
  patching_result = patcher(model, 
                              hook_point,
                              hook_func, 
                              clean_tokens, 
                              corrupt_tokens, 
                              correct_next_token_idx)  
  token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]

  return patching_result, corrupted_idx

In [59]:
correct_next_token_idx = 28105
mu, sigma = 3, 4 # mean and standard deviation for noise
patches, corrupted_idx = subject_patcher(model,
                              "mlp_out", #"resid_pre", "attn_out"
                             residual_stream_patching_hook,
                             clean_prompt,
                             "George Washington",
                             correct_next_token_idx,
                             mu,
                             sigma)

['<|endoftext|>', 'border', 'oor', ' fought', ' in', ' the']


  0%|          | 0/12 [00:00<?, ?it/s]

In [60]:
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
for i in corrupted_idx:
    token_labels[i] += "*"
imshow(patches.T, y=token_labels, yaxis="Position (* indicates corrupted token)", xaxis="Layer", title="Probability for ' Revolutionary' next token prediction")

In [55]:
# Patch all tokens one at a time
def token_wise_patcher(model, hook_point, hook_func, clean_prompt, correct_next_token_idx, mu, sigma):
  patches = []
  clean_tokens = model.to_tokens(clean_prompt)
  num_positions = len(clean_tokens[0])

  for token in range(num_positions):
    noise = np.random.normal(mu, sigma, 1)
    if token == 0:
      noise = 0
    corrupt_tokens = torch.clone(clean_tokens)
    corrupt_tokens[0][token] += noise
    print(model.to_str_tokens(corrupt_tokens))
      
    patching_result = patcher(model, 
                              hook_point,
                              hook_func, 
                              clean_tokens, 
                              corrupt_tokens, 
                              correct_next_token_idx)  
    patches.append(torch.clone(patching_result))
  return patches


In [57]:
correct_next_token_idx = 28105
mu, sigma = 3, 4 # mean and standard deviation for noise
patches = token_wise_patcher(model,
                             "attn_out",
                             residual_stream_patching_hook,
                             clean_prompt,
                             correct_next_token_idx,
                             mu,
                             sigma)

['<|endoftext|>', 'George', ' Washington', ' fought', ' in', ' the']


  0%|          | 0/12 [00:00<?, ?it/s]

['<|endoftext|>', 'border', ' Washington', ' fought', ' in', ' the']


  0%|          | 0/12 [00:00<?, ?it/s]

['<|endoftext|>', 'George', ' Gl', ' fought', ' in', ' the']


  0%|          | 0/12 [00:00<?, ?it/s]

['<|endoftext|>', 'George', ' Washington', 'arms', ' in', ' the']


  0%|          | 0/12 [00:00<?, ?it/s]

['<|endoftext|>', 'George', ' Washington', ' fought', 'ion', ' the']


  0%|          | 0/12 [00:00<?, ?it/s]

['<|endoftext|>', 'George', ' Washington', ' fought', ' in', 'on']


  0%|          | 0/12 [00:00<?, ?it/s]

In [58]:
for token in range(num_positions):
  %matplotlib inline
  # Add the index to the end of the label, because plotly doesn't like duplicate labels
  token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
  if token>0:
    token_labels[token] += "*"
  imshow(patches[token].T, y=token_labels, yaxis="Position (* indicates corrupted token)", xaxis="Layer", title="Probability for ' Revolutionary' next token prediction")

# Game plan
- First show that models perform worse on obscure entities compared to explicit entities (trivial)
- show where in the model is exercised in making the final decision (by friday)
- show that some type of intervention in the obscure case can significantly boost the probability of the correct answer (next week)

# Training small 2 layer GPT style model on Shakespeare 

Mansi Ramble: good goal would be to have a trained and hooked model, have a combinatorial causal tracing strategy (corrupt vs uncorrupt) for prompts. This doesn't necessarilly warant that we can relate how obscure vs explicit prompts work, but we can generate a strategy for that next week. This might just be doing a monsterous run of corrupt vs uncorrupted prompts through the network and trying to post-hoc interpret and understand our results.

One way could be to compare the signals in the residual stream at various time steps in the obscure vs explicit prompts (would be nice to see if we could figure out how to apply a "correction" to the obscure prompts based on the explicit prompts) - perhaps manually after a subject token. (we could do similar causal tracing, but rather than corrupt vs uncorrupt states we can try and match parts of the sentences and supplement hidden states from there)?