In [8]:
# %%
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import haystack_utils
from transformer_lens import utils


pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)


%reload_ext autoreload
%autoreload 2

In [161]:
model = HookedTransformer.from_pretrained(
    "roneneldan/TinyStories-33M",
    #center_unembed=True,
    #center_writing_weights=True,
    #fold_ln=True,
    #refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer


In [5]:
haystack_utils.print_tokenized_word(" The elephant is grey", model)

[' The', ' elephant', ' is', ' grey']


## Color of animals

In [70]:
#test = '"Strawberries are red", Bob said. "Correct", his mother answered. "Do you also know what color the sky is?" "Of course", Bob answered proudly, "it is'
test = "'Strawberries are red', Bob said. 'Correct', his mother answered. 'Do you also know what color the sky is?' 'Of course', Bob answered proudly, 'it is"

In [71]:
utils.test_prompt(test, " blue", model)

Tokenized prompt: ['<|endoftext|>', "'", 'St', 'raw', 'berries', ' are', ' red', "',", ' Bob', ' said', '.', " '", 'Correct', "',", ' his', ' mother', ' answered', '.', " '", 'Do', ' you', ' also', ' know', ' what', ' color', ' the', ' sky', ' is', "?'", " '", 'Of', ' course', "',", ' Bob', ' answered', ' proudly', ',', " '", 'it', ' is']
Tokenized answer: [' blue']


Top 0th token. Logit: 18.04 Prob: 33.08% Token: | blue|
Top 1th token. Logit: 17.46 Prob: 18.39% Token: | red|
Top 2th token. Logit: 16.04 Prob:  4.44% Token: | strawberry|
Top 3th token. Logit: 15.80 Prob:  3.51% Token: | a|
Top 4th token. Logit: 15.73 Prob:  3.26% Token: |!'|
Top 5th token. Logit: 15.70 Prob:  3.17% Token: | the|
Top 6th token. Logit: 15.23 Prob:  1.99% Token: | pink|
Top 7th token. Logit: 15.07 Prob:  1.70% Token: | purple|
Top 8th token. Logit: 14.92 Prob:  1.46% Token: |...|
Top 9th token. Logit: 14.90 Prob:  1.43% Token: | my|


In [72]:
objects = ["sky", "grass", "the sun", "an apple", "a carrot", "coal", "a strawberry", "the ocean", "a pumpkin", "a lemon", "a tomato"]
colors = ["blue", "green", "yellow", "red", "orange", "black", "red", "blue", "orange", "yellow", "red"]

def make_prompt(object, color):
    prompt = f"'Strawberries are red', Bob said. 'Correct', his mother answered. 'Do you also know what color {object} is?' 'Of course', Bob answered proudly, 'it is"
    answer = f" {color}"
    return prompt, answer

print(make_prompt(objects[0], colors[0]))

("'Strawberries are red', Bob said. 'Correct', his mother answered. 'Do you also know what color sky is?' 'Of course', Bob answered proudly, 'it is", ' blue')


In [73]:
def test_prompt(prompt, answer, model, object):
    logits = model(prompt)
    answer_token = model.to_single_token(answer)
    pred = logits[0, -1, :].softmax(dim=-1)
    prob = pred[answer_token].item()
    rank = (pred >= pred[answer_token]).sum().item()
    print(f"{object} -> {answer}: Rank {rank} (p={prob:.2f})")

for object, color in zip(objects, colors):
    prompt, answer = make_prompt(object, color)
    test_prompt(prompt, answer, model, object)

sky ->  blue: Rank 1 (p=0.32)
grass ->  green: Rank 3 (p=0.10)
the sun ->  yellow: Rank 9 (p=0.02)
an apple ->  red: Rank 1 (p=0.37)
a carrot ->  orange: Rank 5 (p=0.05)
coal ->  black: Rank 2 (p=0.14)
a strawberry ->  red: Rank 1 (p=0.28)
the ocean ->  blue: Rank 1 (p=0.30)
a pumpkin ->  orange: Rank 4 (p=0.09)
a lemon ->  yellow: Rank 2 (p=0.11)
a tomato ->  red: Rank 1 (p=0.43)


## Sizes of animals

In [191]:
def make_prompt(animal_1, animal_2):
    prompt = f'Once upon a time there was a boy named Bob. Bob loved animals and was proud that he knew all the sizes of his favorite animals. "A tiger is larger than a cat", Bob said. "Correct", his mother answered. "Do you also know if {animal_1} is larger than {animal_2}?" "Of course", Bob answered proudly, "The'
    return prompt

In [192]:
animals = ["dog", "cat", "cow", "horse", "sheep", "elephant", "lion", "tiger", "bear", "duck", "chicken", "fish", "turtle", "rabbit", "monkey"]

for animal in animals:
    haystack_utils.print_tokenized_word(" "+animal, model)

[' dog']
[' cat']
[' cow']
[' horse']
[' sheep']
[' elephant']
[' lion']
[' tiger']
[' bear']
[' duck']
[' chicken']
[' fish']
[' turtle']
[' rabbit']
[' monkey']


In [193]:
animal_pairs = [
    ("a horse", "a dog", " horse", " dog"),
    ("a bear", "a duck", " bear", " duck"),
    ("an elephant", "a rabbit", " elephant", " rabbit"),
    ("a sheep", "a bug", " sheep", " bug"),
    ("a cow", "a fish", " cow", " fish"),
]

def get_ranks(prompt, correct, incorrect):
    logits = model(prompt)
    answer_token = model.to_single_token(correct)
    incorrect_token = model.to_single_token(incorrect)
    pred = logits[0, -1, :].softmax(dim=-1)
    prob = pred[answer_token].item()
    rank = (pred >= pred[answer_token]).sum().item()
    incorrect_prob = pred[incorrect_token].item()
    incorrect_rank = (pred >= pred[incorrect_token]).sum().item()
    print(f"{correct}>{incorrect}: Rank {rank}/{incorrect_rank} (p={prob:.2f}, {incorrect_prob:.2f})")


for animal_1, animal_2, correct, incorrect in animal_pairs:
    prompt_1 = make_prompt(animal_1, animal_2)
    prompt_2 = make_prompt(animal_2, animal_1)
    correct_token = model.to_single_token(correct)
    incorrect_token = model.to_single_token(incorrect)

    get_ranks(prompt_1, correct, incorrect)
    get_ranks(prompt_2, correct, incorrect)

 horse> dog: Rank 1/5 (p=0.81, 0.01)
 horse> dog: Rank 54/1 (p=0.00, 0.34)
 bear> duck: Rank 1/2 (p=0.40, 0.13)
 bear> duck: Rank 4/1 (p=0.00, 0.91)
 elephant> rabbit: Rank 1/5 (p=0.74, 0.02)
 elephant> rabbit: Rank 2/1 (p=0.07, 0.74)
 sheep> bug: Rank 1/223 (p=0.32, 0.00)
 sheep> bug: Rank 271/3 (p=0.00, 0.09)
 cow> fish: Rank 1/27 (p=0.79, 0.00)
 cow> fish: Rank 9/7 (p=0.01, 0.01)


Ideas
- Opposites
- Categorize fruit or vegetable
- Match animals to habitates
- analogies

In [77]:
utils.test_prompt("A bird has wings and a cat has", " paws", model)

Tokenized prompt: ['<|endoftext|>', 'A', ' bird', ' has', ' wings', ' and', ' a', ' cat', ' has']
Tokenized answer: [' paws']


Top 0th token. Logit: 17.92 Prob: 43.50% Token: | a|
Top 1th token. Logit: 16.25 Prob:  8.20% Token: | wings|
Top 2th token. Logit: 15.94 Prob:  6.03% Token: | no|
Top 3th token. Logit: 14.92 Prob:  2.16% Token: | stripes|
Top 4th token. Logit: 14.69 Prob:  1.72% Token: | many|
Top 5th token. Logit: 14.63 Prob:  1.63% Token: | feathers|
Top 6th token. Logit: 14.48 Prob:  1.39% Token: | an|
Top 7th token. Logit: 14.42 Prob:  1.32% Token: | been|
Top 8th token. Logit: 14.36 Prob:  1.23% Token: | four|
Top 9th token. Logit: 14.20 Prob:  1.06% Token: | big|


In [196]:
prompt = "Once upon a time there was a boy named Bob. Bob loved playing games with words. 'Let's play a game' his mother said. 'I will tell you a word and you have to tell me the opposite. Ready?' 'Yes! That sounds fun', answered Bob." + \
    "His mother nodded. 'Ok, great. What is the opposite of loud?' Bob answered immediately: 'It's"

utils.test_prompt(prompt, " cold", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Bob', '.', ' Bob', ' loved', ' playing', ' games', ' with', ' words', '.', " '", 'Let', "'s", ' play', ' a', ' game', "'", ' his', ' mother', ' said', '.', " '", 'I', ' will', ' tell', ' you', ' a', ' word', ' and', ' you', ' have', ' to', ' tell', ' me', ' the', ' opposite', '.', ' Ready', "?'", " '", 'Yes', '!', ' That', ' sounds', ' fun', "',", ' answered', ' Bob', '.', 'His', ' mother', ' nodded', '.', " '", 'Ok', ',', ' great', '.', ' What', ' is', ' the', ' opposite', ' of', ' loud', "?'", ' Bob', ' answered', ' immediately', ':', " '", 'It', "'s"]
Tokenized answer: [' cold']


Top 0th token. Logit: 16.88 Prob: 25.39% Token: | a|
Top 1th token. Logit: 16.30 Prob: 14.20% Token: | '|
Top 2th token. Logit: 15.63 Prob:  7.29% Token: | the|
Top 3th token. Logit: 15.48 Prob:  6.30% Token: | called|
Top 4th token. Logit: 15.41 Prob:  5.84% Token: | two|
Top 5th token. Logit: 15.40 Prob:  5.78% Token: | so|
Top 6th token. Logit: 14.42 Prob:  2.18% Token: | not|
Top 7th token. Logit: 14.41 Prob:  2.15% Token: | very|
Top 8th token. Logit: 14.21 Prob:  1.77% Token: | my|
Top 9th token. Logit: 13.86 Prob:  1.24% Token: | an|


In [99]:
prompt = "I know a lot about animals. For example, cows live in the"
prompt = "I know a lot about animals. For example, monkeys like to eat"
utils.test_prompt(prompt, " grass", model)

Tokenized prompt: ['<|endoftext|>', 'I', ' know', ' a', ' lot', ' about', ' animals', '.', ' For', ' example', ',', ' monkeys', ' like', ' to', ' eat']
Tokenized answer: [' grass']


Top 0th token. Logit: 24.13 Prob: 94.55% Token: | bananas|
Top 1th token. Logit: 19.70 Prob:  1.13% Token: | fruit|
Top 2th token. Logit: 19.17 Prob:  0.67% Token: | apples|
Top 3th token. Logit: 18.68 Prob:  0.41% Token: | fruits|
Top 4th token. Logit: 18.68 Prob:  0.41% Token: | peanuts|
Top 5th token. Logit: 18.30 Prob:  0.28% Token: | banana|
Top 6th token. Logit: 18.28 Prob:  0.27% Token: | the|
Top 7th token. Logit: 18.15 Prob:  0.24% Token: | nuts|
Top 8th token. Logit: 18.07 Prob:  0.22% Token: | a|
Top 9th token. Logit: 17.23 Prob:  0.10% Token: | leaves|


In [198]:
prompt = "Once upon a time there was a girl named Sarah. Sarah loved animals and really wanted a pet. Sarah wanted to get a giraffe or a rabbit. Her mother hates giraffes so she got Sarah a"
utils.test_prompt(prompt, " rabbit", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' girl', ' named', ' Sarah', '.', ' Sarah', ' loved', ' animals', ' and', ' really', ' wanted', ' a', ' pet', '.', ' Sarah', ' wanted', ' to', ' get', ' a', ' gir', 'affe', ' or', ' a', ' rabbit', '.', ' Her', ' mother', ' hates', ' gir', 'aff', 'es', ' so', ' she', ' got', ' Sarah', ' a']
Tokenized answer: [' rabbit']


Top 0th token. Logit: 17.73 Prob: 14.75% Token: | dog|
Top 1th token. Logit: 17.70 Prob: 14.34% Token: | gir|
Top 2th token. Logit: 16.65 Prob:  5.00% Token: | monkey|
Top 3th token. Logit: 16.46 Prob:  4.14% Token: | par|
Top 4th token. Logit: 16.44 Prob:  4.05% Token: | toy|
Top 5th token. Logit: 16.41 Prob:  3.92% Token: | pig|
Top 6th token. Logit: 16.32 Prob:  3.60% Token: | big|
Top 7th token. Logit: 16.23 Prob:  3.30% Token: | stuffed|
Top 8th token. Logit: 16.07 Prob:  2.81% Token: | pet|
Top 9th token. Logit: 15.95 Prob:  2.49% Token: | puppy|


In [200]:
prompt = 'Once upon a time there was a girl named Alice. Alice loved learning about different animals. "Are chicken smaller than cows?", Alice asked her mother. "'

yes_token = model.to_single_token("Yes")
no_token = model.to_single_token("No")
def get_yes_no_logits(prompt):
    logits = model(prompt, return_type="logits")
    yes_logits = logits[0, -1, yes_token].item()
    no_logits = logits[0, -1, no_token].item()
    return yes_logits - no_logits

print(get_yes_no_logits(prompt))

-0.8930225372314453


In [135]:
utils.test_prompt(prompt, "No", model, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '"', 'Are', ' cows', ' smaller', ' than', ' chicken', '?",', ' Alice', ' asked', ' her', ' mother', '.', ' "']
Tokenized answer: ['No']


Top 0th token. Logit: 18.77 Prob: 26.39% Token: |Yes|
Top 1th token. Logit: 18.10 Prob: 13.57% Token: |No|
Top 2th token. Logit: 17.47 Prob:  7.17% Token: |What|
Top 3th token. Logit: 17.30 Prob:  6.09% Token: |I|
Top 4th token. Logit: 17.28 Prob:  5.93% Token: |Why|
Top 5th token. Logit: 16.89 Prob:  4.02% Token: |That|
Top 6th token. Logit: 16.87 Prob:  3.96% Token: |It|
Top 7th token. Logit: 16.62 Prob:  3.08% Token: |We|
Top 8th token. Logit: 16.46 Prob:  2.62% Token: |The|
Top 9th token. Logit: 16.17 Prob:  1.97% Token: |Well|


In [202]:
prompt = "Once upon a time there was a girl named Alice. Alice had a carrot and a banana. She ate the banana. Then, she only had the"
#prompt = "Once upon a time there was a girl named Alice. Alice had a carrot and a banana. She gave the banana to Bob. Then, Bob had the"
utils.test_prompt(prompt, " banana", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' girl', ' named', ' Alice', '.', ' Alice', ' had', ' a', ' carrot', ' and', ' a', ' banana', '.', ' She', ' ate', ' the', ' banana', '.', ' Then', ',', ' she', ' only', ' had', ' the']
Tokenized answer: [' banana']


Top 0th token. Logit: 18.71 Prob: 20.61% Token: | apple|
Top 1th token. Logit: 18.05 Prob: 10.68% Token: | small|
Top 2th token. Logit: 17.12 Prob:  4.22% Token: | peel|
Top 3th token. Logit: 17.10 Prob:  4.13% Token: | little|
Top 4th token. Logit: 16.97 Prob:  3.61% Token: | carrot|
Top 5th token. Logit: 16.55 Prob:  2.39% Token: | two|
Top 6th token. Logit: 16.10 Prob:  1.52% Token: | left|
Top 7th token. Logit: 16.05 Prob:  1.45% Token: | original|
Top 8th token. Logit: 16.04 Prob:  1.43% Token: | banana|
Top 9th token. Logit: 15.97 Prob:  1.34% Token: | toy|


In [204]:
prompt = 'Once upon a time there was a boy named Jack. On Sunday, Jack visited a lot of family members. In the morning, Jack visited his grandmother.' + \
    ' At noon, he played with his father.' +\
    ' In the evening, he ate dinner with his mother.' + \
    ' The next day in school, he told his best friend about his weekend.' + \
    ' "Yesterday evening, I had dinner with my'

utils.test_prompt(prompt, " mother", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' boy', ' named', ' Jack', '.', ' On', ' Sunday', ',', ' Jack', ' visited', ' a', ' lot', ' of', ' family', ' members', '.', ' In', ' the', ' morning', ',', ' Jack', ' visited', ' his', ' grandmother', '.', ' At', ' noon', ',', ' he', ' played', ' with', ' his', ' father', '.', ' In', ' the', ' evening', ',', ' he', ' ate', ' dinner', ' with', ' his', ' mother', '.', ' The', ' next', ' day', ' in', ' school', ',', ' he', ' told', ' his', ' best', ' friend', ' about', ' his', ' weekend', '.', ' "', 'Yesterday', ' evening', ',', ' I', ' had', ' dinner', ' with', ' my']
Tokenized answer: [' mother']


Top 0th token. Logit: 21.07 Prob: 39.04% Token: | grandmother|
Top 1th token. Logit: 20.61 Prob: 24.77% Token: | family|
Top 2th token. Logit: 19.37 Prob:  7.18% Token: | grand|
Top 3th token. Logit: 18.85 Prob:  4.24% Token: | grandma|
Top 4th token. Logit: 18.65 Prob:  3.48% Token: | grandfather|
Top 5th token. Logit: 18.39 Prob:  2.67% Token: | friend|
Top 6th token. Logit: 18.37 Prob:  2.63% Token: | friends|
Top 7th token. Logit: 18.16 Prob:  2.14% Token: | father|
Top 8th token. Logit: 18.11 Prob:  2.02% Token: | dad|
Top 9th token. Logit: 18.08 Prob:  1.96% Token: | Grand|


In [205]:
prompt = "Once upon a time there was a girl named Alice. Alice went to nearby apple tree to collect some apples. Alice had four apples. She ate two of them. Then, she had exactly"
#prompt = "Alice had two apples. She found one more apple. Then, she had"
#prompt = "Jack had three apples. He gave two to Jill. Then, he had"
utils.test_prompt(prompt, " three", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ' there', ' was', ' a', ' girl', ' named', ' Alice', '.', ' Alice', ' went', ' to', ' nearby', ' apple', ' tree', ' to', ' collect', ' some', ' apples', '.', ' Alice', ' had', ' four', ' apples', '.', ' She', ' ate', ' two', ' of', ' them', '.', ' Then', ',', ' she', ' had', ' exactly']
Tokenized answer: [' three']


Top 0th token. Logit: 20.84 Prob: 26.38% Token: | the|
Top 1th token. Logit: 20.37 Prob: 16.57% Token: | four|
Top 2th token. Logit: 20.25 Prob: 14.70% Token: | enough|
Top 3th token. Logit: 20.02 Prob: 11.69% Token: | one|
Top 4th token. Logit: 19.05 Prob:  4.44% Token: | where|
Top 5th token. Logit: 19.00 Prob:  4.21% Token: | five|
Top 6th token. Logit: 18.83 Prob:  3.56% Token: | ten|
Top 7th token. Logit: 18.77 Prob:  3.35% Token: | three|
Top 8th token. Logit: 18.53 Prob:  2.62% Token: | seven|
Top 9th token. Logit: 18.37 Prob:  2.24% Token: | what|


In [190]:
prompt = "Once upon a time, there was a fish called Craig. He was very different from the other fish. All the other fish were fast swimmers, but Craig was"

utils.test_prompt(prompt, " weak", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' fish', ' called', ' Craig', '.', ' He', ' was', ' very', ' different', ' from', ' the', ' other', ' fish', '.', ' All', ' the', ' other', ' fish', ' were', ' fast', ' sw', 'immers', ',', ' but', ' Craig', ' was']
Tokenized answer: [' weak']


Top 0th token. Logit: 21.87 Prob: 40.34% Token: | the|
Top 1th token. Logit: 20.64 Prob: 11.84% Token: | very|
Top 2th token. Logit: 20.62 Prob: 11.63% Token: | still|
Top 3th token. Logit: 19.97 Prob:  6.06% Token: | always|
Top 4th token. Logit: 19.43 Prob:  3.54% Token: | not|
Top 5th token. Logit: 19.18 Prob:  2.76% Token: | slow|
Top 6th token. Logit: 19.15 Prob:  2.67% Token: | special|
Top 7th token. Logit: 18.85 Prob:  1.97% Token: | a|
Top 8th token. Logit: 18.80 Prob:  1.88% Token: | really|
Top 9th token. Logit: 18.79 Prob:  1.86% Token: | fast|
