## Setup

In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import haystack_utils
from transformer_lens import utils
from fancy_einsum import einsum
import einops
import json
import ipywidgets as widgets
from IPython.display import display
from datasets import load_dataset
import random
import math


pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained(
    "roneneldan/TinyStories-33M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer


## A vs AN investigation

In [3]:
fruits = ["apple", "pear", "orange", "banana"]
for fruit in fruits:
    utils.test_prompt(f"Once upon a time there was a boy named Tim. Tim loved {fruit}s. He climbed an {fruit} tree and picked", "an", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Tim', '.', ' Tim', ' loved', ' apples', '.', ' He', ' climbed', ' an', ' apple', ' tree', ' and', ' picked']
Tokenized answer: [' an']


Top 0th token. Logit: 23.18 Prob: 42.98% Token: | the|
Top 1th token. Logit: 22.07 Prob: 14.17% Token: | an|
Top 2th token. Logit: 21.85 Prob: 11.35% Token: | a|
Top 3th token. Logit: 21.55 Prob:  8.43% Token: | one|
Top 4th token. Logit: 21.15 Prob:  5.65% Token: | lots|
Top 5th token. Logit: 20.98 Prob:  4.74% Token: | apples|
Top 6th token. Logit: 20.60 Prob:  3.26% Token: | some|
Top 7th token. Logit: 20.04 Prob:  1.86% Token: | many|
Top 8th token. Logit: 19.86 Prob:  1.55% Token: | all|
Top 9th token. Logit: 19.39 Prob:  0.97% Token: | two|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Tim', '.', ' Tim', ' loved', ' p', 'ears', '.', ' He', ' climbed', ' an', ' pear', ' tree', ' and', ' picked']
Tokenized answer: [' an']


Top 0th token. Logit: 24.18 Prob: 54.64% Token: | the|
Top 1th token. Logit: 23.26 Prob: 21.79% Token: | a|
Top 2th token. Logit: 22.08 Prob:  6.72% Token: | lots|
Top 3th token. Logit: 21.60 Prob:  4.15% Token: | one|
Top 4th token. Logit: 21.17 Prob:  2.68% Token: | some|
Top 5th token. Logit: 20.89 Prob:  2.03% Token: | p|
Top 6th token. Logit: 20.87 Prob:  2.01% Token: | many|
Top 7th token. Logit: 20.35 Prob:  1.18% Token: | all|
Top 8th token. Logit: 19.53 Prob:  0.52% Token: | an|
Top 9th token. Logit: 19.28 Prob:  0.41% Token: | as|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Tim', '.', ' Tim', ' loved', ' oranges', '.', ' He', ' climbed', ' an', ' orange', ' tree', ' and', ' picked']
Tokenized answer: [' an']


Top 0th token. Logit: 23.53 Prob: 30.05% Token: | an|
Top 1th token. Logit: 23.11 Prob: 19.82% Token: | the|
Top 2th token. Logit: 22.68 Prob: 12.89% Token: | one|
Top 3th token. Logit: 22.63 Prob: 12.20% Token: | oranges|
Top 4th token. Logit: 22.60 Prob: 11.88% Token: | a|
Top 5th token. Logit: 21.17 Prob:  2.83% Token: | some|
Top 6th token. Logit: 20.91 Prob:  2.19% Token: | lots|
Top 7th token. Logit: 20.50 Prob:  1.44% Token: | all|
Top 8th token. Logit: 20.00 Prob:  0.88% Token: | it|
Top 9th token. Logit: 19.89 Prob:  0.79% Token: | up|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Tim', '.', ' Tim', ' loved', ' bananas', '.', ' He', ' climbed', ' an', ' banana', ' tree', ' and', ' picked']
Tokenized answer: [' an']


Top 0th token. Logit: 23.39 Prob: 37.74% Token: | a|
Top 1th token. Logit: 22.88 Prob: 22.59% Token: | the|
Top 2th token. Logit: 22.86 Prob: 22.29% Token: | one|
Top 3th token. Logit: 21.14 Prob:  3.99% Token: | some|
Top 4th token. Logit: 21.06 Prob:  3.69% Token: | lots|
Top 5th token. Logit: 20.14 Prob:  1.47% Token: | all|
Top 6th token. Logit: 19.85 Prob:  1.09% Token: | two|
Top 7th token. Logit: 19.78 Prob:  1.02% Token: | bananas|
Top 8th token. Logit: 19.57 Prob:  0.83% Token: | an|
Top 9th token. Logit: 19.51 Prob:  0.78% Token: | many|


In [4]:
dataset = load_dataset("roneneldan/TinyStories", split='validation')
dataset = [x["text"] for x in dataset]

Downloading readme:   0%|          | 0.00/946 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [5]:
A = model.to_single_token(" a")
AN = model.to_single_token(" an")

In [22]:
prompts = []
fruits = ["apple", "orange", "avocado", "apricot", "olive"] + ["pear", "banana", "plum", "lemon", "cherry"]
#a_fruits = 
for fruit in fruits:
    prompts.append(f'Once upon a time there was a boy named Tim. Tim loved {fruit}s. He climbed {"an" if fruit[0] in ["a", "e", "i", "o", "u"] else "a"} {fruit} tree and picked')

tokens = model.to_tokens(prompts)
final_index = (tokens!=model.tokenizer.pad_token_id).sum(-1)


unembed_dir = model.W_U[:, AN] - model.W_U[:, A]
logits, cache = model.run_with_cache(model.to_tokens(prompts))
resid_stack, resid_labels = cache.get_full_resid_decomposition(expand_neurons=False, apply_ln=True, return_labels=True)
final_resid_stack = resid_stack[:, np.arange(len(prompts)), final_index, :]
(final_resid_stack @ unembed_dir).shape

from neel_plotly import *
line((final_resid_stack @ unembed_dir).T, line_labels=[fruit for fruit in fruits], x=resid_labels, title="Direct Logit Diff Attribution")

Tried to stack head results when they weren't cached. Computing head results now


In [23]:
logits, cache = model.run_with_cache(model.to_tokens(prompts))
final_ln_scale = cache["scale"][np.arange(len(prompts)), final_index, 0]
unembed_dir = model.W_U[:, AN] - model.W_U[:, A]

print(final_ln_scale.shape)
for layer in [3]:
    neuron_acts = cache["post", layer][np.arange(len(prompts)), final_index, :]
    neuron_wdla = (model.blocks[layer].mlp.W_out @ unembed_dir) / final_ln_scale[:, None]
    line(neuron_acts * neuron_wdla, line_labels=[fruit for fruit in fruits], title=f"Neuron DLA in L{layer}")

torch.Size([10])


In [19]:
logits, cache = model.run_with_cache(model.to_tokens(prompts))
final_ln_scale = cache["scale"][np.arange(len(prompts)), final_index, 0]
unembed_dir = model.W_U[:, AN] - model.W_U[:, A]

print(final_ln_scale.shape)
for layer in [3]:
    neuron_acts = cache["post", layer][np.arange(len(prompts)), final_index, :]
    neuron_wdla = (model.blocks[layer].mlp.W_out @ unembed_dir) / final_ln_scale[:, None]
    line(neuron_acts * neuron_wdla, line_labels=[fruit for fruit in fruits], title=f"Neuron DLA in L{layer}")

torch.Size([5])


In [23]:
prompts = []
animals = ["alligators", "owls", "octopuses", "eagles", "elephants", "bears", "foxes", "lions", "tigers"]
#a_animals = 
for animal in animals:
    prompts.append(f"Once upon a time there was a boy named Tim. Tim loved {animal}s. He went to the zoo to see")

tokens = model.to_tokens(prompts)
final_index = (tokens!=model.tokenizer.pad_token_id).sum(-1)


unembed_dir = model.W_U[:, AN] - model.W_U[:, A]
logits, cache = model.run_with_cache(model.to_tokens(prompts))
resid_stack, resid_labels = cache.get_full_resid_decomposition(expand_neurons=False, apply_ln=True, return_labels=True)
final_resid_stack = resid_stack[:, np.arange(len(prompts)), final_index, :]
(final_resid_stack @ unembed_dir).shape

from neel_plotly import *
line((final_resid_stack @ unembed_dir).T, line_labels=[fruit for fruit in animals], x=resid_labels, title="Direct Logit Diff Attribution")

Tried to stack head results when they weren't cached. Computing head results now


## Dataset investigation

In [6]:
test = load_dataset("roneneldan/TinyStories", split='validation')
test = [x["text"] for x in test]

train = load_dataset("roneneldan/TinyStories", split='train')
train = [x["text"] for x in train]

Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


In [7]:
print(len(test), len(train))

21990 2119719


In [15]:
train_data = []
for x in tqdm(train):
    x = model.to_tokens(x).flatten().cpu()
    if len(x) > 150:
        train_data.append(x[50:150])
    if len(train_data) >= 5000:
        break

test_data = []
for x in tqdm(test):
    x = model.to_tokens(x).flatten().cpu()
    if len(x) > 150:
        test_data.append(x[50:150])
    if len(test_data) >= 5000:
        break

train_data = torch.stack(train_data)
test_data = torch.stack(test_data)
y = torch.cat([torch.zeros(train_data.shape[0]), torch.ones(test_data.shape[0])])
x = torch.cat([train_data, test_data])
print(x.shape, y.shape)

  0%|          | 0/2119719 [00:00<?, ?it/s]

  0%|          | 0/21990 [00:00<?, ?it/s]

torch.Size([10000, 100]) torch.Size([10000])


In [30]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
from sklearn import preprocessing
## Baseline probe on tokens
def train_linear_probe(x, y):
    x_numpy = x.numpy()
    y_numpy = y.numpy()
    x_numpy = preprocessing.StandardScaler().fit_transform(x_numpy)
    # Split the dataset into training and testing sets
    x_train, x_test, y_train, y_test = train_test_split(x_numpy, y_numpy, test_size=0.2, random_state=42)
    # Create and train a logistic regression model (since it's binary classification)
    model = LogisticRegression(max_iter=1000)
    model.fit(x_train, y_train)
    # Make predictions on the test set
    y_pred = model.predict(x_test)
    # Compute the F1 score
    f1 = f1_score(y_test, y_pred)
    print("F1 Score:", f1)
train_linear_probe(x, y)

F1 Score: 0.5189753320683111


In [24]:
haystack_utils.clean_cache()

In [28]:
new_x = []

for example in tqdm(x):
    _, cache = model.run_with_cache(example)
    res = cache["resid_post", 3][0, -5:-1]
    new_x.append(res.cpu())
new_x = torch.stack(new_x)
print(new_x.shape)

  0%|          | 0/10000 [00:00<?, ?it/s]

torch.Size([10000, 4, 768])


In [31]:
train_linear_probe(new_x.view(10000, -1), y)

F1 Score: 0.48068238835925736


## Factual recall

In [5]:
prompt = 'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "What color is a monkey?", he asked his mom. "A monkey is'
utils.test_prompt(prompt, " brown", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loves', ' to', ' learn', ' about', ' different', ' animals', '.', ' "', 'What', ' color', ' is', ' a', ' monkey', '?",', ' he', ' asked', ' his', ' mom', '.', ' "', 'A', ' monkey', ' is']
Tokenized answer: [' brown']


Top 0th token. Logit: 23.03 Prob: 50.34% Token: | a|
Top 1th token. Logit: 22.00 Prob: 17.96% Token: | an|
Top 2th token. Logit: 21.79 Prob: 14.48% Token: | brown|
Top 3th token. Logit: 21.04 Prob:  6.85% Token: | black|
Top 4th token. Logit: 19.54 Prob:  1.53% Token: | orange|
Top 5th token. Logit: 19.49 Prob:  1.45% Token: | red|
Top 6th token. Logit: 19.33 Prob:  1.24% Token: | yellow|
Top 7th token. Logit: 18.81 Prob:  0.74% Token: | so|
Top 8th token. Logit: 18.75 Prob:  0.70% Token: | very|
Top 9th token. Logit: 18.46 Prob:  0.52% Token: | bright|


In [20]:
# Eat
prompt = 'Once upon a time there was a boy named Bob. Bob loved to learn what different animals eat. "Birds eat'
utils.test_prompt(prompt, " worms", model)
prompt = 'Once upon a time there was a boy named Bob. Bob loved to learn what different animals eat. "Squirrels eat'
utils.test_prompt(prompt, " nuts", model)
prompt = 'Once upon a time there was a boy named Bob. Bob loved to learn what different animals eat. "Monkeys eat'
utils.test_prompt(prompt, " nuts", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' to', ' learn', ' what', ' different', ' animals', ' eat', '.', ' "', 'B', 'irds', ' eat']
Tokenized answer: [' worms']


Top 0th token. Logit: 17.36 Prob: 15.38% Token: | in|
Top 1th token. Logit: 16.93 Prob:  9.98% Token: | worms|
Top 2th token. Logit: 16.55 Prob:  6.82% Token: | nuts|
Top 3th token. Logit: 16.26 Prob:  5.11% Token: | squirrel|
Top 4th token. Logit: 16.09 Prob:  4.29% Token: | leaves|
Top 5th token. Logit: 15.88 Prob:  3.50% Token: | bugs|
Top 6th token. Logit: 15.86 Prob:  3.42% Token: | food|
Top 7th token. Logit: 15.80 Prob:  3.23% Token: |,"|
Top 8th token. Logit: 15.66 Prob:  2.81% Token: | grass|
Top 9th token. Logit: 15.45 Prob:  2.27% Token: | lettuce|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' to', ' learn', ' what', ' different', ' animals', ' eat', '.', ' "', 'Squ', 'irrel', 's', ' eat']
Tokenized answer: [' nuts']


Top 0th token. Logit: 20.87 Prob: 50.76% Token: | nuts|
Top 1th token. Logit: 20.32 Prob: 29.47% Token: | leaves|
Top 2th token. Logit: 17.26 Prob:  1.37% Token: | grass|
Top 3th token. Logit: 17.15 Prob:  1.23% Token: | the|
Top 4th token. Logit: 17.11 Prob:  1.18% Token: | a|
Top 5th token. Logit: 16.98 Prob:  1.04% Token: | apples|
Top 6th token. Logit: 16.98 Prob:  1.04% Token: | squirrel|
Top 7th token. Logit: 16.87 Prob:  0.93% Token: | oats|
Top 8th token. Logit: 16.74 Prob:  0.82% Token: | bugs|
Top 9th token. Logit: 16.54 Prob:  0.67% Token: | food|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' to', ' learn', ' what', ' different', ' animals', ' eat', '.', ' "', 'Mon', 'keys', ' eat']
Tokenized answer: [' nuts']


Top 0th token. Logit: 21.38 Prob: 64.68% Token: | apples|
Top 1th token. Logit: 20.28 Prob: 21.46% Token: | bananas|
Top 2th token. Logit: 18.52 Prob:  3.70% Token: | leaves|
Top 3th token. Logit: 17.49 Prob:  1.32% Token: |,"|
Top 4th token. Logit: 16.61 Prob:  0.54% Token: | banana|
Top 5th token. Logit: 16.45 Prob:  0.47% Token: |,|
Top 6th token. Logit: 16.34 Prob:  0.42% Token: | a|
Top 7th token. Logit: 16.30 Prob:  0.40% Token: | fruits|
Top 8th token. Logit: 16.18 Prob:  0.35% Token: | apple|
Top 9th token. Logit: 16.17 Prob:  0.35% Token: | meat|


In [23]:
# Sound
prompt = 'Once upon a time there was a boy named Bob. Bob loved to learn where different animals live. "Birds live in'
utils.test_prompt(prompt, " worms", model)
prompt = 'Once upon a time there was a boy named Bob. Bob loved to learn where different animals live. "Squirrels live in'
utils.test_prompt(prompt, " nuts", model)
prompt = 'Once upon a time there was a boy named Bob. Bob loved to learn where different animals live. "Monkeys live in'
utils.test_prompt(prompt, " nuts", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' to', ' learn', ' where', ' different', ' animals', ' live', '.', ' "', 'B', 'irds', ' live', ' in']
Tokenized answer: [' worms']


Top 0th token. Logit: 20.88 Prob: 76.24% Token: | the|
Top 1th token. Logit: 18.79 Prob:  9.44% Token: | trees|
Top 2th token. Logit: 17.71 Prob:  3.21% Token: | a|
Top 3th token. Logit: 17.56 Prob:  2.76% Token: | nests|
Top 4th token. Logit: 16.88 Prob:  1.40% Token: | cages|
Top 5th token. Logit: 16.41 Prob:  0.87% Token: | houses|
Top 6th token. Logit: 15.55 Prob:  0.37% Token: | parks|
Top 7th token. Logit: 15.50 Prob:  0.35% Token: | rivers|
Top 8th token. Logit: 15.40 Prob:  0.32% Token: | big|
Top 9th token. Logit: 15.35 Prob:  0.30% Token: | an|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' to', ' learn', ' where', ' different', ' animals', ' live', '.', ' "', 'Squ', 'irrel', 's', ' live', ' in']
Tokenized answer: [' nuts']


