# Setup

In [1]:
!git clone https://github.com/amakelov/serimats.git

Cloning into 'serimats'...


Username for 'https://github.com': ^C


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the TransformerLens code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [12]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import pandas as pd


import itertools

import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens.patching as patching

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fd17c222da0>

In [16]:
model = HookedTransformer.from_pretrained("tiny-stories-33M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

Downloading (…)lve/main/config.json:   0%|          | 0.00/968 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/291M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-33M into HookedTransformer


In [9]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [10]:
MIXED_HAPPY_ENDING = [
  "Once upon a time, a little bunny lost its way in the forest. But then, a friendly squirrel showed the bunny the path back home.",
  "A little bird had a broken wing and couldn't fly. But a kind girl found the bird and took it to a vet who made its wing all better.",
  "A mischievous puppy accidentally knocked over a vase. The puppy's owner wasn't upset and gave the puppy a big hug instead.",
  "A little girl couldn't find her favorite toy. But her mom helped her search, and they found it hiding under the bed.",
  "A boy got a small cut on his knee while playing. His mom cleaned it with a magic ointment that made it feel better quickly.",
  "A little duckling got separated from its family at the pond. But then, a wise old turtle helped guide the duckling back to its parents.",
  "A boy dropped his ice cream on the ground. But the ice cream vendor saw what happened and gave him a new cone for free.",
  "A little girl's balloon flew away into the sky. But then, a kind man who saw it gave her a brand new balloon.",
  "A little frog got stuck in a puddle. But some friendly kids came by and gently lifted the frog out to safety.",
  "A girl's kite got caught in a tree. But then, a friendly dog ran up and tugged the string, freeing the kite.",
  "A little bear felt scared during a thunderstorm. But its mom hugged it tight and sang a comforting lullaby, and the bear felt safe.",
  "A girl spilled juice on her favorite dress. But her grandma said it was okay and washed the dress, and it looked good as new.",
  "A little bee got lost and couldn't find its hive. But then, a helpful butterfly showed the bee the way back home.",
  "A boy's soccer ball got stuck on the roof. But his dad climbed up and retrieved it, and they played soccer together.",
  "A little squirrel couldn't find any acorns to eat. But then, a friendly bird shared some of its seeds with the squirrel.",
  "A girl's puzzle pieces got mixed up. But her brother helped her sort them out, and they finished the puzzle together.",
  "A little snail was stuck in its shell and couldn't come out. But then, a gentle rain shower helped the snail slide out easily.",
  "A boy's favorite book had a torn page. But his sister fixed it with tape, and the book was whole again.",
  "A little butterfly got caught in a spider's web. But a clever spider helped the butterfly escape and fly away.",
  "A girl lost her tooth, and she couldn't find it under her pillow. But when she woke up, the tooth fairy had left a special gift for her.",
  "A little fish was stuck in a small puddle. But then, a friendly otter splashed water into the puddle, and the fish could swim again.",
  "A boy's toy car's wheels got stuck. But his dad fixed it and made it zoom around even faster.",
  "A little mouse couldn't find its way through a maze. But a wise owl gave the mouse helpful directions, and it found the exit.",
  "A girl's birthday balloon flew away. But then, a group of birds carried the balloon back to her as a surprise.",
  "A little caterpillar was hungry and couldn't find any leaves. But then, a kind gardener gave it a fresh leaf to eat.",
  "A boy's kite got tangled in a tree. But with the help of his friends, they untangled it and flew it high in the sky."
]

UNIFORM_HAPPY_ENDING = [
  "Once upon a time, a little bunny found a shiny treasure while hopping in the meadow. The bunny hopped back home happily.",
  "A little bird built a cozy nest on a tree branch and laid tiny eggs inside. Soon, little chicks hatched and chirped joyfully.",
  "A playful puppy found a new friend to run and jump with in the park. They wagged their tails and played all day long.",
  "A little girl discovered a colorful butterfly in her garden. It danced around her, spreading magic and making her smile.",
  "A boy found a lost teddy bear in the playground. He gave it a warm hug and returned it to its owner, who was very grateful.",
  "A baby duckling took its first swim in the pond, following its mother. The duckling paddled happily and quacked with joy.",
  "A kind farmer planted seeds in the field, and soon beautiful flowers bloomed everywhere, filling the air with sweet scents.",
  "A little kitten climbed a tree but got scared. The kitten's mom came to the rescue and gently carried it back down to safety.",
  "A boy and his friends had a picnic in the park, sharing yummy sandwiches and laughing together under the bright sun.",
  "A girl blew soap bubbles and watched them float in the air. The bubbles sparkled with rainbow colors, making her giggle.",
  "A squirrel found a big acorn and buried it in the ground, saving it for the winter. The squirrel's cheeks puffed up happily.",
  "A group of children built a sandcastle on the beach, complete with towers and a moat. They clapped their hands, proud of their creation.",
  "A little bee buzzed from flower to flower, collecting sweet nectar. It flew back to the hive and danced with joy, sharing the treasure.",
  "A boy rode his bicycle without training wheels for the first time. He pedaled confidently and shouted, 'Look, no hands!'",
  "A girl went on a treasure hunt in her backyard and found a shiny seashell. She held it to her ear and heard the sound of the ocean.",
  "A baby turtle hatched from its shell and crawled toward the ocean. It swam with tiny flippers and splashed happily in the waves.",
  "A boy and his dad went fishing and caught a big fish. They high-fived each other and released it back into the water.",
  "A girl had a tea party with her stuffed animals. They sat in a circle, sipped imaginary tea, and shared imaginary cake.",
  "A little frog hopped and landed on a lily pad. It croaked a cheerful song, and the other frogs joined in a chorus.",
  "A boy and his sister jumped in puddles after the rain. They laughed as water splashed, creating ripples all around them.",
  "A girl painted a beautiful picture with bright colors. She held it up, proud of her masterpiece, and her family applauded.",
  "A family of ducks swam together in a pond. They quacked happily and waddled out to explore the world with their little ducklings.",
  "A boy blew out candles on his birthday cake, surrounded by loved ones. They sang a birthday song, and he made a wish.",
  "A girl found a lost kitten and took it home. She gave it a cozy bed and a bowl of milk, and they became best friends.",
  "A baby elephant sprayed water from its trunk, creating a playful shower. It trumpeted joyfully, splashing everyone nearby.",
  "A boy and his mom planted a vegetable garden. They watched the plants grow tall and harvested delicious vegetables to share.",
]

UNIFORM_SAD_ENDING = [
    "Once upon a time, a little birdie lost its way and couldn't find its family again.",
    "A cute puppy couldn't find its favorite toy and felt very lonely.",
    "A tiny caterpillar worked hard to become a butterfly but couldn't fly as high as it wanted to.",
    "A little fish swam too far from its friends and got lost in the big ocean.",
    "A little girl planted a flower, but it didn't grow and made her feel sad.",
    "A little boy's ice cream fell on the ground before he could take a bite, and he cried.",
    "A colorful balloon floated away from a child's hand and disappeared into the sky.",
    "A little bunny lost its way in the dark forest and couldn't find its home.",
    "A toy bear got forgotten at the playground and nobody came to pick it up.",
    "A little duckling got separated from its mother and couldn't find her again.",
    "A small kitten got stuck in a tree and couldn't get down on its own.",
    "A little boy dropped his favorite teddy bear in a puddle, and it got all wet.",
    "A tiny snail wished it could run fast like the other animals but was too slow.",
    "A little girl's balloon popped suddenly, and she couldn't play with it anymore.",
    "A cute squirrel lost all its acorns and couldn't find food for the winter.",
    "A little girl's sandcastle got washed away by the big waves at the beach.",
    "A little boy's cookie broke in half before he could eat it, and he felt disappointed.",
    "A little spider tried to build a beautiful web but kept getting tangled up.",
    "A small bird tried to fly but fell from the tree branch and hurt its wing.",
    "A little girl's kite got stuck in a tree, and she couldn't fly it anymore.",
    "A little frog hopped away from its friends and got lost in the tall grass.",
    "A little girl's toy car got broken, and she couldn't play with it anymore.",
    "A small mouse lost its cheese and had nothing to eat for dinner.",
    "A little girl's favorite book had missing pages, and she couldn't finish the story.",
    "A little boy's bubble popped before he could catch it with his hands.",
    "A cute ladybug got blown away by the wind and couldn't find its way back.",
    "A little girl's puzzle had a missing piece, and she couldn't complete it.",
]

MIXED_SAD_ENDING = [  "Once upon a time, a little bunny found a shiny red balloon. He held it tightly and laughed as it floated higher and higher, but then it popped and made him cry.",  "There was a little girl who loved ice cream. She savored each lick of her chocolate cone, but then a bird swooped down and snatched it away, leaving her empty-handed.",  "In a cozy forest, a squirrel discovered a big acorn. He hugged it with joy, but then it rolled away and disappeared, making him feel lost.",  "A happy little duckling splashed in a puddle after the rain. It was fun until the water drained away, and the duckling felt lonely.",  "Once upon a time, a curious kitten found a colorful butterfly. It danced and fluttered around, but then it flew away, leaving the kitten feeling empty.",  "A friendly puppy discovered a bouncy ball. He chased it with excitement, but then it bounced into the river, making him sad.",  "In a sunny garden, a little girl picked a bouquet of flowers. She smiled as she held them, but then she accidentally dropped them, and they got all messy.",  "A mischievous monkey swung from vine to vine. It was thrilling, but then he slipped and fell, hurting his arm and making him cry.",  "Once upon a time, a tiny turtle found a shiny seashell. He admired its beauty, but then it cracked into pieces, making him feel disappointed.",  "A happy little bird built a nest high up in a tree. She chirped with joy, but then a strong wind blew it down, leaving her homeless.",  "In a magical forest, a friendly elf discovered a treasure chest. He opened it, expecting something amazing, but it was empty, and he felt let down.",  "A cheerful butterfly fluttered among the flowers. It was beautiful, but then it got stuck in a spider's web, and it couldn't fly away.",  "Once upon a time, a little boy flew a colorful kite. It soared in the sky, but then the string broke, and the kite flew away forever.",  "A happy little fish swam in a clear pond. It played with its friends, but then a big fish scared them away, leaving the little fish feeling scared.",  "In a vibrant garden, a caterpillar nibbled on tasty leaves. It felt satisfied, but then it turned into a chrysalis, and everything changed.",  "A playful squirrel gathered a pile of acorns. It was a great collection, but then a bigger squirrel took them all, leaving the little squirrel with nothing.",  "Once upon a time, a joyful bee buzzed from flower to flower. It collected nectar, but then it got stuck in a spider's web, unable to escape.",  "A happy little girl blew soap bubbles into the air. They floated and sparkled, but then they popped one by one, leaving her feeling empty.",  "In a peaceful meadow, a deer pranced with grace. It was a lovely sight, but then it got tangled in a thorny bush, hurting its leg.",  "A friendly puppy found a cozy spot to nap. It slept soundly, but then it woke up and couldn't find its way back home, feeling lost.",  "Once upon a time, a curious kitten climbed a tall tree. It explored the branches, but then it got stuck and couldn't come down, feeling scared.",  "A happy little snail slid across a shiny leaf. It was slow and steady, but then a bird swooped down and took it away, making the snail feel helpless.",  "In a beautiful garden, a ladybug landed on a flower petal. It was peaceful, but then a gust of wind blew it away, leaving it all alone.",  "A brave little mouse found a piece of cheese. It nibbled happily, but then a big cat appeared, and the mouse had to run away in fear.",  "Once upon a time, a contented squirrel found a secret stash of nuts. It had plenty to eat, but then a mischievous raccoon stole them, making the squirrel sad.",  "A happy little bird sang its sweet melody. The sound filled the air, but then it flew into a window, hurting itself and feeling scared.",  "In a magical forest, a playful fairy danced with joy. She twirled and spun, but then she accidentally tripped and fell, hurting her wing.",  "A friendly turtle found a cozy spot to rest. It closed its eyes, but then a noisy lawnmower scared it away, making the turtle feel unsettled."]

In [13]:
def tabulate_dataset():
  rows = []
  for example in MIXED_HAPPY_ENDING:
    rows.append({
        'x': example,
        'y': 1,
        'mixed': True
    })
  for example in UNIFORM_HAPPY_ENDING:
    rows.append({
        'x': example,
        'y': 1,
        'mixed': False
    })
  for example in MIXED_SAD_ENDING:
    rows.append({
        'x': example,
        'y': 0,
        'mixed': True
    })
  for example in UNIFORM_SAD_ENDING:
    rows.append({
        'x': example,
        'y': 0,
        'mixed': False
    })
  return pd.DataFrame(rows)

DF = tabulate_dataset()

In [14]:
DF[['mixed', 'y']].value_counts()

mixed  y
True   0    28
False  0    27
       1    26
True   1    26
Name: count, dtype: int64

In [15]:
examples = []
examples += [s + " The end! What a" for s in MIXED_SAD_ENDING]

In [17]:
def modify_prompt(p: str) -> str:
    return p + " The end! What a"

def eval_sentiment_prediction():
    happy_idx, sad_idx = model.to_single_token(' happy'), model.to_single_token(' sad')
    predictions = []
    labels = []
    for example, y, mixed in DF.itertuples(index=False):
      logits = model([modify_prompt(example)])[0, -1, :]
      prediction = (logits[happy_idx] - logits[sad_idx]) > 0
      predictions.append(prediction)
      labels.append(y)
    # accuracy
    return (torch.tensor(predictions) == torch.tensor(labels)).float().mean()

In [18]:
eval_sentiment_prediction()

tensor(0.9720)

In [24]:
from typing import Tuple

def get_important_neurons(A_train: np.ndarray, y_train: np.ndarray, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
  # score each activation dimension by the difference between the means of
  # the classes 0 and 1
  A_train, y_train = torch.Tensor(A_train), torch.Tensor(y_train)
  class_0_mean = A_train[y_train == 0].mean(dim=0)
  class_1_mean = A_train[y_train == 1].mean(dim=0)
  scores = (class_1_mean - class_0_mean).abs()
  top_k_scores, top_k_indices = torch.topk(scores, k)
  return top_k_indices, top_k_scores

def get_activations(examples: List[str], layer: int, act_type: str) -> torch.Tensor:
  activations = []
  for example in examples:
      _, cache = model.run_with_cache([modify_prompt(example)])
      if act_type == 'attn':
        act_label = f'blocks.{layer}.hook_resid_mid'
      elif act_type == 'mlp':
        act_label = f'blocks.{layer}.mlp.hook_post'
      act = cache[act_label][0, -1, :]
      activations.append(act)
  return torch.stack(activations, dim=0)

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

def train_sparse_probe(A: torch.Tensor, y: torch.Tensor, k: int, num_repeats: int) -> Tuple[float, float]:
  test_accs = []
  train_accs = []
  for i in range(num_repeats):
    A_train, A_test, y_train, y_test = train_test_split(A.cpu().numpy(), y.cpu().numpy(), test_size=0.4)
    top_k_indices, top_k_scores = get_important_neurons(A_train, y_train, k)
    model = LogisticRegression()
    topk = top_k_indices.cpu().numpy().tolist()
    A_train = A_train[:, topk]
    A_test = A_test[:, topk]
    model.fit(A_train, y_train)
    test_acc = model.score(A_test, y_test)
    train_acc = model.score(A_train, y_train)
    test_accs.append(test_acc)
    train_accs.append(train_acc)
  return sum(train_accs)/num_repeats, sum(test_accs)/num_repeats, top_k_indices, top_k_scores


In [20]:
A = get_activations(examples=DF['x'].values.tolist(), layer=1, act_type='attn')
y = torch.Tensor(DF['y'].values).cuda()

In [21]:
train_sparse_probe(A, y, k=10, num_repeats=20)

(0.93828125,
 0.8616279069767441,
 tensor([652,  63, 589, 691,   9, 753, 684, 115, 400, 644]),
 tensor([0.2445, 0.1681, 0.1606, 0.1319, 0.1201, 0.1186, 0.1130, 0.1075, 0.1064,
         0.1050]))

In [22]:
with open('TinyStories-valid.txt', 'r') as in_file:
    lines = in_file.readlines()
lines

DELIMETER = '<|endoftext|>\n'
stories = ''.join(lines).split(DELIMETER)
stories = [s.replace('\n', '') for s in stories]
for s in stories[:10]:
  print(s)

FileNotFoundError: [Errno 2] No such file or directory: 'TinyStories-valid.txt'

In [23]:
from transformers import pipeline
sentiment_pipeline = pipeline("sentiment-analysis")


No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


Downloading (…)lve/main/config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


In [None]:
data = stories[:100]
predictions = sentiment_pipeline(data)

In [None]:
for s, pred in zip(data, predictions):
  print(pred['label'], pred['score'])
  print('.\n'.join(s.split('.')))

In [None]:
POSITIVE = ['Lily is a girl with curly hair. One day she receives a bouquet of colorful balloons and feels',
 'Oliver is a boy with glasses. One day he jumps high on a trampoline and feels',
 'Grace is a girl with a pink dress. One day she plays with her favorite teddy bear and feels',
 'Henry is a boy with freckles on his face. One day he discovers a shiny seashell at the beach and feels',
 'Ella is a girl with braided hair. One day she swings high on the playground swings and feels',
 'James is a boy with a big smile. One day he helps his dad plant flowers in the garden and feels',
 'Chloe is a girl with a sparkly tiara. One day she dances to her favorite song and feels',
 'Benjamin is a boy with a superhero cape. One day he saves a little bird with a broken wing and feels',
 'Scarlett is a girl with a purple backpack. One day she finds a pretty feather on a nature walk and feels',
 'Wyatt is a boy with a red cap. One day he builds a tall tower with his building blocks and feels',
 'Hannah is a girl with long braids. One day she jumps in a puddle and feels',
 'Leo is a boy with a big imagination. One day he pretends to be a pirate and finds a hidden treasure, feeling',
 'Stella is a girl with a bright smile. One day she gets a new set of colorful markers and feels',
 'Samuel is a boy with a curious mind. One day he discovers a ladybug in his backyard and feels',
 'Penelope is a girl with a flower in her hair. One day she receives a sweet note from her best friend and feels',
 'Owen is a boy with a playful spirit. One day he builds a sandcastle at the beach and feels',
 'Violet is a girl with a vivid imagination. One day she creates a beautiful drawing and feels',
 'Daniel is a boy with a joyful laugh. One day he finds a shiny penny on the sidewalk and feels',
 'Ava is a girl with twinkling eyes. One day she learns how to ride a bike without training wheels and feels',
 'Noah is a boy with a mischievous grin. One day he splashes in a pool and feels',
 'Isabella is a girl with a charming giggle. One day she bakes cookies with her grandma and feels',
 'Ethan is a boy with boundless energy. One day he plays soccer with his friends and scores a goal, feeling',
 'Sophia is a girl with a gentle heart. One day she helps a lost puppy find its way home and feels',
 'Liam is a boy with a contagious smile. One day he receives a surprise gift from his parents and feels',
 'Emma is a girl with a skip in her step. One day she jumps rope with her friends and feels',
 'Jackson is a boy with a curious mind. One day he discovers a caterpillar and watches it turn into a beautiful butterfly, feeling',
 'Olivia is a girl with a twinkle in her eye. One day she receives a big hug from her favorite teacher and feels',
 'Mason is a boy with a friendly nature. One day he shares his toy with a new friend and feels',
 'Ava is a girl with a radiant smile. One day she plants flowers in her garden and watches them bloom, feeling',
 'Noah is a boy with an adventurous spirit. One day he explores a forest and discovers a hidden treasure, feeling',
 'Emily is a girl with a melodic voice. One day she sings her favorite song and feels',
 'Ethan is a boy with a playful imagination. One day he builds a fort with blankets and pillows and feels',
 'Olivia is a girl with a creative mind. One day she paints a beautiful rainbow and feels',
 'Liam is a boy with a curious nature. One day he finds a caterpillar and watches it transform into a butterfly, feeling',
 'Emma is a girl with a kind heart. One day she helps her little brother tie his shoelaces and feels',
 'Jackson is a boy with a brave spirit. One day he climbs to the top of a tall tree and feels',
 'Ava is a girl with a contagious laugh. One day she plays tag with her friends and feels',
 'Noah is a boy with a big imagination. One day he pretends to be a superhero and saves the day, feeling',
 'Emily is a girl with a cheerful demeanor. One day she blows bubbles in the park and feels',
 'Ethan is a boy with a mischievous smile. One day he jumps in a pile of leaves and feels',
 'Olivia is a girl with a loving heart. One day she shares her toys with children in need and feels',
 'Liam is a boy with a bright personality. One day he receives a gold star for his good work and feels',
 'Emma is a girl with a lively spirit. One day she dances in the rain and feels',
 'Jackson is a boy with a clever mind. One day he solves a puzzle and feels',
 'Ava is a girl with a compassionate nature. One day she helps an elderly neighbor carry groceries and feels',
 'Noah is a boy with a playful nature. One day he builds a sandcastle on the beach and feels',
 'Emily is a girl with an adventurous soul. One day she explores a forest and discovers a hidden waterfall, feeling',
 'Ethan is a boy with an inquisitive mind. One day he observes a colorful butterfly in his garden and feels',
 'Olivia is a girl with a joyful laugh. One day she plays with her favorite stuffed animal and feels',
 'Liam is a boy with a heart full of curiosity. One day he finds a shiny rock and adds it to his collection, feeling',
 'Emma is a girl with a bright imagination. One day she pretends to be a princess and has a royal tea party, feeling',
 'Jackson is a boy with a creative spark. One day he paints a picture and receives praise from his teacher, feeling',
 'Ava is a girl with a gentle spirit. One day she feeds the ducks at the park and feels',
 'Noah is a boy with a playful personality. One day he jumps on a trampoline and feels',
 'Emily is a girl with a sunny smile. One day she plays hide-and-seek with her friends and feels',
 'Ethan is a boy with a caring heart. One day he helps his mom set the table for dinner and feels',
 'Olivia is a girl with a contagious giggle. One day she dances to her favorite song and feels',
 'Liam is a boy with a clever mind. One day he solves a puzzle and feels',
 'Emma is a girl with a helpful nature. One day she shares her snack with a friend and feels',
 "Jackson is a boy with a curious spirit. One day he discovers a bird's nest in his backyard and feels",
 'Ava is a girl with a loving heart. One day she hugs her pet dog tightly and feels',
 'Noah is a boy with an adventurous soul. One day he rides his bike down a steep hill and feels',
 'Emily is a girl with a melodic voice. One day she sings her favorite song and feels',
 'Ethan is a boy with a playful imagination. One day he builds a fort with blankets and pillows and feels',
 'Olivia is a girl with a creative mind. One day she paints a beautiful rainbow and feels',
 'Liam is a boy with a curious nature. One day he finds a caterpillar and watches it transform into a butterfly, feeling',
 'Emma is a girl with a kind heart. One day she helps her little brother tie his shoelaces and feels',
 'Jackson is a boy with a brave spirit. One day he climbs to the top of a tall tree and feels',
 'Ava is a girl with a contagious laugh. One day she plays tag with her friends and feels',
 'Noah is a boy with a big imagination. One day he pretends to be a superhero and saves the day, feeling',
 'Emily is a girl with a cheerful demeanor. One day she blows bubbles in the park and feels',
 'Ethan is a boy with a mischievous smile. One day he jumps in a pile of leaves and feels',
 'Olivia is a girl with a loving heart. One day she shares her toys with children in need and feels',
 'Liam is a boy with a bright personality. One day he receives a gold star for his good work and feels',
 'Emma is a girl with a lively spirit. One day she dances in the rain and feels',
 'Jackson is a boy with a clever mind. One day he solves a puzzle and feels',
 'Ava is a girl with a compassionate nature. One day she helps an elderly neighbor carry groceries and feels',
 'Noah is a boy with a playful nature. One day he builds a sandcastle on the beach and feels',
 'Emily is a girl with an adventurous soul. One day she explores a forest and discovers a hidden waterfall, feeling',
 'Ethan is a boy with an inquisitive mind. One day he observes a colorful butterfly in his garden and feels',
 'Olivia is a girl with a joyful laugh. One day she plays with her favorite stuffed animal and feels',
 'Liam is a boy with a heart full of curiosity. One day he finds a shiny rock and adds it to his collection, feeling',
 'Emma is a girl with a bright imagination. One day she pretends to be a princess and has a royal tea party, feeling',
 'Jackson is a boy with a creative spark. One day he paints a picture and receives praise from his teacher, feeling',
 'Ava is a girl with a gentle spirit. One day she feeds the ducks at the park and feels',
 'Noah is a boy with a playful personality. One day he jumps on a trampoline and feels',
 'Emily is a girl with a sunny smile. One day she plays hide-and-seek with her friends and feels',
 'Ethan is a boy with a caring heart. One day he helps his mom set the table for dinner and feels',
 'Olivia is a girl with a contagious giggle. One day she dances to her favorite song and feels',
 'Liam is a boy with a clever mind. One day he solves a puzzle and feels',
 'Emma is a girl with a helpful nature. One day she shares her snack with a friend and feels',
 "Jackson is a boy with a curious spirit. One day he discovers a bird's nest in his backyard and feels",
 'Ava is a girl with a loving heart. One day she hugs her pet dog tightly and feels',
 'Noah is a boy with an adventurous soul. One day he rides his bike down a steep hill and feels',
 'Emily is a girl with a melodic voice. One day she sings her favorite song and feels',
 'Ethan is a boy with a playful imagination. One day he builds a fort with blankets and pillows and feels',
 'Olivia is a girl with a creative mind. One day she paints a beautiful rainbow and feels',
 'Liam is a boy with a curious nature. One day he finds a caterpillar and watches it transform into a butterfly, feeling',
 'Emma is a girl with a kind heart. One day she helps her little brother tie his shoelaces and feels',
 'Jackson is a boy with a brave spirit. One day he climbs to the top of a tall tree and feels']

NEGATIVE = ['Lily is a girl with curly hair. One day she trips and falls down, scraping her knee, and feels',
 'Oliver is a boy with glasses. One day he loses his favorite book at the park and feels',
 'Grace is a girl with a pink dress. One day she accidentally spills juice on her new dress and feels',
 'Henry is a boy with freckles on his face. One day he realizes he forgot his lunch at home and feels',
 'Ella is a girl with braided hair. One day she gets caught in the rain and her drawing gets wet, making her feel',
 "James is a boy with a big smile. One day he can't find his favorite toy car and feels",
 'Chloe is a girl with a sparkly tiara. One day she breaks her beautiful necklace and feels',
 'Benjamin is a boy with a superhero cape. One day he loses his superhero mask and feels',
 "Scarlett is a girl with a purple backpack. One day she can't find her favorite stuffed animal and feels",
 'Wyatt is a boy with a red cap. One day he accidentally spills his ice cream cone and feels',
 'Lily is a girl with long braids. One day she loses her favorite hair clip and feels',
 'Oliver is a boy with a bright smile. One day he misses the bus to school and feels',
 'Grace is a girl with a twinkle in her eyes. One day she drops her ice cream cone on the ground and feels',
 'Henry is a boy with a playful spirit. One day he loses his favorite ball and feels',
 'Ella is a girl with a colorful backpack. One day she forgets her art project at home and feels',
 'James is a boy with a mischievous grin. One day he accidentally breaks his toy robot and feels',
 'Chloe is a girl with a cheerful laugh. One day she rips her favorite dress and feels',
 "Benjamin is a boy with a curious mind. One day he can't find his missing puzzle piece and feels",
 'Scarlett is a girl with a vibrant personality. One day she loses her special bracelet and feels',
 'Wyatt is a boy with a playful nature. One day he drops his ice cream cone on the ground and feels',
 'Lily is a girl with a big imagination. One day she loses her favorite storybook and feels',
 'Oliver is a boy with a kind heart. One day he accidentally breaks his toy train and feels',
 'Grace is a girl with a creative mind. One day she spills paint on her artwork and feels',
 'Henry is a boy with a friendly nature. One day he loses his favorite action figure and feels',
 'Ella is a girl with a gentle spirit. One day she forgets her lunch box at home and feels',
 'James is a boy with a bright imagination. One day he loses his special drawing and feels',
 'Chloe is a girl with a curious nature. One day she breaks her favorite toy car and feels',
 'Benjamin is a boy with a playful personality. One day he loses his soccer ball and feels',
 'Scarlett is a girl with a sweet smile. One day she accidentally spills her juice box and feels',
 'Wyatt is a boy with an adventurous spirit. One day he loses his favorite hat and feels',
 'Lily is a girl with a lively personality. One day she loses her favorite doll and feels',
 'Oliver is a boy with a mischievous grin. One day he tears his favorite book and feels',
 'Grace is a girl with a kind heart. One day she forgets her umbrella on a rainy day and feels',
 'Henry is a boy with a curious mind. One day he loses his toy dinosaur and feels',
 'Ella is a girl with a cheerful laugh. One day she accidentally breaks her toy tea set and feels',
 'James is a boy with a bright smile. One day he forgets his lunch at home and feels',
 'Chloe is a girl with a playful spirit. One day she loses her favorite puzzle piece and feels',
 'Benjamin is a boy with a friendly nature. One day he accidentally spills his milk and feels',
 'Scarlett is a girl with a creative mind. One day she loses her art supplies and feels',
 'Wyatt is a boy with a big imagination. One day he loses his favorite superhero cape and feels',
 'Lily is a girl with a curious nature. One day she breaks her favorite pair of glasses and feels',
 'Oliver is a boy with a bright personality. One day he forgets his homework at home and feels',
 'Grace is a girl with a gentle heart. One day she loses her special necklace and feels',
 'Henry is a boy with a playful spirit. One day he accidentally breaks his toy car and feels',
 'Ella is a girl with a caring nature. One day she forgets her lunch money and feels',
 'James is a boy with a mischievous grin. One day he loses his favorite stuffed animal and feels',
 'Chloe is a girl with a vibrant personality. One day she spills juice on her new dress and feels',
 'Benjamin is a boy with a creative mind. One day he loses his art project and feels',
 'Scarlett is a girl with a playful nature. One day she accidentally breaks her toy doll and feels',
 'Wyatt is a boy with a bright imagination. One day he loses his special drawing and feels',
 'Lily is a girl with a cheerful demeanor. One day she forgets her favorite teddy bear at the park and feels',
 'Oliver is a boy with a curious mind. One day he loses his toy spaceship and feels',
 'Grace is a girl with a playful spirit. One day she accidentally spills paint on her new shoes and feels',
 'Henry is a boy with a mischievous smile. One day he loses his soccer ball and feels',
 'Ella is a girl with a bright personality. One day she breaks her favorite hairband and feels',
 'James is a boy with a creative imagination. One day he loses his special art set and feels',
 'Chloe is a girl with a gentle heart. One day she accidentally rips her favorite coloring book and feels',
 'Benjamin is a boy with a curious nature. One day he loses his special toy train and feels',
 'Scarlett is a girl with a playful spirit. One day she spills juice on her favorite stuffed animal and feels',
 'Wyatt is a boy with a cheerful demeanor. One day he forgets his lunch box at home and feels',
 'Lily is a girl with a mischievous grin. One day she loses her favorite hair clip and feels',
 'Oliver is a boy with a creative mind. One day he accidentally breaks his favorite toy car and feels',
 'Grace is a girl with a playful personality. One day she spills paint on her artwork and feels',
 'Henry is a boy with a curious nature. One day he loses his special ball and feels',
 'Ella is a girl with a gentle spirit. One day she forgets her lunch at home and feels',
 'James is a boy with a bright imagination. One day he loses his special drawing and feels',
 'Chloe is a girl with a vibrant personality. One day she breaks her favorite toy doll and feels',
 'Benjamin is a boy with a playful nature. One day he loses his soccer ball and feels',
 'Scarlett is a girl with a sweet smile. One day she accidentally spills her juice box and feels',
 'Wyatt is a boy with an adventurous spirit. One day he loses his favorite hat and feels',
 'Lily is a girl with a lively personality. One day she loses her favorite doll and feels',
 'Oliver is a boy with a mischievous grin. One day he breaks his favorite toy plane and feels',
 'Grace is a girl with a kind heart. One day she forgets her umbrella on a rainy day and feels',
 'Henry is a boy with a curious mind. One day he loses his toy dinosaur and feels',
 'Ella is a girl with a cheerful laugh. One day she accidentally spills her juice on her new dress and feels',
 'James is a boy with a bright smile. One day he forgets his lunch at home and feels',
 'Chloe is a girl with a playful spirit. One day she loses her favorite puzzle piece and feels',
 'Benjamin is a boy with a friendly nature. One day he accidentally spills his milk and feels',
 'Scarlett is a girl with a creative mind. One day she loses her art supplies and feels',
 'Wyatt is a boy with a big imagination. One day he loses his favorite superhero cape and feels',
 'Lily is a girl with a curious nature. One day she breaks her favorite pair of glasses and feels',
 'Oliver is a boy with a bright personality. One day he forgets his homework at home and feels',
 'Grace is a girl with a gentle heart. One day she loses her special necklace and feels',
 'Henry is a boy with a playful spirit. One day he accidentally breaks his toy car and feels',
 'Ella is a girl with a caring nature. One day she forgets her lunch money and feels',
 'James is a boy with a mischievous grin. One day he loses his favorite stuffed animal and feels',
 'Chloe is a girl with a vibrant personality. One day she spills juice on her new dress and feels',
 'Benjamin is a boy with a creative mind. One day he loses his art project and feels',
 'Scarlett is a girl with a playful nature. One day she accidentally breaks her toy doll and feels',
 'Wyatt is a boy with a bright imagination. One day he loses his special drawing and feels',
 'Lily is a girl with a cheerful demeanor. One day she forgets her favorite teddy bear at the park and feels',
 'Oliver is a boy with a curious mind. One day he loses his toy spaceship and feels',
 'Grace is a girl with a playful spirit. One day she accidentally spills paint on her new shoes and feels',
 'Henry is a boy with a mischievous smile. One day he loses his soccer ball and feels',
 'Ella is a girl with a bright personality. One day she breaks her favorite hairband and feels',
 'James is a boy with a creative imagination. One day he loses his special art set and feels',
 'Chloe is a girl with a gentle heart. One day she accidentally rips her favorite coloring book and feels',
 'Benjamin is a boy with a curious nature. One day he loses his special toy train and feels',
 'Scarlett is a girl with a playful spirit. One day she spills juice on her favorite stuffed animal and feels',
 'Wyatt is a boy with a cheerful demeanor. One day he forgets his lunch box at home and feels']

In [None]:
NEW_POSITIVE = POSITIVE + [x.replace(' and feels', ' but feels') for x in NEGATIVE if ' and feels' in x]
NEW_NEGATIVE = NEGATIVE + [x.replace(' and feels', ' but feels') for x in POSITIVE if ' and feels' in x]

In [None]:
len(NEW_POSITIVE), len(NEW_NEGATIVE)

(199, 182)

# Exploration

In [None]:
model = HookedTransformer.from_pretrained(
    "roneneldan/TinyStories-33M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer


In [None]:
def eval_sentiment_prediction(positive: List[str], negative: List[str]):
    happy_idx, sad_idx = model.to_single_token(' happy'), model.to_single_token(' sad')
    predictions = []
    labels = [1 for _ in positive] + [0 for _ in negative]
    for example in positive + negative:
      logits = model([example])[0, -1, :]
      prediction = (logits[happy_idx] - logits[sad_idx]) > 0
      predictions.append(prediction)
    # accuracy
    return (torch.tensor(predictions) == torch.tensor(labels)).float().mean()

In [None]:
eval_sentiment_prediction(NEW_POSITIVE, NEW_NEGATIVE)

tensor(0.7979)

In [None]:
logits, cache = model.run_with_cache(POSITIVE[0])
print(cache.keys())

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.hook_mlp_in', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hoo

In [None]:
cache['blocks.3.mlp.hook_post'].shape

torch.Size([1, 22, 3072])

In [None]:
from typing import Tuple

def get_important_neurons(A_train: np.ndarray, y_train: np.ndarray, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # score each activation dimension by the difference between the means of
        # the classes 0 and 1
        A_train, y_train = torch.Tensor(A_train), torch.Tensor(y_train)
        class_0_mean = A_train[y_train == 0].mean(dim=0)
        class_1_mean = A_train[y_train == 1].mean(dim=0)
        scores = (class_1_mean - class_0_mean).abs()
        top_k_scores, top_k_indices = torch.topk(scores, k)
        return top_k_indices, top_k_scores

def get_activations(examples: List[str], layer: int) -> torch.Tensor:
    activations = []
    for example in examples:
        _, cache = model.run_with_cache([example])
        act = cache[f'blocks.{layer}.mlp.hook_post'][0, -1, :]
        activations.append(act)
    return torch.stack(activations, dim=0)

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

def train_sparse_probe(A: torch.Tensor, y: torch.Tensor, k: int) -> Tuple[float, float]:
    A_train, A_test, y_train, y_test = train_test_split(A.cpu().numpy(), y.cpu().numpy(), test_size=0.2)
    top_k_indices, top_k_scores = get_important_neurons(A_train, y_train, k)
    model = LogisticRegression()
    topk = top_k_indices.cpu().numpy().tolist()
    A_train = A_train[:, topk]
    A_test = A_test[:, topk]
    model.fit(A_train, y_train)
    test_acc = model.score(A_test, y_test)
    train_acc = model.score(A_train, y_train)
    return train_acc, test_acc, top_k_indices, top_k_scores


In [None]:
A_per_layer = [get_activations(POSITIVE + NEGATIVE, layer=i) for i in range(4)]
y = torch.Tensor([1 for _ in POSITIVE] + [0 for _ in NEGATIVE]).cuda()

In [None]:
IMPORTANT_NEURONS = []
for i in range(4):
    train_acc, test_acc, top_k_indices, top_k_score = train_sparse_probe(A_per_layer[i], y, k=200)
    IMPORTANT_NEURONS.append(top_k_indices)

In [None]:
def eval_sentiment_prediction_logits(logits: torch.Tensor, labels: torch.Tensor):
    # HERE, LOGITS ARE JUST FROM POSITION -1!!!
    assert len(logits.shape) == 2
    happy_idx, sad_idx = model.to_single_token(' happy'), model.to_single_token(' sad')
    predictions = (logits[:, happy_idx] - logits[:, sad_idx]) > 0
    # accuracy
    return (predictions == labels).float().mean()

# IMPORTANT_NEURONS = [
#     [1405, 3029,  594, 1994, 2384, 2899, 1140, 2593,  132, 1551],
#     [2695,  973, 2258,  215, 1616, 1018, 2557, 2439, 1276,   22],
#     [2785, 2049, 2892, 1937,  160, 1514, 1634, 2028, 2588, 2355],
#     [2341, 1845, 1049, 1463,  112,  596, 1203, 1200,  260, 2827]
# ]

def generate_random_neurons(per_layer: int) -> List[List[int]]:
    pass

MEAN_POSITIVE_ACTIVATIONS = [x[:100].mean(dim=0) for x in A_per_layer]
MEAN_NEGATIVE_ACTIVATIONS = [x[100:].mean(dim=0) for x in A_per_layer]

def evaluate_with_hook():
    def generate_hook(layer: int, indices: List[int], label: int):
        patching_act = MEAN_POSITIVE_ACTIVATIONS if label == 0 else MEAN_NEGATIVE_ACTIVATIONS
        def hook(value, hook):
            value[:, -1, indices] = patching_act[layer][indices]
        return hook

    labels = torch.Tensor([1 for _ in POSITIVE] + [0 for _ in NEGATIVE]).cuda()
    all_logits = []
    for example in POSITIVE:
        fwd_hooks = [
            (f'blocks.{i}.mlp.hook_post', generate_hook(i, IMPORTANT_NEURONS[i], label=1))
            for i in range(4)
        ]
        logits = model.run_with_hooks(example, fwd_hooks=fwd_hooks)
        all_logits.append(logits[0, -1, :])
    for example in NEGATIVE:
        fwd_hooks = [
            (f'blocks.{i}.mlp.hook_post', generate_hook(i, IMPORTANT_NEURONS[i], label=0))
            for i in range(4)
        ]
        logits = model.run_with_hooks(example, fwd_hooks=fwd_hooks)
        all_logits.append(logits[0, -1, :])
    all_logits = torch.stack(all_logits, dim=0)
    acc = eval_sentiment_prediction_logits(all_logits, labels)
    return acc

In [None]:
evaluate_with_hook()

tensor(0.4800, device='cuda:0')



```
# This is formatted as code
```

# ITS NOT ABOUT MLPS!!!

In [None]:
cache.keys()

NameError: ignored

In [None]:
acts = []
for example in POSITIVE + NEGATIVE:
    _, cache = model.run_with_cache([example])
    act = cache['blocks.0.hook_resid_mid']
    acts.append(act[0, -1, :])
acts = torch.stack(acts, dim=0)
acts.shape
y = torch.Tensor([1 for _ in POSITIVE] + [0 for _ in NEGATIVE]).cuda()
A_train, A_test, y_train, y_test = train_test_split(acts.cpu().numpy(), y.cpu().numpy(), test_size=0.2)
lr = LogisticRegression()
lr.fit(A_train, y_train)
lr.score(A_train, y_train), lr.score(A_test, y_test)


(0.98125, 0.975)

In [None]:
# import numpy as np
# print(np.__version__)
# !pip uninstall numpy
# !pip install numpy==1.23.5
# print(np.__version__)

1.22.4
Found existing installation: numpy 1.24.3
Uninstalling numpy-1.24.3:
  Would remove:
    /usr/local/bin/f2py
    /usr/local/bin/f2py3
    /usr/local/bin/f2py3.10
    /usr/local/lib/python3.10/dist-packages/numpy-1.24.3.dist-info/*
    /usr/local/lib/python3.10/dist-packages/numpy.libs/libgfortran-040039e1.so.5.0.0
    /usr/local/lib/python3.10/dist-packages/numpy.libs/libopenblas64_p-r0-15028c96.3.21.so
    /usr/local/lib/python3.10/dist-packages/numpy.libs/libquadmath-96973f99.so.0.0.0
    /usr/local/lib/python3.10/dist-packages/numpy/*
Proceed (Y/n)? y
  Successfully uninstalled numpy-1.24.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
[31mERR

1.22.4


In [None]:
activations = get_activations(examples=POSITIVE[:10], layer=3)

In [None]:
torch.stack(activations, dim=0).shape

torch.Size([10, 3072])

In [None]:
#example_prompt = """A little boy named Ben and a small cat named Whiskers were playing. Ben found a rotten apple and was very upset. But Whiskers wanted to play with it and was sad. Then, Ben lost the marble and became sad, while Whiskers found a nest of cockroaches and was very"""
#example_prompt = "Bob has the cookie and is happy, but Alice doesn't have a cookie and is sad. Alice steals the cookie and is happy. Right now, Bob feels very"
#example_prompt = "Bob has the cookie and is happy, but Alice doesn't have a cookie and is sad. Alice steals the cookie and is happy. Right now, Alice feels very"
example_prompt = "Bob was a little boy living in an orchard. One day, Bob finds a tasty apple but feels"
example_answer = " happy"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Bob', ' was', ' a', ' little', ' boy', ' living', ' in', ' an', ' or', 'chard', '.', ' One', ' day', ',', ' Bob', ' finds', ' a', ' tasty', ' apple', ' but', ' feels']
Tokenized answer: [' happy']


Top 0th token. Logit: 18.43 Prob: 13.72% Token: | a|
Top 1th token. Logit: 18.42 Prob: 13.65% Token: | uncomfortable|
Top 2th token. Logit: 17.58 Prob:  5.87% Token: | bad|
Top 3th token. Logit: 17.58 Prob:  5.85% Token: | guilty|
Top 4th token. Logit: 17.55 Prob:  5.68% Token: | something|
Top 5th token. Logit: 17.38 Prob:  4.82% Token: | that|
Top 6th token. Logit: 17.36 Prob:  4.71% Token: | sad|
Top 7th token. Logit: 17.35 Prob:  4.67% Token: | very|
Top 8th token. Logit: 17.29 Prob:  4.41% Token: | scared|
Top 9th token. Logit: 16.86 Prob:  2.85% Token: | he|


In [None]:
prompts = [
    "Bob was a little boy living in an orchard. One day, Bob finds a tasty apple and feels",
    "Bob was a little boy living in an orchard. One day, Bob finds a rotten apple and feels",
    "Alice was a little boy living in an orchard. One day, Alice finds a tasty apple and feels",
    "Alice was a little boy living in an orchard. One day, Alice finds a rotten apple and feels",
    "Jim was a little boy living in an orchard. One day, Jim finds a tasty apple and feels",
    "Jim was a little boy living in an orchard. One day, Jim finds a rotten apple and feels",
    ]
answers = [(' happy', ' sad'), (' sad', ' happy'), (' happy', ' sad'), (' sad', ' happy'), (' happy', ' sad'), (' sad', ' happy')]

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|endoftext|>Bob was a little boy living in an orchard. One day, Bob finds a tasty apple and feels
Corrupted string 0 <|endoftext|>Bob was a little boy living in an orchard. One day, Bob finds a rotten apple and feels
Answer token indices tensor([[3772, 6507],
        [6507, 3772],
        [3772, 6507],
        [6507, 3772],
        [3772, 6507],
        [6507, 3772]], device='cuda:0')


In [None]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 3.7397
Corrupted logit diff: -3.7397


In [None]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def normalized_ld(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

print(f"Clean Baseline is 1: {normalized_ld(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {normalized_ld(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [None]:
# Whether to do the runs by head and by position, which are much slower

from neel_plotly import line, imshow, scatter
DO_SLOW_RUNS = True

In [None]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, normalized_ld)

  0%|          | 0/88 [00:00<?, ?it/s]

In [None]:
imshow(resid_pre_act_patch_results,
       yaxis="Layer",
       xaxis="Position",
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="resid_pre Activation Patching")

In [None]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, normalized_ld)


  0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
imshow(attn_head_out_all_pos_act_patch_results,
       yaxis="Layer",
       xaxis="Head",
       title="attn_head_out Activation Patching (All Pos)")

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([6, 2, 768])
Logit difference directions shape: torch.Size([6, 768])


In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([6, 22, 768])
Calculated average logit diff: 3.73970103263855
Original logit difference: 3.739701271057129


In [None]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)


In [None]:
accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

In [None]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, clean_cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

In [None]:
labels[8]

'3_attn_out'

In [None]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: Optional[ActivationCache]=None,
    local_tokens: Optional[torch.Tensor]=None,
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = clean_cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = clean_tokens[0]

    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [None]:
top_k = 5
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")
top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

pysvelte components appear to be unbuilt or stale
Running npm install...
Building pysvelte components with webpack...


In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, normalized_ld)


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
from neel_plotly import line, imshow, scatter
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=1, zmin=-1)
# [markdown]
# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail

## Path Patching

### Setup

In [None]:
try:
    from google.colab import drive # type: ignore
    %pip install transformer_lens
    %pip install gdown
    # %pip install plotly
    # %pip install jaxtyping
    # %pip install einops
    # %pip install protobuf==3.20.*
    import os
    import sys
    from pathlib import Path
    import gdown
    if not Path("ioi_dataset.py").resolve().exists():
        urls = {
            "ioi_dataset.py": "https://drive.google.com/uc?id=19UjxFnb6kztuhvz6dGAXjA9oRZmd84kC",
            "path_patching.py": "https://drive.google.com/uc?id=1duF7B3IjG_E5nGcjT_BuoSrSkynUhZI5",
        }
        for filename, url in urls.items():
            output = str(Path(filename).resolve())
            gdown.download(url, output)
except:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

import torch as t
from torch import Tensor
from tqdm.notebook import tqdm
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from rich import print as rprint
from transformer_lens import utils, HookedTransformer, ActivationCache

import torch as t
from typing import List, Union
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import re
from transformer_lens import utils

t.set_grad_enabled(False)

from ioi_dataset import NAMES, IOIDataset

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"}

def imshow(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    fig.show(renderer=renderer)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting typeguard<4.0.0,>=3.0.2 (from transformer_lens)
  Using cached typeguard-3.0.2-py3-none-any.whl (30 kB)
Installing collected packages: typeguard
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.13.3
    Uninstalling typeguard-2.13.3:
      Successfully uninstalled typeguard-2.13.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pysvelte 1.0.0 requires typeguard~=2.0, but you have typeguard 3.0.2 which is incompatible.[0m[31m
[0mSuccessfully installed typeguard-3.0.2


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Downloading...
From: https://drive.google.com/uc?id=19UjxFnb6kztuhvz6dGAXjA9oRZmd84kC
To: /content/ioi_dataset.py
100%|██████████| 26.0k/26.0k [00:00<00:00, 60.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1duF7B3IjG_E5nGcjT_BuoSrSkynUhZI5
To: /content/path_patching.py
100%|██████████| 23.8k/23.8k [00:00<00:00, 55.3MB/s]


## Patching

In [None]:
from path_patching import Component, MultiComponent, path_patch

In [None]:
def normalized_ld_2(
    logits: Float[Tensor, "batch seq d_vocab"],
    clean_logit_diff: float = CLEAN_BASELINE,
    corrupted_logit_diff: float = CORRUPTED_BASELINE,
) -> float:
    '''
    We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
    and -1 when performance has been destroyed (i.e. is same as ABC dataset).
    '''
    patched_logit_diff = get_logit_diff(logits)
    return (patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

In [None]:
NAME_MOVER_HEADS = [(3, 2), (3, 5)]

head_L9H9_to_resid_post_final = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_components=[Component("z", layer, head=head) for layer, head in NAME_MOVER_HEADS], # Output of all name mover heads
    receiver_components=Component("resid_post", 3), # This is resid_post at layer 11
    patching_metric=normalized_ld_2,
)

print(head_L9H9_to_resid_post_final)

tensor(-0.4132, device='cuda:0')


In [None]:
model.to_str_tokens(clean_tokens[0])[20]

' and'

In [None]:
LAYER_2_HEADS = [(2, 8), (2, 15)]

each_head_to_value_inputs_of_output_heads = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_components=MultiComponent("z", seq_pos="each"),
    receiver_components=[Component("v", layer, head=head, seq_pos=[20]) for layer, head in LAYER_2_HEADS],
    patching_metric=normalized_ld_2,
    verbose=True,
)

Patching over senders: z:   0%|          | 0/1408 [00:00<?, ?it/s]

Fixing receivers, iterating over senders
results[z].shape = (layer=4, seq_pos=22, head=16)


In [None]:

def imshow_2(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    fig.show(renderer=renderer)

In [None]:
each_head_to_value_inputs_of_output_heads['z'].shape

torch.Size([4, 22, 16])

'<|endoftext|>Bob was a little boy living in an orchard. One day, Bob finds a tasty apple and feels'

In [None]:
for seq_idx in range(22):
  words = model.to_str_tokens(clean_tokens[0])
  imshow_2(
      each_head_to_value_inputs_of_output_heads["z"][:2, seq_idx, :],
      title=f"{words[seq_idx]}",
      labels={"x": "Head", "y": "Layer", "color": "Logit diff<br>variation"},
      zmax=0.12, zmin=-0.12,
      border=True
  )

  LAYER_0_HEADS = [(0, 0), (0, 2), (0, 4), (0, 7)]