Top 0th token. Logit: 21.53 Prob: 61.82% Token: | the|
Top 1th token. Logit: 20.47 Prob: 21.36% Token: | trees|
Top 2th token. Logit: 19.57 Prob:  8.69% Token: | a|
Top 3th token. Logit: 17.83 Prob:  1.54% Token: | houses|
Top 4th token. Logit: 17.31 Prob:  0.91% Token: | nests|
Top 5th token. Logit: 17.14 Prob:  0.77% Token: | an|
Top 6th token. Logit: 16.59 Prob:  0.44% Token: | homes|
Top 7th token. Logit: 16.44 Prob:  0.38% Token: | tree|
Top 8th token. Logit: 16.31 Prob:  0.33% Token: | big|
Top 9th token. Logit: 16.22 Prob:  0.31% Token: | groups|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' to', ' learn', ' where', ' different', ' animals', ' live', '.', ' "', 'Mon', 'keys', ' live', ' in']
Tokenized answer: [' nuts']


Top 0th token. Logit: 21.02 Prob: 77.31% Token: | the|
Top 1th token. Logit: 18.47 Prob:  6.04% Token: | trees|
Top 2th token. Logit: 18.46 Prob:  5.99% Token: | a|
Top 3th token. Logit: 17.17 Prob:  1.65% Token: | China|
Top 4th token. Logit: 16.75 Prob:  1.08% Token: | Africa|
Top 5th token. Logit: 16.46 Prob:  0.81% Token: | tree|
Top 6th token. Logit: 16.38 Prob:  0.75% Token: | big|
Top 7th token. Logit: 16.20 Prob:  0.62% Token: | houses|
Top 8th token. Logit: 15.71 Prob:  0.38% Token: | Mexico|
Top 9th token. Logit: 15.46 Prob:  0.30% Token: | different|


In [84]:
data = [
    ("cows", "grass", "field", "brown"),
    ("elephants", "leaves", "desert", "gray"),
    ("rabbits", "carrots", "garden", "white"),
    ("pigs", "potatos", "farm", "pink"),
    ("monkeys", "bananas", "jungle", "brown"),
    ("birds", "worms", "tree", "black"),
    ("crabs", "fish", "ocean", "red"),
    ("zebra", "grass", "savannah", "black and white"),
    ("sharks", "fish", "ocean", "gray"),
    ("wolfs", "meat", "forest", "brown"),
    ("tiger", "meat", "jungle", "orange"),
    ("frog", "flies", "pond", "green"),
    ("flamingo", "fish", "lake", "pink"),
]

In [118]:
animal = "crabs"
prompt = f'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "Where do {animal} live?", his mother asked. "They live'
length = len(prompt)
answers = []
for i in range(10):
    answer = model.generate(prompt, max_new_tokens=15, temperature=1, verbose=False)
    answer = answer[length:-1]
    answers.append(answer.split(".")[0])    
for answer in answers:
    print(answer)

 in the ocean", Bob replied
 in the ocean," Bob explained
 in the shell top of the sea", Bob replied
 in the ocean," answered Bob
 in the sea
"They like to live in the sand", Bob ha
 in the sea!" Bob said
 in the sea
 in the ocean", he replied
 in the ocean," Bob answered
 inside the shell," replied Bob


In [119]:
places = ["ocean", "garden", "forest", "jungle", "pond"]

In [120]:
index = -1
animal = data[index][0]
capitalize = lambda name: name[0].upper() + name[1:]
eat_prompt = f'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "What do {animal} eat?", his mother asked. "They eat'
live_prompt = f'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "Where do {animal} live?", his mother asked. "They live in the'
color_prompt = f'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "What color are {animal}?", his mother asked. "They are'
utils.test_prompt(eat_prompt, " "+data[index][1], model)
utils.test_prompt(live_prompt, " "+data[index][2], model)
utils.test_prompt(color_prompt, " "+data[index][3], model)

NameError: name 'data' is not defined

In [122]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache, dir) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einops.einsum(scaled_residual_stack, dir, "components batch d_model, batch d_model -> components")/len(prompts)


animal = "cows"
answer = " grass"
prompt = f'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "What do {animal} eat?", his mother asked. "They eat'
prompts = [prompt]
answer_token = model.to_single_token(answer)
_, cache = model.run_with_cache(prompts)
answer_dir = model.tokens_to_residual_directions(answer_token).unsqueeze(0)
print("Answer dir shape", answer_dir.shape)
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer = -1, pos_slice=-1)
print("Scaled residual stack shape", scaled_residual_stack.shape)
logit_lens_logit_diffs =  einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, answer_dir)/len(prompts)
print("Logit lens logit diffs shape", logit_lens_logit_diffs.shape)
print("Logit lens logit diffs", logit_lens_logit_diffs)
px.line(y=logit_lens_logit_diffs.cpu().numpy(), x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="DLA From Accumulate Residual Stream")


per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache, answer_dir)
px.line(y=per_layer_logit_diffs.cpu().numpy(), hover_name=labels, title="DLA From Each Layer")

Answer dir shape torch.Size([1, 768])
Scaled residual stack shape torch.Size([9, 1, 768])
Logit lens logit diffs shape torch.Size([9])
Logit lens logit diffs tensor([ 0.3530,  1.1197,  3.7302,  4.6555,  4.9084,  7.4523,  8.3859, 15.0320,
        21.8727], device='cuda:0')


In [123]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache, answer_dir)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
px.imshow(per_head_logit_diffs.cpu().numpy(), labels={"x":"Head", "y":"Layer"}, title="DLA From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [124]:
per_layer_residual, labels = cache.get_full_resid_decomposition(layer=-1, pos_slice=-1, return_labels=True, expand_neurons=True)
print(per_layer_residual.shape)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache, answer_dir)
px.line(y=per_layer_logit_diffs.cpu().numpy(), hover_name=labels, title="Component wise DLA")

torch.Size([12355, 1, 768])


In [69]:
per_layer_residual, labels
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache, answer_dir)
neuron_slice = np.s_[64:12352]
neuron_labels = labels[neuron_slice]
neuron_dla = per_layer_logit_diffs[neuron_slice]

top_values, top_indices = torch.topk(neuron_dla, 15, dim=0)
for label in top_indices:
    print(neuron_labels[label])

L3N1372
L3N1657
L0N2995
L3N1669
L3N1062
L3N2311
L3N1451
L3N2675
L3N271
L3N1714
L3N724
L3N1802
L3N859
L2N2679
L3N1244


In [26]:
preprompt = 'Once upon a time there was a boy named Bob. Bob loves to learn about different food. When he did not know something, he asked his mother for help.'
prompt = preprompt + ' Are cucumbers a fruit?", he asked his mother. "'
utils.test_prompt(prompt, "Yes", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loves', ' to', ' learn', ' about', ' different', ' food', '.', ' When', ' he', ' did', ' not', ' know', ' something', ',', ' he', ' asked', ' his', ' mother', ' for', ' help', '.', ' Are', ' cuc', 'umbers', ' a', ' fruit', '?",', ' he', ' asked', ' his', ' mother', '.', ' "']
Tokenized answer: [' Yes']


Top 0th token. Logit: 18.47 Prob: 29.53% Token: |Yes|
Top 1th token. Logit: 17.61 Prob: 12.58% Token: |No|
Top 2th token. Logit: 17.28 Prob:  9.02% Token: |Oh|
Top 3th token. Logit: 17.23 Prob:  8.53% Token: |That|
Top 4th token. Logit: 16.69 Prob:  5.01% Token: |Those|
Top 5th token. Logit: 16.36 Prob:  3.60% Token: |Of|
Top 6th token. Logit: 16.07 Prob:  2.67% Token: |Sure|
Top 7th token. Logit: 15.90 Prob:  2.25% Token: |Why|
Top 8th token. Logit: 15.56 Prob:  1.61% Token: |Eat|
Top 9th token. Logit: 15.49 Prob:  1.50% Token: |Ve|


## Color of animals

In [27]:
test = '"Strawberries are red", Bob said. "Correct", his mother answered. "Do you also know what color the sky is?" "Of course", Bob answered proudly, "it is'

In [28]:
utils.test_prompt(test, " blue", model)

Tokenized prompt: ['<|endoftext|>', '"', 'St', 'raw', 'berries', ' are', ' red', '",', ' Bob', ' said', '.', ' "', 'Correct', '",', ' his', ' mother', ' answered', '.', ' "', 'Do', ' you', ' also', ' know', ' what', ' color', ' the', ' sky', ' is', '?"', ' "', 'Of', ' course', '",', ' Bob', ' answered', ' proudly', ',', ' "', 'it', ' is']
Tokenized answer: [' blue']


Top 0th token. Logit: 23.22 Prob: 87.59% Token: | blue|
Top 1th token. Logit: 19.81 Prob:  2.90% Token: | red|
Top 2th token. Logit: 19.81 Prob:  2.89% Token: | the|
Top 3th token. Logit: 18.82 Prob:  1.07% Token: |!"|
Top 4th token. Logit: 18.05 Prob:  0.50% Token: | yellow|
Top 5th token. Logit: 17.81 Prob:  0.39% Token: | bright|
Top 6th token. Logit: 17.60 Prob:  0.32% Token: | brown|
Top 7th token. Logit: 17.56 Prob:  0.31% Token: | pink|
Top 8th token. Logit: 17.51 Prob:  0.29% Token: |."|
Top 9th token. Logit: 17.41 Prob:  0.26% Token: |!|


In [29]:
objects = ["sky", "grass", "the sun", "an apple", "a carrot", "coal", "a strawberry", "the ocean", "a pumpkin", "a lemon", "a tomato"]
colors = ["blue", "green", "yellow", "red", "orange", "black", "red", "blue", "orange", "yellow", "red"]

def make_prompt(object, color):
    prompt = f"'Strawberries are red', Bob said. 'Correct', his mother answered. 'Do you also know what color {object} is?' 'Of course', Bob answered proudly, 'it is"
    answer = f" {color}"
    return prompt, answer

print(make_prompt(objects[0], colors[0]))

("'Strawberries are red', Bob said. 'Correct', his mother answered. 'Do you also know what color sky is?' 'Of course', Bob answered proudly, 'it is", ' blue')


In [73]:
def test_prompt(prompt, answer, model, object):
    logits = model(prompt)
    answer_token = model.to_single_token(answer)
    pred = logits[0, -1, :].softmax(dim=-1)
    prob = pred[answer_token].item()
    rank = (pred >= pred[answer_token]).sum().item()
    print(f"{object} -> {answer}: Rank {rank} (p={prob:.2f})")

for object, color in zip(objects, colors):
    prompt, answer = make_prompt(object, color)
    test_prompt(prompt, answer, model, object)

sky ->  blue: Rank 1 (p=0.32)
grass ->  green: Rank 3 (p=0.10)
the sun ->  yellow: Rank 9 (p=0.02)
an apple ->  red: Rank 1 (p=0.37)
a carrot ->  orange: Rank 5 (p=0.05)
coal ->  black: Rank 2 (p=0.14)
a strawberry ->  red: Rank 1 (p=0.28)
the ocean ->  blue: Rank 1 (p=0.30)
a pumpkin ->  orange: Rank 4 (p=0.09)
a lemon ->  yellow: Rank 2 (p=0.11)
a tomato ->  red: Rank 1 (p=0.43)


## Sizes of animals

In [191]:
def make_prompt(animal_1, animal_2):
    prompt = f'Once upon a time there was a boy named Bob. Bob loved animals and was proud that he knew all the sizes of his favorite animals. "A tiger is larger than a cat", Bob said. "Correct", his mother answered. "Do you also know if {animal_1} is larger than {animal_2}?" "Of course", Bob answered proudly, "The'
    return prompt

In [192]:
animals = ["dog", "cat", "cow", "horse", "sheep", "elephant", "lion", "tiger", "bear", "duck", "chicken", "fish", "turtle", "rabbit", "monkey"]

for animal in animals:
    haystack_utils.print_tokenized_word(" "+animal, model)

[' dog']
[' cat']
[' cow']
[' horse']
[' sheep']
[' elephant']
[' lion']
[' tiger']
[' bear']
[' duck']
[' chicken']
[' fish']
[' turtle']
[' rabbit']
[' monkey']


In [193]:
animal_pairs = [
    ("a horse", "a dog", " horse", " dog"),
    ("a bear", "a duck", " bear", " duck"),
    ("an elephant", "a rabbit", " elephant", " rabbit"),
    ("a sheep", "a bug", " sheep", " bug"),
    ("a cow", "a fish", " cow", " fish"),
]

def get_ranks(prompt, correct, incorrect):
    logits = model(prompt)
    answer_token = model.to_single_token(correct)
    incorrect_token = model.to_single_token(incorrect)
    pred = logits[0, -1, :].softmax(dim=-1)
    prob = pred[answer_token].item()
    rank = (pred >= pred[answer_token]).sum().item()
    incorrect_prob = pred[incorrect_token].item()
    incorrect_rank = (pred >= pred[incorrect_token]).sum().item()
    print(f"{correct}>{incorrect}: Rank {rank}/{incorrect_rank} (p={prob:.2f}, {incorrect_prob:.2f})")


for animal_1, animal_2, correct, incorrect in animal_pairs:
    prompt_1 = make_prompt(animal_1, animal_2)
    prompt_2 = make_prompt(animal_2, animal_1)
    correct_token = model.to_single_token(correct)
    incorrect_token = model.to_single_token(incorrect)

    get_ranks(prompt_1, correct, incorrect)
    get_ranks(prompt_2, correct, incorrect)

 horse> dog: Rank 1/5 (p=0.81, 0.01)
 horse> dog: Rank 54/1 (p=0.00, 0.34)
 bear> duck: Rank 1/2 (p=0.40, 0.13)
 bear> duck: Rank 4/1 (p=0.00, 0.91)
 elephant> rabbit: Rank 1/5 (p=0.74, 0.02)
 elephant> rabbit: Rank 2/1 (p=0.07, 0.74)
 sheep> bug: Rank 1/223 (p=0.32, 0.00)
 sheep> bug: Rank 271/3 (p=0.00, 0.09)
 cow> fish: Rank 1/27 (p=0.79, 0.00)
 cow> fish: Rank 9/7 (p=0.01, 0.01)


## Random failures

Ideas
- Opposites
- Categorize fruit or vegetable
- Match animals to habitates
- analogies

In [77]:
utils.test_prompt("A bird has wings and a cat has", " paws", model)

Tokenized prompt: ['<|endoftext|>', 'A', ' bird', ' has', ' wings', ' and', ' a', ' cat', ' has']
Tokenized answer: [' paws']


Top 0th token. Logit: 17.92 Prob: 43.50% Token: | a|
Top 1th token. Logit: 16.25 Prob:  8.20% Token: | wings|
Top 2th token. Logit: 15.94 Prob:  6.03% Token: | no|
Top 3th token. Logit: 14.92 Prob:  2.16% Token: | stripes|
Top 4th token. Logit: 14.69 Prob:  1.72% Token: | many|
Top 5th token. Logit: 14.63 Prob:  1.63% Token: | feathers|
Top 6th token. Logit: 14.48 Prob:  1.39% Token: | an|
Top 7th token. Logit: 14.42 Prob:  1.32% Token: | been|
Top 8th token. Logit: 14.36 Prob:  1.23% Token: | four|
Top 9th token. Logit: 14.20 Prob:  1.06% Token: | big|


In [196]:
prompt = "Once upon a time there was a boy named Bob. Bob loved playing games with words. 'Let's play a game' his mother said. 'I will tell you a word and you have to tell me the opposite. Ready?' 'Yes! That sounds fun', answered Bob." + \
    "His mother nodded. 'Ok, great. What is the opposite of loud?' Bob answered immediately: 'It's"

utils.test_prompt(prompt, " cold", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' playing', ' games', ' with', ' words', '.', " '", 'Let', "'s", ' play', ' a', ' game', "'", ' his', ' mother', ' said', '.', " '", 'I', ' will', ' tell', ' you', ' a', ' word', ' and', ' you', ' have', ' to', ' tell', ' me', ' the', ' opposite', '.', ' Ready', "?'", " '", 'Yes', '!', ' That', ' sounds', ' fun', "',", ' answered', ' Bob', '.', 'His', ' mother', ' nodded', '.', " '", 'Ok', ',', ' great', '.', ' What', ' is', ' the', ' opposite', ' of', ' loud', "?'", ' Bob', ' answered', ' immediately', ':', " '", 'It', "'s"]
Tokenized answer: [' cold']


Top 0th token. Logit: 16.88 Prob: 25.39% Token: | a|
Top 1th token. Logit: 16.30 Prob: 14.20% Token: | '|
Top 2th token. Logit: 15.63 Prob:  7.29% Token: | the|
Top 3th token. Logit: 15.48 Prob:  6.30% Token: | called|
Top 4th token. Logit: 15.41 Prob:  5.84% Token: | two|
Top 5th token. Logit: 15.40 Prob:  5.78% Token: | so|
Top 6th token. Logit: 14.42 Prob:  2.18% Token: | not|
Top 7th token. Logit: 14.41 Prob:  2.15% Token: | very|
Top 8th token. Logit: 14.21 Prob:  1.77% Token: | my|
Top 9th token. Logit: 13.86 Prob:  1.24% Token: | an|


In [99]:
prompt = "I know a lot about animals. For example, cows live in the"
prompt = "I know a lot about animals. For example, monkeys like to eat"
utils.test_prompt(prompt, " grass", model)

Tokenized prompt: ['<|endoftext|>', 'I', ' know', ' a', ' lot', ' about', ' animals', '.', ' For', ' example', ',', ' monkeys', ' like', ' to', ' eat']
Tokenized answer: [' grass']


Top 0th token. Logit: 24.13 Prob: 94.55% Token: | bananas|
Top 1th token. Logit: 19.70 Prob:  1.13% Token: | fruit|
Top 2th token. Logit: 19.17 Prob:  0.67% Token: | apples|
Top 3th token. Logit: 18.68 Prob:  0.41% Token: | fruits|
Top 4th token. Logit: 18.68 Prob:  0.41% Token: | peanuts|
Top 5th token. Logit: 18.30 Prob:  0.28% Token: | banana|
Top 6th token. Logit: 18.28 Prob:  0.27% Token: | the|
Top 7th token. Logit: 18.15 Prob:  0.24% Token: | nuts|
Top 8th token. Logit: 18.07 Prob:  0.22% Token: | a|
Top 9th token. Logit: 17.23 Prob:  0.10% Token: | leaves|


In [198]:
prompt = "Once upon a time there was a girl named Sarah. Sarah loved animals and really wanted a pet. Sarah wanted to get a giraffe or a rabbit. Her mother hates giraffes so she got Sarah a"
utils.test_prompt(prompt, " rabbit", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' girl', ' named', ' Sarah', '.', ' Sarah', ' loved', ' animals', ' and', ' really', ' wanted', ' a', ' pet', '.', ' Sarah', ' wanted', ' to', ' get', ' a', ' gir', 'affe', ' or', ' a', ' rabbit', '.', ' Her', ' mother', ' hates', ' gir', 'aff', 'es', ' so', ' she', ' got', ' Sarah', ' a']
Tokenized answer: [' rabbit']


Top 0th token. Logit: 17.73 Prob: 14.75% Token: | dog|
Top 1th token. Logit: 17.70 Prob: 14.34% Token: | gir|
Top 2th token. Logit: 16.65 Prob:  5.00% Token: | monkey|
Top 3th token. Logit: 16.46 Prob:  4.14% Token: | par|
Top 4th token. Logit: 16.44 Prob:  4.05% Token: | toy|
Top 5th token. Logit: 16.41 Prob:  3.92% Token: | pig|
Top 6th token. Logit: 16.32 Prob:  3.60% Token: | big|
Top 7th token. Logit: 16.23 Prob:  3.30% Token: | stuffed|
Top 8th token. Logit: 16.07 Prob:  2.81% Token: | pet|
Top 9th token. Logit: 15.95 Prob:  2.49% Token: | puppy|


In [200]:
prompt = 'Once upon a time there was a girl named Alice. Alice loved learning about different animals. "Are chicken smaller than cows?", Alice asked her mother. "'

yes_token = model.to_single_token("Yes")
no_token = model.to_single_token("No")
def get_yes_no_logits(prompt):
    logits = model(prompt, return_type="logits")
    yes_logits = logits[0, -1, yes_token].item()
    no_logits = logits[0, -1, no_token].item()
    return yes_logits - no_logits

print(get_yes_no_logits(prompt))

-0.8930225372314453


In [135]:
utils.test_prompt(prompt, "No", model, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '"', 'Are', ' cows', ' smaller', ' than', ' chicken', '?",', ' Alice', ' asked', ' her', ' mother', '.', ' "']
Tokenized answer: ['No']


Top 0th token. Logit: 18.77 Prob: 26.39% Token: |Yes|
Top 1th token. Logit: 18.10 Prob: 13.57% Token: |No|
Top 2th token. Logit: 17.47 Prob:  7.17% Token: |What|
Top 3th token. Logit: 17.30 Prob:  6.09% Token: |I|
Top 4th token. Logit: 17.28 Prob:  5.93% Token: |Why|
Top 5th token. Logit: 16.89 Prob:  4.02% Token: |That|
Top 6th token. Logit: 16.87 Prob:  3.96% Token: |It|
Top 7th token. Logit: 16.62 Prob:  3.08% Token: |We|
Top 8th token. Logit: 16.46 Prob:  2.62% Token: |The|
Top 9th token. Logit: 16.17 Prob:  1.97% Token: |Well|


In [202]:
prompt = "Once upon a time there was a girl named Alice. Alice had a carrot and a banana. She ate the banana. Then, she only had the"
#prompt = "Once upon a time there was a girl named Alice. Alice had a carrot and a banana. She gave the banana to Bob. Then, Bob had the"
utils.test_prompt(prompt, " banana", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' girl', ' named', ' Alice', '.', ' Alice', ' had', ' a', ' carrot', ' and', ' a', ' banana', '.', ' She', ' ate', ' the', ' banana', '.', ' Then', ',', ' she', ' only', ' had', ' the']
Tokenized answer: [' banana']


Top 0th token. Logit: 18.71 Prob: 20.61% Token: | apple|
Top 1th token. Logit: 18.05 Prob: 10.68% Token: | small|
Top 2th token. Logit: 17.12 Prob:  4.22% Token: | peel|
Top 3th token. Logit: 17.10 Prob:  4.13% Token: | little|
Top 4th token. Logit: 16.97 Prob:  3.61% Token: | carrot|
Top 5th token. Logit: 16.55 Prob:  2.39% Token: | two|
Top 6th token. Logit: 16.10 Prob:  1.52% Token: | left|
Top 7th token. Logit: 16.05 Prob:  1.45% Token: | original|
Top 8th token. Logit: 16.04 Prob:  1.43% Token: | banana|
Top 9th token. Logit: 15.97 Prob:  1.34% Token: | toy|


In [204]:
prompt = 'Once upon a time there was a boy named Jack. On Sunday, Jack visited a lot of family members. In the morning, Jack visited his grandmother.' + \
    ' At noon, he played with his father.' +\
    ' In the evening, he ate dinner with his mother.' + \
    ' The next day in school, he told his best friend about his weekend.' + \
    ' "Yesterday evening, I had dinner with my'

utils.test_prompt(prompt, " mother", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Jack', '.', ' On', ' Sunday', ',', ' Jack', ' visited', ' a', ' lot', ' of', ' family', ' members', '.', ' In', ' the', ' morning', ',', ' Jack', ' visited', ' his', ' grandmother', '.', ' At', ' noon', ',', ' he', ' played', ' with', ' his', ' father', '.', ' In', ' the', ' evening', ',', ' he', ' ate', ' dinner', ' with', ' his', ' mother', '.', ' The', ' next', ' day', ' in', ' school', ',', ' he', ' told', ' his', ' best', ' friend', ' about', ' his', ' weekend', '.', ' "', 'Yesterday', ' evening', ',', ' I', ' had', ' dinner', ' with', ' my']
Tokenized answer: [' mother']


Top 0th token. Logit: 21.07 Prob: 39.04% Token: | grandmother|
Top 1th token. Logit: 20.61 Prob: 24.77% Token: | family|
Top 2th token. Logit: 19.37 Prob:  7.18% Token: | grand|
Top 3th token. Logit: 18.85 Prob:  4.24% Token: | grandma|
Top 4th token. Logit: 18.65 Prob:  3.48% Token: | grandfather|
Top 5th token. Logit: 18.39 Prob:  2.67% Token: | friend|
Top 6th token. Logit: 18.37 Prob:  2.63% Token: | friends|
Top 7th token. Logit: 18.16 Prob:  2.14% Token: | father|
Top 8th token. Logit: 18.11 Prob:  2.02% Token: | dad|
Top 9th token. Logit: 18.08 Prob:  1.96% Token: | Grand|


In [205]:
prompt = "Once upon a time there was a girl named Alice. Alice went to nearby apple tree to collect some apples. Alice had four apples. She ate two of them. Then, she had exactly"
#prompt = "Alice had two apples. She found one more apple. Then, she had"
#prompt = "Jack had three apples. He gave two to Jill. Then, he had"
utils.test_prompt(prompt, " three", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' girl', ' named', ' Alice', '.', ' Alice', ' went', ' to', ' nearby', ' apple', ' tree', ' to', ' collect', ' some', ' apples', '.', ' Alice', ' had', ' four', ' apples', '.', ' She', ' ate', ' two', ' of', ' them', '.', ' Then', ',', ' she', ' had', ' exactly']
Tokenized answer: [' three']


Top 0th token. Logit: 20.84 Prob: 26.38% Token: | the|
Top 1th token. Logit: 20.37 Prob: 16.57% Token: | four|
Top 2th token. Logit: 20.25 Prob: 14.70% Token: | enough|
Top 3th token. Logit: 20.02 Prob: 11.69% Token: | one|
Top 4th token. Logit: 19.05 Prob:  4.44% Token: | where|
Top 5th token. Logit: 19.00 Prob:  4.21% Token: | five|
Top 6th token. Logit: 18.83 Prob:  3.56% Token: | ten|
Top 7th token. Logit: 18.77 Prob:  3.35% Token: | three|
Top 8th token. Logit: 18.53 Prob:  2.62% Token: | seven|
Top 9th token. Logit: 18.37 Prob:  2.24% Token: | what|


In [190]:
prompt = "Once upon a time, there was a fish called Craig. He was very different from the other fish. All the other fish were fast swimmers, but Craig was"

utils.test_prompt(prompt, " weak", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' fish', ' called', ' Craig', '.', ' He', ' was', ' very', ' different', ' from', ' the', ' other', ' fish', '.', ' All', ' the', ' other', ' fish', ' were', ' fast', ' sw', 'immers', ',', ' but', ' Craig', ' was']
Tokenized answer: [' weak']


Top 0th token. Logit: 21.87 Prob: 40.34% Token: | the|
Top 1th token. Logit: 20.64 Prob: 11.84% Token: | very|
Top 2th token. Logit: 20.62 Prob: 11.63% Token: | still|
Top 3th token. Logit: 19.97 Prob:  6.06% Token: | always|
Top 4th token. Logit: 19.43 Prob:  3.54% Token: | not|
Top 5th token. Logit: 19.18 Prob:  2.76% Token: | slow|
Top 6th token. Logit: 19.15 Prob:  2.67% Token: | special|
Top 7th token. Logit: 18.85 Prob:  1.97% Token: | a|
Top 8th token. Logit: 18.80 Prob:  1.88% Token: | really|
Top 9th token. Logit: 18.79 Prob:  1.86% Token: | fast|


## Dataset examples


In [4]:
def get_mlp_examples(dataset, model):
    dataset_max_activations = torch.zeros(len(dataset), model.cfg.n_layers, model.cfg.d_mlp)
    for i, prompt in tqdm(enumerate(dataset), total=len(dataset)):
        _, cache = model.run_with_cache(prompt)
        for layer in range(model.cfg.n_layers):
            cache_name = f"blocks.{layer}.mlp.hook_post"
            max_activations = cache[cache_name].max(dim=1).values
            dataset_max_activations[i, layer] = max_activations.flatten()
    return dataset_max_activations

dataset_max_activations = get_mlp_examples(dataset, model)

  0%|          | 0/21990 [00:00<?, ?it/s]

In [5]:
def print_story(prompt, layer, neuron):
    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(tokens)
    _, cache = model.run_with_cache(tokens)
    activations = cache[f"blocks.{layer}.mlp.hook_post"][0, :, neuron]
    haystack_utils.print_strings_as_html(str_tokens, activations, max_value=4)

def print_top_stories(max_activations, layer, neuron, k=5):
    dataset_activations = max_activations[:, layer, neuron]
    top_values, top_indices = torch.topk(dataset_activations, k)
    for index in top_indices:
        print_story(dataset[index], layer, neuron)

In [6]:
def interactive_print_story(layer, neuron, k):
    print_top_stories(dataset_max_activations, layer, neuron, k)

def increment_neuron(b):
    neuron_widget.value += 1

def decrement_neuron(b):
    neuron_widget.value -= 1

def random_neuron(b):
    neuron_widget.value = random.randint(0, model.cfg.d_mlp - 1)

layer_widget = widgets.IntSlider(value=0, min=0, max=3, step=1, description='Layer:')
neuron_widget = widgets.IntSlider(value=0, min=0, max=model.cfg.d_mlp - 1, step=1, description='Neuron:')
k_widget = widgets.IntSlider(value=1, min=1, max=20, step=1, description='Num stories:')
increment_button = widgets.Button(description="Next")
decrement_button = widgets.Button(description="Prev")
random_button = widgets.Button(description="Random Neuron")
increment_button.on_click(increment_neuron)
decrement_button.on_click(decrement_neuron)
random_button.on_click(random_neuron)

interactive_plot = widgets.interactive(interactive_print_story, layer=layer_widget, neuron=neuron_widget, k=k_widget)
buttons = widgets.HBox([decrement_button, increment_button, random_button])
display(buttons, interactive_plot)


HBox(children=(Button(description='Prev', style=ButtonStyle()), Button(description='Next', style=ButtonStyle()…

interactive(children=(IntSlider(value=0, description='Layer:', max=3), IntSlider(value=0, description='Neuron:…

In [79]:
layer = 3
neuron = 1657
neuron_weight = model.W_out[layer, neuron]
neuron_boosts = einops.einsum(neuron_weight, model.unembed.W_U, "d_mlp, d_mlp d_vocab -> d_vocab")
max_boosts, max_boosted = torch.topk(neuron_boosts, 20)
print(model.to_str_tokens(max_boosted))

[' random', ' vel', ' reflects', ' alcohol', ' plat', ' crosses', ' diseases', ' phot', ' rubbish', ' litter', ' mish', ' pepp', ' excellent', ' liquids', 'good', ' Blood', ' sap', ' compliments', ' other', ' hay']


In [None]:
activations = {}
prompts = {}
k = 5
pbar = tqdm(total=5*model.cfg.d_mlp*model.cfg.n_layers)
for layer in range(model.cfg.n_layers):
    activations[layer] = {}
    for neuron in range(model.cfg.d_mlp):
        activations[layer][neuron] = {}
        dataset_activations = dataset_max_activations[:, layer, neuron]
        top_values, top_indices = torch.topk(dataset_activations, k)
        for index in top_indices:
            prompt = dataset[index]
            tokens = model.to_tokens(prompt)
            if index not in prompts.keys():
                str_tokens = model.to_str_tokens(tokens)
                prompts[index] = str_tokens
            _, cache = model.run_with_cache(tokens)
            activation = cache[f"blocks.{layer}.mlp.hook_post"][0, :, neuron].tolist()
            activations[layer][neuron][index] = activation
            pbar.update(1)
pbar.close()

new_examples = {"activations": activations, "prompts": prompts}
file_path = "./data/tiny_stories_eval"
with open(file_path, 'w') as outfile:
    json.dump(new_examples, outfile)

## Find high DLA promtps

In [7]:
def get_neuron_dla(prompt, model:HookedTransformer):
    prompts = [prompt]
    tokens = model.to_tokens(prompts)
    _, cache = model.run_with_cache(tokens)
    W_U_token = model.W_U[:, tokens.flatten()]
    dla = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
    for layer in range(model.cfg.n_layers):
        W_out_U_token = model.W_out[layer] @ W_U_token
        neuron_dla = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1] * W_out_U_token[:, 1:].T
        scale = cache["ln_final.hook_scale"][0, :-1]
        neuron_dla = neuron_dla / scale
        dla[layer] = neuron_dla.max(dim=0).values
    return dla

def get_dla_examples(dataset, model):
    dataset_max_dlas = torch.zeros(len(dataset), model.cfg.n_layers, model.cfg.d_mlp).cpu()
    for i, prompt in tqdm(enumerate(dataset), total=len(dataset)):
        dataset_max_dlas[i] = get_neuron_dla(prompt, model)
    return dataset_max_dlas

#dataset_max_dla = get_dla_examples(dataset, model)

In [44]:
def get_tokenwise_dla(prompt, model, layer, neuron):
    prompts = [prompt]
    tokens = model.to_tokens(prompts)
    _, cache = model.run_with_cache(tokens)
    W_U_token = model.W_U[:, tokens.flatten()]
    W_out_U_token = model.W_out[layer] @ W_U_token
    neuron_dla = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1] * W_out_U_token[:, 1:].T
    scale = cache["ln_final.hook_scale"][0, :-1]
    neuron_dla = neuron_dla / scale
    return neuron_dla[:, neuron]

def get_max_neuron_dla_examples(layer, neuron, k=10):
    max_dlas = dataset_max_dla[:, layer, neuron]
    top_dlas, top_indices = torch.topk(max_dlas, k=k, dim=0)
    for top_index in top_indices:
        prompt = dataset[top_index]
        str_prompt = model.to_str_tokens(prompt)
        tokenwise_dla = [0] + get_tokenwise_dla(prompt, model, layer, neuron).tolist()
        haystack_utils.print_strings_as_html(str_prompt, tokenwise_dla, max_value=4)

#get_max_neuron_dla_examples(3, 0)

In [46]:
def interactive_print_dla(layer, neuron, k):
    get_max_neuron_dla_examples(layer, neuron, k)

def increment_neuron(b):
    neuron_widget.value += 1

def decrement_neuron(b):
    neuron_widget.value -= 1

def random_neuron(b):
    neuron_widget.value = random.randint(0, model.cfg.d_mlp - 1)

layer_widget = widgets.IntSlider(value=0, min=0, max=3, step=1, description='Layer:')
neuron_widget = widgets.IntSlider(value=0, min=0, max=model.cfg.d_mlp - 1, step=1, description='Neuron:')
k_widget = widgets.IntSlider(value=1, min=1, max=20, step=1, description='Num stories:')
increment_button = widgets.Button(description="Next")
decrement_button = widgets.Button(description="Prev")
random_button = widgets.Button(description="Random Neuron")
increment_button.on_click(increment_neuron)
decrement_button.on_click(decrement_neuron)
random_button.on_click(random_neuron)

interactive_plot = widgets.interactive(interactive_print_dla, layer=layer_widget, neuron=neuron_widget, k=k_widget)
buttons = widgets.HBox([decrement_button, increment_button, random_button])
display(buttons, interactive_plot)

HBox(children=(Button(description='Prev', style=ButtonStyle()), Button(description='Next', style=ButtonStyle()…

interactive(children=(IntSlider(value=0, description='Layer:', max=3), IntSlider(value=0, description='Neuron:…

## Variance


In [29]:
# Neuron vocab mean

n_tokens = torch.zeros(model.cfg.d_vocab)
neuron_means = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp, model.cfg.d_vocab)
for prompt in tqdm(dataset[:2]):
    tokens = model.to_tokens(prompt)
    n_tokens += tokens.shape[-1]
    _, cache = model.run_with_cache(tokens)
    for layer in range(model.cfg.n_layers):
        activations = cache[f"blocks.{layer}.mlp.hook_post"]
        for token_index in range(tokens.shape[-1]):
            #neuron_means[layer, :, tokens[0, token_index]] += activations[0, token_index].cpu()
            neuron_means[layer].index_add_(0, tokens[0, token_index], activations[0, token_index])


neuron_means /= n_tokens

  0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA_index_add_)

In [21]:
n_tokens = 0
neuron_variances = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp, model.cfg.d_vocab).cpu()
for prompt in tqdm(dataset[:2000]):
    tokens = model.to_tokens(prompt)
    n_tokens += tokens.shape[-1]
    _, cache = model.run_with_cache(tokens)
    for layer in range(model.cfg.n_layers):
        activations = cache[f"blocks.{layer}.mlp.hook_post"]
        for token_index in range(tokens.shape[-1]):
            diff = (activations[0, token_index].cpu()-neuron_means[layer, :, tokens[0, token_index]])**2
            neuron_variances[layer, :, tokens[0, token_index]] += diff
neuron_variances /= (n_tokens-1)

  0%|          | 0/2000 [00:00<?, ?it/s]

## Find ngrams


In [92]:
from collections import Counter
from nltk.util import ngrams

all_ngrams = {n:[] for n in range(10, 4, -1)}

for sentence in tqdm(dataset):
    tokens = model.to_str_tokens(sentence)
    for n in range(10, 4, -1):
        x_grams = ngrams(tokens, n)
        all_ngrams[n].extend(x_grams)

  0%|          | 0/21990 [00:00<?, ?it/s]

In [105]:
common_phrases = Counter(all_ngrams[5]).most_common(100)
for phrase in common_phrases:
    print(phrase)

(('<|endoftext|>', 'Once', ' upon', ' a', ' time'), 13526)
(('Once', ' upon', ' a', ' time', ','), 10649)
((' upon', ' a', ' time', ',', ' there'), 10035)
((' a', ' time', ',', ' there', ' was'), 9567)
((' time', ',', ' there', ' was', ' a'), 9360)
((',', ' there', ' was', ' a', ' little'), 5150)
(('\n', '\n', 'One', ' day', ','), 4480)
((' there', ' was', ' a', ' little', ' girl'), 4223)
(('.', '\n', '\n', 'One', ' day'), 4131)
((' was', ' a', ' little', ' girl', ' named'), 3366)
(('.', ' ', '\n', '\n', 'The'), 3040)
(('Once', ' upon', ' a', ' time', ' there'), 2848)
((' a', ' little', ' girl', ' named', ' Lily'), 2752)
((' upon', ' a', ' time', ' there', ' was'), 2534)
((' little', ' girl', ' named', ' Lily', '.'), 2519)
((' a', ' time', ' there', ' was', ' a'), 2408)
((' �', '�', '€', '�', '�'), 2370)
((' girl', ' named', ' Lily', '.', ' She'), 2326)
(('.', '\n', '\n', 'L', 'ily'), 2164)
((' smiled', ' and', ' said', ',', ' "'), 2125)
((' named', ' Lily', '.', ' She', ' loved'), 171

In [107]:
def get_mem_probs(prompt, model):
    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(prompt)
    answer_tokens = tokens[0, 1:].tolist()
    question_tokens = tokens[:-1]
    probs = model(tokens, return_type="logits").softmax(-1)[0]
    for i, token in enumerate(answer_tokens):
        print(f"{str_tokens[i]}->{str_tokens[i+1]}: {probs[i, token]:.2f}")

get_mem_probs("Once upon a time, there was a", model)

<|endoftext|>->Once: 0.00
Once-> upon: 1.00
 upon-> a: 1.00
 a-> time: 1.00
 time->,: 1.00
,-> there: 0.94
 there-> was: 0.99
 was-> a: 0.99


In [74]:
utils.test_prompt("Once upon a time,", "there", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',']
Tokenized answer: [' there']


Top 0th token. Logit: 27.54 Prob: 93.59% Token: | there|
Top 1th token. Logit: 24.62 Prob:  5.03% Token: | in|
Top 2th token. Logit: 23.27 Prob:  1.32% Token: | a|
Top 3th token. Logit: 19.14 Prob:  0.02% Token: | on|
Top 4th token. Logit: 17.94 Prob:  0.01% Token: | two|
Top 5th token. Logit: 17.77 Prob:  0.01% Token: | it|
Top 6th token. Logit: 17.11 Prob:  0.00% Token: | Tim|
Top 7th token. Logit: 17.05 Prob:  0.00% Token: | an|
Top 8th token. Logit: 16.56 Prob:  0.00% Token: | the|
Top 9th token. Logit: 16.14 Prob:  0.00% Token: | Tom|


## Sparsity

In [4]:
def get_pos_neuron_dla(prompt, model:HookedTransformer):
    prompts = [prompt]
    tokens = model.to_tokens(prompts)
    _, cache = model.run_with_cache(tokens)
    W_U_token = model.W_U[:, tokens.flatten()]
    dla = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp, tokens.shape[-1])
    for layer in range(model.cfg.n_layers):
        W_out_U_token = model.W_out[layer] @ W_U_token
        neuron_dla = cache[f"blocks.{layer}.mlp.hook_post"][0, :-1] * W_out_U_token[:, 1:].T
        scale = cache["ln_final.hook_scale"][0, :-1]
        neuron_dla = neuron_dla / scale
        dla[layer, :, 1:] = neuron_dla.T
    return dla

In [5]:
prompt = "Once upon a time, there was a"

dlas = get_pos_neuron_dla(prompt, model)
dlas.shape

torch.Size([4, 3072, 9])

In [6]:
_, cache = model.run_with_cache(prompt)
str_tokens = model.to_str_tokens(prompt)
decomp, labels = cache.get_full_resid_decomposition(layer=-1, expand_neurons=True, return_labels=True)


Tried to stack head results when they weren't cached. Computing head results now


In [24]:
pos = 3
prompt = "".join(str_tokens[1:pos+1])
answer_token = str_tokens[pos+1]
dir = model.tokens_to_residual_directions(model.to_single_token(answer_token))

per_layer_residual, labels = cache.get_full_resid_decomposition(layer=-1, pos_slice=pos, return_labels=True, expand_neurons=True)
scaled_residual_stack = cache.apply_ln_to_stack(per_layer_residual, layer = -1, pos_slice=pos)
per_layer_logit_diffs = einops.einsum(scaled_residual_stack, dir.unsqueeze(0), "components batch d_model, batch d_model -> components")

px.line(y=per_layer_logit_diffs.cpu().numpy(), hover_name=labels, title=f"Neuron DLA for \"{str_tokens[pos]}\"->\"{str_tokens[pos+1]}\" on 1M model")

In [46]:
index_to_neuron = lambda x: (x // model.cfg.d_mlp, x % model.cfg.d_mlp)
pos_neurons = [[]]
for pos in range(2, dlas.shape[-1]):
    pos_dla = dlas[:, :, pos].flatten()
    top_dlas, top_indices = torch.topk(pos_dla, 60)
    #print(top_dlas, top_indices)
    pos_neurons.append([index_to_neuron(x.item()) for x in top_indices])
    
print(pos_neurons[3])

[(1, 191), (3, 2447), (1, 75), (0, 2995), (1, 202), (3, 1489), (2, 2471), (3, 1009), (0, 1304), (2, 952), (0, 197), (3, 2009), (1, 2984), (1, 1089), (3, 262), (0, 983), (3, 342), (3, 658), (1, 2046), (1, 854), (2, 1515), (3, 1052), (3, 1558), (0, 2505), (2, 1368), (2, 2308), (1, 1901), (1, 1876), (1, 846), (3, 2007), (1, 1221), (1, 33), (3, 901), (0, 459), (1, 1381), (0, 1085), (0, 1322), (3, 485), (0, 1116), (1, 110), (0, 2949), (1, 1917), (1, 569), (3, 624), (2, 3053), (3, 1482), (3, 2215), (1, 2374), (0, 563), (0, 1371), (3, 70), (0, 2556), (0, 2918), (3, 155), (0, 1540), (2, 632), (1, 538), (2, 1477), (3, 465), (1, 2100)]


In [47]:
# Ablate neurons
original_loss = model(prompt, return_type="loss", loss_per_token=True).flatten()

def get_zero_ablate_hook(layer, neuron, pos):
    def hook_fn(value, hook):
        value[:, pos, neuron] = 0
        return value
    return [(f"blocks.{layer}.mlp.hook_post",hook_fn)]

def get_hooks_for_pos(neurons, pos_list):
    hooks = []
    for pos in pos_list:
        pos_neurons = neurons[pos]
        for layer, neuron in pos_neurons:
            hooks.extend(get_zero_ablate_hook(layer, neuron, pos))
    return hooks

ablate_neuron_hook = get_hooks_for_pos(pos_neurons, [3])
utils.test_prompt(prompt, answer_token, model)

with model.hooks(ablate_neuron_hook):
    ablated_loss = model(prompt, return_type="loss", loss_per_token=True).flatten()
    utils.test_prompt(prompt, answer_token, model)
print(f"Loss increase {(ablated_loss - original_loss)}")

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a']
Tokenized answer: [' time']


Top 0th token. Logit: 32.11 Prob: 100.00% Token: | time|
Top 1th token. Logit: 19.24 Prob:  0.00% Token: | tim|
Top 2th token. Logit: 18.28 Prob:  0.00% Token: | nice|
Top 3th token. Logit: 18.19 Prob:  0.00% Token: | ti|
Top 4th token. Logit: 18.01 Prob:  0.00% Token: | summer|
Top 5th token. Logit: 17.79 Prob:  0.00% Token: | small|
Top 6th token. Logit: 17.62 Prob:  0.00% Token: | day|
Top 7th token. Logit: 17.58 Prob:  0.00% Token: | Sunday|
Top 8th token. Logit: 17.29 Prob:  0.00% Token: | little|
Top 9th token. Logit: 17.15 Prob:  0.00% Token: | land|


Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a']
Tokenized answer: [' time']


Top 0th token. Logit: 16.99 Prob: 25.29% Token: | there|
Top 1th token. Logit: 16.78 Prob: 20.50% Token: | time|
Top 2th token. Logit: 15.23 Prob:  4.35% Token: | good|
Top 3th token. Logit: 14.97 Prob:  3.36% Token: | story|
Top 4th token. Logit: 14.97 Prob:  3.35% Token: | while|
Top 5th token. Logit: 14.78 Prob:  2.76% Token: | cat|
Top 6th token. Logit: 14.36 Prob:  1.81% Token: |,|
Top 7th token. Logit: 14.16 Prob:  1.49% Token: | mom|
Top 8th token. Logit: 14.11 Prob:  1.42% Token: | me|
Top 9th token. Logit: 14.01 Prob:  1.28% Token: | it|


Loss increase tensor([0., 0., 0.], device='cuda:0')


In [143]:
print(model.generate("Once", max_new_tokens=10, temperature=0, prepend_bos=True, verbose=False))
with model.hooks(ablate_neuron_hook):
    print(model.generate("Once", max_new_tokens=10, temperature=0, prepend_bos=True, verbose=False, use_past_kv_cache=False))

Once upon a time, there was a little girl named
Once, there was a little girl named Lily. She


## Low loss ngrams

In [51]:
def longest_true_streak(tensor):
    tensor_padded = torch.cat((torch.tensor([False], device=tensor.device), tensor, torch.tensor([False], device=tensor.device)))
    diff = tensor_padded[1:] != tensor_padded[:-1]
    indices = diff.nonzero(as_tuple=False).squeeze()
    streak_lengths = indices[1::2] - indices[::2]
    return int(streak_lengths.max().item()) if len(streak_lengths) > 0 else 0

cutoff = -math.log(0.989)
streak_lengths = []
for prompt in tqdm(dataset):
    loss = model(prompt, return_type="loss", loss_per_token=True).flatten()
    streak = longest_true_streak(loss <= cutoff)
    streak_lengths.append(streak)

px.histogram(streak_lengths)

  0%|          | 0/21990 [00:00<?, ?it/s]

In [61]:
top_streaks, top_indices = torch.topk(torch.LongTensor(streak_lengths), 50)

thank_you_prompts = []
def print_streaks(index):
    prompt = dataset[index]
    if "Thank you," in prompt:
        thank_you_prompts.append(prompt)
        str_tokens = model.to_str_tokens(prompt)[1:]
        loss = model(prompt, return_type="loss", loss_per_token=True).flatten()
        loss = (loss <= cutoff).to(torch.float32).tolist()
        haystack_utils.print_strings_as_html(str_tokens, loss, max_value=1)
        # answer_tokens = model.to_tokens(prompt)[0, 1:]
        # logprob = model(prompt, return_type="logits").softmax(dim=-1)
        # answer_probs = []
        # for pos in range(answer_tokens.shape[0]):
        #     answer_probs.append(logprob[0, pos, answer_tokens[pos]].item())
        # haystack_utils.print_strings_as_html(str_tokens, answer_probs, max_value=1)

for index in top_indices:
    print_streaks(index)

In [53]:
top_streaks, top_indices = torch.topk(torch.LongTensor(streak_lengths), 10000)#
len_tokens = 7
thank_you_prompts = []
for index in top_indices:
    prompt = dataset[index]
    pattern = r'"Thank you, \w+\.'
    match = re.search(pattern, prompt)
    if match:
        str_match = prompt[match.start():match.end()]
        if len(model.to_str_tokens(str_match)) == len_tokens:
            up_to_match = prompt[:match.end()]
            thank_you_prompts.append(up_to_match)
print(thank_you_prompts[:10])

['Tom and Amy are youth. They like to play with toys and books. They go to school every day. But today, school is delay. There is snow on the road. The bus cannot come. Tom and Amy are happy. They can stay at home and play more.\n\nThey put on their coats and boots. They go outside. They make a snowman. They give him a hat and a scarf. They make snowballs. They throw them at each other. They laugh and have fun.\n\nBut then, they hear a voice. It is Mom. She says, "Tom, Amy, come inside. It is time for lunch. You can play more later." Tom and Amy are sad. They want to play more. They say, "Mom, please, can we play more? Just a little more?"\n\nMom says, "No, you have to come inside. It is cold outside. You need to eat and warm up. The snow is mild today. It will not melt soon. You can play more after lunch. Come on, I made your favorite soup."\n\nTom and Amy listen to Mom. They say goodbye to their snowman. They promise to come back. They go inside. They take off their coats and boots. 

In [117]:
data = []
columns = ["\"", "Thank", "you", ",", "[name]", ".", "smile"]
for prompt in tqdm(thank_you_prompts):
    loss = model(prompt, return_type="loss", loss_per_token=True).flatten()
    smile = "smile" in prompt[-min(len(prompt), 100):]
    data.append(loss.tolist()[-len_tokens+1:] + [smile])

df = pd.DataFrame(data, columns=columns)
px.box(df, color="smile")

  0%|          | 0/307 [00:00<?, ?it/s]

In [85]:


def left_pad(prompts, model):
    tokens = model.to_tokens(prompts)
    target_length = tokens.shape[1]

    results = []
    for prompt in prompts:
        tokens = model.to_tokens(prompt)[0]
        padded_tokens = torch.cat([torch.zeros((target_length - tokens.shape[0],), dtype=int).cuda()+model.tokenizer.bos_token_id, tokens])
        results.append(padded_tokens)

    return torch.stack(results)

tokens = left_pad(thank_you_smile_data, model)

In [73]:
query = "Thank you, Ben."
index = thank_you_prompts[3].index(query)
prompt = thank_you_prompts[3][:index+len(query)]

loss = model(prompt, return_type="loss", loss_per_token=True).flatten()
px.line(loss.cpu().numpy()[-50:])

In [79]:
_, cache = model.run_with_cache(prompt)
str_tokens = model.to_str_tokens(prompt)
decomp, labels = cache.get_full_resid_decomposition(layer=-1, expand_neurons=True, return_labels=True)
pos = len(str_tokens)-4
answer_token = str_tokens[pos+1]
dir = model.tokens_to_residual_directions(model.to_single_token(answer_token))

per_layer_residual, labels = cache.get_full_resid_decomposition(layer=-1, pos_slice=pos, return_labels=True, expand_neurons=True)
scaled_residual_stack = cache.apply_ln_to_stack(per_layer_residual, layer = -1, pos_slice=pos)
per_layer_logit_diffs = einops.einsum(scaled_residual_stack, dir.unsqueeze(0), "components batch d_model, batch d_model -> components")

px.line(y=per_layer_logit_diffs.cpu().numpy(), hover_name=labels, title=f"Neuron DLA for \"{str_tokens[pos]}\"->\"{str_tokens[pos+1]}\" on 33M model")


Tried to stack head results when they weren't cached. Computing head results now


In [94]:

def get_zero_ablate_hook(layer, neuron, pos):
    def hook_fn(value, hook):
        value[:, :, neuron] = 0
        return value
    return [(f"blocks.{layer}.mlp.hook_post",hook_fn)]

def get_hooks_for_pos(neurons, pos_list):
    hooks = []
    for pos in pos_list:
        for layer, neuron in neurons:
            hooks.extend(get_zero_ablate_hook(layer, neuron, pos))
    return hooks

ablate_neuron_hook = get_hooks_for_pos([(3, 326), (3, 502), (3, 609)], [-2, -1])
utils.test_prompt(prompt[:-6], answer_token, model, prepend_space_to_answer=False)

with model.hooks(ablate_neuron_hook):
    utils.test_prompt(prompt[:-6], answer_token, model, prepend_space_to_answer=False)


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' are', ' friends', '.', ' They', ' like', ' to', ' play', ' with', ' toys', ' and', ' books', '.', ' One', ' day', ',', ' they', ' find', ' a', ' big', ' box', ' in', ' the', ' yard', '.', ' They', ' open', ' the', ' box', ' and', ' see', ' many', ' things', ' inside', '.', ' They', ' see', ' a', ' hat', ',', ' a', ' ball', ',', ' a', ' doll', ',', ' a', ' car', ',', ' and', ' a', ' bear', '.', '\n', '\n', 'S', 'ara', ' likes', ' the', ' bear', '.', ' It', ' is', ' soft', ' and', ' brown', '.', ' She', ' picks', ' it', ' up', ' and', ' hugs', ' it', '.', ' She', ' says', ',', ' "', 'This', ' is', ' my', ' bear', '.', ' I', ' love', ' it', '."', ' Ben', ' likes', ' the', ' car', '.', ' It', ' is', ' red', ' and', ' shiny', '.', ' He', ' picks', ' it', ' up', ' and', ' rolls', ' it', '.', ' He', ' says', ',', ' "', 'This', ' is', ' my', ' car', '.', ' I', ' can', ' go', ' fast', '."', '\n', '\n', 'They', ' play', ' with', '

Top 0th token. Logit: 34.27 Prob: 100.00% Token: |,|
Top 1th token. Logit: 22.79 Prob:  0.00% Token: | for|
Top 2th token. Logit: 22.74 Prob:  0.00% Token: |.|
Top 3th token. Logit: 21.03 Prob:  0.00% Token: | Ben|
Top 4th token. Logit: 19.42 Prob:  0.00% Token: | so|
Top 5th token. Logit: 18.40 Prob:  0.00% Token: |!|
Top 6th token. Logit: 17.86 Prob:  0.00% Token: | ,|
Top 7th token. Logit: 17.43 Prob:  0.00% Token: |."|
Top 8th token. Logit: 16.99 Prob:  0.00% Token: | very|
Top 9th token. Logit: 16.79 Prob:  0.00% Token: |Ben|


Tokenized prompt: ['<|endoftext|>', 'S', 'ara', ' and', ' Ben', ' are', ' friends', '.', ' They', ' like', ' to', ' play', ' with', ' toys', ' and', ' books', '.', ' One', ' day', ',', ' they', ' find', ' a', ' big', ' box', ' in', ' the', ' yard', '.', ' They', ' open', ' the', ' box', ' and', ' see', ' many', ' things', ' inside', '.', ' They', ' see', ' a', ' hat', ',', ' a', ' ball', ',', ' a', ' doll', ',', ' a', ' car', ',', ' and', ' a', ' bear', '.', '\n', '\n', 'S', 'ara', ' likes', ' the', ' bear', '.', ' It', ' is', ' soft', ' and', ' brown', '.', ' She', ' picks', ' it', ' up', ' and', ' hugs', ' it', '.', ' She', ' says', ',', ' "', 'This', ' is', ' my', ' bear', '.', ' I', ' love', ' it', '."', ' Ben', ' likes', ' the', ' car', '.', ' It', ' is', ' red', ' and', ' shiny', '.', ' He', ' picks', ' it', ' up', ' and', ' rolls', ' it', '.', ' He', ' says', ',', ' "', 'This', ' is', ' my', ' car', '.', ' I', ' can', ' go', ' fast', '."', '\n', '\n', 'They', ' play', ' with', '

Top 0th token. Logit: 30.88 Prob: 99.85% Token: |,|
Top 1th token. Logit: 23.81 Prob:  0.09% Token: |.|
Top 2th token. Logit: 23.24 Prob:  0.05% Token: | for|
Top 3th token. Logit: 21.67 Prob:  0.01% Token: | Ben|
Top 4th token. Logit: 19.04 Prob:  0.00% Token: |!|
Top 5th token. Logit: 19.02 Prob:  0.00% Token: | so|
Top 6th token. Logit: 18.58 Prob:  0.00% Token: |."|
Top 7th token. Logit: 17.68 Prob:  0.00% Token: |Ben|
Top 8th token. Logit: 17.55 Prob:  0.00% Token: | very|
Top 9th token. Logit: 17.35 Prob:  0.00% Token: | to|
