# Mood Binding Circuit

## Setup

In [1]:
!ls

README.md		    miniconda.sh  quick_start_pytorch.ipynb   wandb
eliciting-latent-sentiment  miniconda3	  quick_start_pytorch_images


In [2]:
%cd eliciting-latent-sentiment

/notebooks/eliciting-latent-sentiment


In [3]:
#!source activate circuits/bin/activate

In [7]:
#!pip install git+https://github.com/neelnanda-io/TransformerLens.git
!pip install git+https://github.com/glerzing/TransformerLens.git@stable_lm
!pip install circuitsvis
!pip install jaxtyping==0.2.13
!pip install einops
!pip install protobuf==3.20.*
!pip install plotly
!pip install torchtyping
!pip install git+https://github.com/neelnanda-io/neel-plotly.git

Collecting git+https://github.com/glerzing/TransformerLens.git@stable_lm
  Cloning https://github.com/glerzing/TransformerLens.git (to revision stable_lm) to /tmp/pip-req-build-0yipfvzo
  Running command git clone --filter=blob:none --quiet https://github.com/glerzing/TransformerLens.git /tmp/pip-req-build-0yipfvzo
  Running command git checkout -b stable_lm --track origin/stable_lm
  Switched to a new branch 'stable_lm'
  Branch 'stable_lm' set up to track remote branch 'stable_lm' from 'origin'.
  Resolved https://github.com/glerzing/TransformerLens.git to commit 049f56f810292ad05be77633898472c700ab8e27
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0mCollecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-b6fadbo6
  Running command git clone --filter=blob:none --quiet https://git

In [8]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [69]:
import os
import pathlib
from typing import List, Optional, Union

import torch
import numpy as np
import yaml
import pickle
import einops
from fancy_einsum import einsum


import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching

from torch import Tensor
from tqdm.notebook import tqdm
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from rich import print as rprint

from typing import List, Union
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import re

from functools import partial

from torchtyping import TensorType as TT

from path_patching import Node, IterNode, path_patch, act_patch
from neel_plotly import imshow as imshow_n

from utils.visualization import get_attn_head_patterns, imshow_p, plot_attention_heads, scatter_attention_and_contribution_simple

from utils.prompts import get_dataset
from utils.circuit_analysis import get_logit_diff, logit_diff_denoising, logit_diff_noising

In [70]:
import torch
torch.set_grad_enabled(False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [11]:
#import plotly
#plotly.offline.init_notebook_mode()

In [71]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Analysis


### TransformerLens Version

In [13]:
from transformers import AutoModelForCausalLM

def load_model(model_name):
    if model_name == "EleutherAI/pythia-6.9b" or model_name == "StabilityAI/stablelm-tuned-alpha-7b":
        source_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="model_cache").to('cpu').bfloat16()
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = HookedTransformer.from_pretrained(
            model_name,
            center_unembed=True,
            center_writing_weights=True,
            fold_ln=True,
            refactor_factored_attn_matrices=False,
            tokenizer=tokenizer,
            hf_model=source_model,
        )
    else:
        model = HookedTransformer.from_pretrained(
            model_name,
            center_unembed=True,
            center_writing_weights=True,
            fold_ln=True,
            refactor_factored_attn_matrices=False
        )
    return model


In [14]:
model = load_model("EleutherAI/pythia-2.8b")
model.set_use_hook_mlp_in(True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


### Initial Examination

In [16]:
example_prompt =    ("Jack, the cheerful monkey, and Sam, the happy parrot, found a treasure box in the jungle. They opened it to find shiny, golden bananas. Their eyes sparkled with joy as they danced and cheered. They spent the day sharing and eating their golden bananas. At the end of the day, Jack was feeling extremely")

print(example_prompt)
example_answer = " happy"

res = utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=10)

Jack, the cheerful monkey, and Sam, the happy parrot, found a treasure box in the jungle. They opened it to find shiny, golden bananas. Their eyes sparkled with joy as they danced and cheered. They spent the day sharing and eating their golden bananas. At the end of the day, Jack was feeling extremely
Tokenized prompt: ['<|endoftext|>', 'Jack', ',', ' the', ' cheerful', ' monkey', ',', ' and', ' Sam', ',', ' the', ' happy', ' par', 'rot', ',', ' found', ' a', ' treasure', ' box', ' in', ' the', ' jungle', '.', ' They', ' opened', ' it', ' to', ' find', ' shiny', ',', ' golden', ' ban', 'anas', '.', ' Their', ' eyes', ' spark', 'led', ' with', ' joy', ' as', ' they', ' danced', ' and', ' che', 'ered', '.', ' They', ' spent', ' the', ' day', ' sharing', ' and', ' eating', ' their', ' golden', ' ban', 'anas', '.', ' At', ' the', ' end', ' of', ' the', ' day', ',', ' Jack', ' was', ' feeling', ' extremely']
Tokenized answer: [' happy']


Top 0th token. Logit: 18.42 Prob: 20.29% Token: | happy|
Top 1th token. Logit: 18.12 Prob: 15.07% Token: | tired|
Top 2th token. Logit: 18.09 Prob: 14.51% Token: | full|
Top 3th token. Logit: 17.72 Prob: 10.06% Token: | hungry|
Top 4th token. Logit: 17.06 Prob:  5.20% Token: | fat|
Top 5th token. Logit: 16.52 Prob:  3.03% Token: | sleepy|
Top 6th token. Logit: 16.31 Prob:  2.46% Token: | good|
Top 7th token. Logit: 16.30 Prob:  2.45% Token: | satisfied|
Top 8th token. Logit: 15.93 Prob:  1.69% Token: | thirst|
Top 9th token. Logit: 15.78 Prob:  1.45% Token: | sad|


In [14]:
example_prompt =    ("Joe, the excited puppy, and Fred, the lazy cat, discovered a new toy. Joe wagged his tail and jumped around the toy with joy. Fred, however, looked at the toy and yawned, unimpressed. They spent the day with Joe playing energetically and Fred snoozing beside him. At the end of the day, Joe felt")
print(example_prompt)
example_answer = " happy"

utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=15)

Joe, the excited puppy, and Fred, the lazy cat, discovered a new toy. Joe wagged his tail and jumped around the toy with joy. Fred, however, looked at the toy and yawned, unimpressed. They spent the day with Joe playing energetically and Fred snoozing beside him. At the end of the day, Joe felt
Tokenized prompt: ['<|endoftext|>', 'Joe', ',', ' the', ' excited', ' puppy', ',', ' and', ' Fred', ',', ' the', ' lazy', ' cat', ',', ' discovered', ' a', ' new', ' toy', '.', ' Joe', ' w', 'agged', ' his', ' tail', ' and', ' jumped', ' around', ' the', ' toy', ' with', ' joy', '.', ' Fred', ',', ' however', ',', ' looked', ' at', ' the', ' toy', ' and', ' ya', 'wn', 'ed', ',', ' un', 'imp', 'ressed', '.', ' They', ' spent', ' the', ' day', ' with', ' Joe', ' playing', ' energet', 'ically', ' and', ' Fred', ' s', 'no', 'oz', 'ing', ' beside', ' him', '.', ' At', ' the', ' end', ' of', ' the', ' day', ',', ' Joe', ' felt']
Tokenized answer: [' happy']


Top 0th token. Logit: 16.22 Prob: 12.25% Token: | a|
Top 1th token. Logit: 15.69 Prob:  7.19% Token: | so|
Top 2th token. Logit: 15.62 Prob:  6.72% Token: | that|
Top 3th token. Logit: 15.47 Prob:  5.78% Token: | exhausted|
Top 4th token. Logit: 15.29 Prob:  4.83% Token: | he|
Top 5th token. Logit: 15.13 Prob:  4.10% Token: | the|
Top 6th token. Logit: 15.11 Prob:  4.05% Token: | tired|
Top 7th token. Logit: 15.04 Prob:  3.77% Token: | very|
Top 8th token. Logit: 14.91 Prob:  3.30% Token: | like|
Top 9th token. Logit: 14.54 Prob:  2.28% Token: | guilty|
Top 10th token. Logit: 14.26 Prob:  1.72% Token: | his|
Top 11th token. Logit: 14.25 Prob:  1.70% Token: | good|
Top 12th token. Logit: 14.09 Prob:  1.46% Token: | satisfied|
Top 13th token. Logit: 14.08 Prob:  1.44% Token: | happy|
Top 14th token. Logit: 14.07 Prob:  1.43% Token: | sad|


In [32]:
# Define the list of stories
stories_test = [
    {
        "original": "Jack, the cheerful monkey, and Sam, the happy parrot, found a treasure box in the jungle. They opened it to find shiny, golden bananas. Their eyes sparkled with joy as they danced and cheered. They spent the day sharing and eating their golden bananas. At the end of the day, ",
        "reversed": "Jack, the cheerful monkey, and Sam, the happy parrot, found a treasure box in the jungle. They opened it to find scary, creepy spiders. Their eyes widened in shock as they screamed and ran away. They spent the day hiding and recovering from the shock. At the end of the day, ",
        "characters": ["Jack", "Sam"]
    },
    # Continue with the other stories here...
]

stories_1 = [
    {
        "original": "Sam, the dog, and Bella, the cat, found a shiny toy. Sam wagged his tail and played with the toy. Bella, however, hissed at it and climbed a nearby tree. Sam spent the afternoon joyfully tossing the toy around, while Bella watched warily from above.",
        "reversed": "Sam, the dog, and Bella, the cat, stumbled upon a big, scary vacuum cleaner. Sam whined and hid behind the couch. Bella, however, seemed unbothered and walked towards it curiously. Sam spent the afternoon trembling in fear, while Bella confidently explored the unfamiliar machine.",
        "characters": ["Sam", "Bella"]
    },
    {
        "original": "Paul, the bunny, and Spikes, the porcupine, came across a field full of carrots. Paul hopped with joy and started nibbling on the carrots. Spikes, however, couldn't eat them and looked disappointed. Paul had a feast, while Spikes had to look for berries.",
        "reversed": "Paul, the bunny, and Spikes, the porcupine, found a bush full of berries. Paul made a face and hopped away. Spikes, however, was delighted and started to enjoy the berries. Paul had to look for something else, while Spikes had a feast.",
        "characters": ["Paul", "Spikes"]
    },
    {
        "original": "Larry, the bird, and Moe, the elephant, found a tree with ripe fruits. Larry chirped happily and flew up to eat. Moe, however, couldn't reach the fruits and felt upset. Larry had a wonderful meal, while Moe had to search for food on the ground.",
        "reversed": "Larry, the bird, and Moe, the elephant, came across a ground full of peanuts. Larry looked down and felt disappointed. Moe, however, was thrilled and started munching on the peanuts. Larry had to search for food in the trees, while Moe had a wonderful meal.",
        "characters": ["Larry", "Moe"]
    },
    {
        "original": "Finn, the fox, and Benny, the bear, found a hidden chicken coop. Finn licked his lips and snuck in to steal some eggs. Benny, however, got scared and stayed back. Finn enjoyed a delicious meal, while Benny went away hungry.",
        "reversed": "Finn, the fox, and Benny, the bear, encountered a beehive full of honey. Finn ran away to avoid getting stung. Benny, however, happily dug in to enjoy the honey. Finn went away hungry, while Benny enjoyed a delicious meal.",
        "characters": ["Finn", "Benny"]
    },
    {
        "original": "Tim, the mouse, and Jerry, the turtle, stumbled upon a block of cheese. Tim squeaked with delight and took a big bite. Jerry, however, couldn't eat it and felt sad. Tim had a great time eating, while Jerry slowly walked away.",
        "reversed": "Tim, the mouse, and Jerry, the turtle, discovered a pile of lettuce. Tim turned up his nose and walked away. Jerry, however, enjoyed the lettuce happily. Tim walked away, while Jerry had a great time eating.",
        "characters": ["Tim", "Jerry"]
    }
]

stories_2 = [
    {
        "original": "Lucy, the squirrel, and Max, the monkey, found a cozy nest on a tall tree. Lucy was delighted as it was a perfect spot for her to rest. Max, however, was scared of heights and stayed on the ground. Lucy had a warm, comfortable afternoon while Max looked around restlessly.",
        "reversed": "Lucy, the squirrel, and Max, the monkey, stumbled upon a dark, damp cave. Lucy was frightened by the darkness and stayed outside. Max, however, found it a great place to escape the heat and settled in comfortably. Lucy spent her afternoon anxiously, while Max had a cool, restful time.",
        "characters": ["Lucy", "Max"]
    },
    {
        "original": "Fred, the duck, and Spike, the cat, came across a beautiful pond. Fred splashed happily in the water, enjoying the sun. Spike, however, disliked water and sat grumpily on the side. Fred spent the day swimming joyfully, while Spike watched from a distance.",
        "reversed": "Fred, the duck, and Spike, the cat, found a patch of warm, soft sand. Fred was uneasy and missed his pond, but Spike rolled around, purring happily. Fred waddled around restlessly, while Spike had a wonderful time sunbathing.",
        "characters": ["Fred", "Spike"]
    },
    {
        "original": "Betty, the deer, and George, the wolf, encountered a lush, green meadow. Betty was thrilled and began to graze peacefully. George, however, was a carnivore and couldn't eat the grass, making him feel left out. Betty had a feast, while George had to continue his search for food.",
        "reversed": "Betty, the deer, and George, the wolf, found a site with leftover camp food. Betty was uneasy around human things, but George found some meat scraps and was delighted. Betty spent her time nervously exploring, while George enjoyed a hearty meal.",
        "characters": ["Betty", "George"]
    },
    {
        "original": "Milly, the goat, and Steve, the chicken, discovered a mountain of hay. Milly was excited and began to munch on the hay. Steve, however, was unable to eat hay and felt disappointed. Milly enjoyed her meal, while Steve pecked around fruitlessly.",
        "reversed": "Milly, the goat, and Steve, the chicken, found a box full of grains. Milly sniffed at it but didn't find it appetizing. Steve, on the other hand, was delighted and started pecking enthusiastically. Milly moved away to find something else, while Steve enjoyed his food.",
        "characters": ["Milly", "Steve"]
    },
    {
        "original": "Tommy, the rabbit, and Frankie, the lion, came across a huge burrow. Tommy was happy and quickly jumped into the burrow. Frankie, however, was too large to fit in the burrow and felt left out. Tommy had a great time exploring, while Frankie lay outside feeling lonely.",
        "reversed": "Tommy, the rabbit, and Frankie, the lion, found a spacious cave. Tommy was terrified of the dark, open space. Frankie, however, felt at home and comfortably walked in. Tommy stayed outside, nervously watching, while Frankie relaxed in the cool shade.",
        "characters": ["Tommy", "Frankie"]
    }
]

stories_3 = [
    {
        "original": "Ellie, the elephant, and Max, the mouse, discovered a large lake. Ellie was delighted and sprayed water with her trunk. Max, however, was afraid of the water and stayed away. Ellie had a fun-filled afternoon, while Max watched from a safe distance.",
        "reversed": "Ellie, the elephant, and Max, the mouse, stumbled upon a tiny hole. Ellie was disappointed as she couldn't fit, but Max was excited and quickly ran inside. Ellie spent her day feeling left out, while Max had a fun-filled afternoon exploring.",
        "characters": ["Ellie", "Max"]
    },
    {
        "original": "Bobby, the bird, and Terry, the tortoise, found a high tree. Bobby was thrilled and flew to the top, enjoying the view. Terry, however, couldn't climb and stayed on the ground. Bobby had an exciting day, while Terry felt a little disappointed.",
        "reversed": "Bobby, the bird, and Terry, the tortoise, discovered a large shell. Bobby tried to get in but couldn't, while Terry happily crawled inside. Bobby felt left out, while Terry had a relaxing, peaceful day.",
        "characters": ["Bobby", "Terry"]
    },
    {
        "original": "Sandy, the squirrel, and Peter, the peacock, found a shiny mirror. Sandy was curious and came closer to look, while Peter got scared of his reflection and ran away. Sandy had an amusing day, while Peter spent his time hiding.",
        "reversed": "Sandy, the squirrel, and Peter, the peacock, stumbled upon a large predator. Sandy was terrified and ran away, while Peter puffed up his feathers and scared the predator off. Sandy spent her day in fear, while Peter felt brave and victorious.",
        "characters": ["Sandy", "Peter"]
    },
    {
        "original": "Daisy, the dog, and Milo, the monkey, came across a busy playground. Daisy was ecstatic and ran around playing with the children. Milo, however, was overwhelmed by the noise and climbed a tree to escape. Daisy had a fantastic day, while Milo watched from a distance, feeling lonely.",
        "reversed": "Daisy, the dog, and Milo, the monkey, discovered a quiet, empty park. Daisy looked around for playmates but was disappointed, while Milo was relieved and began to play in the peaceful surroundings. Daisy spent her day feeling lonely, while Milo had a fantastic day playing by himself.",
        "characters": ["Daisy", "Milo"]
    },
    {
        "original": "Finn, the fish, and Tommy, the tiger, found a deep, clear pond. Finn was delighted and dived in, enjoying the cool water. Tommy, however, was afraid of water and stayed on the bank. Finn had an enjoyable time swimming, while Tommy just watched from the side.",
        "reversed": "Finn, the fish, and Tommy, the tiger, came across a large, dry savanna. Finn felt out of place and missed his pond, while Tommy felt right at home and started exploring. Finn spent his day wishing for water, while Tommy enjoyed his adventure.",
        "characters": ["Finn", "Tommy"]
    }
]

stories_4 = [
    {
        "original": "Penny, the puppy, and Whiskers, the cat, stumbled upon a bouncing ball. Penny was thrilled and started chasing the ball. Whiskers, however, was startled by the ball's movement and climbed a tree. Penny had a fun-filled day playing, while Whiskers watched warily from above.",
        "reversed": "Penny, the puppy, and Whiskers, the cat, found a quiet corner. Penny was bored without anything to play with, but Whiskers curled up and started purring. Penny wandered around restlessly, while Whiskers had a peaceful nap.",
        "characters": ["Penny", "Whiskers"]
    },
    {
        "original": "Daisy, the cow, and Benny, the bird, found a large grassy field. Daisy was delighted and started grazing, but Benny couldn't find any trees and felt upset. Daisy spent her day munching happily, while Benny flew around, searching for a perch.",
        "reversed": "Daisy, the cow, and Benny, the bird, came across a tall tree. Daisy couldn't climb and felt disappointed, but Benny happily nestled among the branches. Daisy spent her day wishing for grass, while Benny enjoyed the view from the treetop.",
        "characters": ["Daisy", "Benny"]
    },
    {
        "original": "Molly, the mole, and Frankie, the falcon, discovered a dark tunnel. Molly was excited and quickly burrowed inside, while Frankie, who preferred the sky, felt out of place. Molly had an adventurous day, while Frankie circled overhead, feeling lonely.",
        "reversed": "Molly, the mole, and Frankie, the falcon, found an open sky. Molly felt exposed and dug a hole for protection, but Frankie soared high, loving the freedom. Molly spent her day in hiding, while Frankie enjoyed the thrill of flying.",
        "characters": ["Molly", "Frankie"]
    },
    {
        "original": "Nina, the nightingale, and Sammy, the snake, found a tall tree full of ripe fruits. Nina sang in delight and pecked at the fruits, but Sammy, who ate only meat, watched with indifference. Nina enjoyed her sweet meal, while Sammy slithered away, still hungry.",
        "reversed": "Nina, the nightingale, and Sammy, the snake, discovered a small rodent. Nina flew away in fright, but Sammy quickly caught it and was contented. Nina spent her time searching for berries, while Sammy enjoyed his catch.",
        "characters": ["Nina", "Sammy"]
    },
    {
        "original": "Timmy, the tortoise, and Ricky, the rabbit, came across a slow-moving stream. Timmy was pleased and moved towards the water, but Ricky, who disliked water, hopped away. Timmy had a good time soaking, while Ricky nibbled on grass, eyeing the stream warily.",
        "reversed": "Timmy, the tortoise, and Ricky, the rabbit, found a dry, grassy meadow. Timmy missed the water and stayed put, but Ricky was delighted and started hopping around. Timmy spent his day longing for a pond, while Ricky had a delightful time exploring the meadow.",
        "characters": ["Timmy", "Ricky"]
    },
    # WEAK LOGIT DIFFS
    # {
    #     "original": "Barney, the bear, and Oliver, the owl, found a cozy cave. Barney was pleased and settled down for a nap, but Oliver preferred the treetops and flew away. Barney had a relaxing afternoon, while Oliver searched for a suitable tree.",
    #     "reversed": "Barney, the bear, and Oliver, the owl, came across a tall tree. Barney couldn't climb and felt left out, while Oliver was happy and flew to a high branch. Barney sat beneath the tree feeling lonely, while Oliver had a restful nap.",
    #     "characters": ["Barney", "Oliver"]
    # },
    # {
    #     "original": "Sally, the snail, and Jerry, the jackrabbit, discovered a dense, green garden. Sally was thrilled and began to explore slowly, but Jerry, who preferred open fields, felt constrained. Sally spent a delightful day among the leaves, while Jerry hopped around, feeling restless.",
    #     "reversed": "Sally, the snail, and Jerry, the jackrabbit, found an open, sandy path. Sally felt exposed and missed the greenery, but Jerry was delighted and began to run around energetically. Sally hid under a leaf, while Jerry had a fun day racing around.",
    #     "characters": ["Sally", "Jerry"]
    # },
    # INCORRECT LOGIT DIFFS
    # {
    #     "original": "Danny, the dolphin, and Larry, the leopard, came across a vast ocean. Danny was overjoyed and dived in, performing tricks, but Larry couldn't swim and stayed on the shore. Danny had a wonderful time in the water, while Larry watched from the sandy beach.",
    #     "reversed": "Danny, the dolphin, and Larry, the leopard, stumbled upon a dense jungle. Danny felt out of place and missed the sea, but Larry was in his element and started exploring. Danny stayed near a river, while Larry had an adventurous time in the forest.",
    #     "characters": ["Danny", "Larry"]
    # },
    # {
    #     "original": "Bella, the butterfly, and Gracie, the grasshopper, found a field full of flowers. Bella was ecstatic and fluttered from flower to flower, but Gracie preferred green leaves and felt disappointed. Bella had a vibrant day among the blossoms, while Gracie hopped around, feeling left out.",
    #     "reversed": "Bella, the butterfly, and Gracie, the grasshopper, discovered a leafy bush. Bella missed the colorful flowers and flew around restlessly, but Gracie was satisfied and began munching on the leaves. Bella spent her day looking for flowers, while Gracie enjoyed her leafy meal.",
    #     "characters": ["Bella", "Gracie"]
    # },
    # {
    #     "original": "Rosie, the raccoon, and Oscar, the otter, came across a shiny trash can. Rosie was fascinated and started rummaging through it, but Oscar preferred clean rivers and slipped away. Rosie had an interesting day exploring, while Oscar swam in the river, feeling relieved.",
    #     "reversed": "Rosie, the raccoon, and Oscar, the otter, found a clean, sparkling river. Rosie felt unsure and stayed on the bank, but Oscar was delighted and began swimming joyously. Rosie spent her day near the river, feeling cautious, while Oscar had a refreshing day in the water.",
    #     "characters": ["Rosie", "Oscar"]
    # }
]



story_counter = 0
# Iterate through each story
for story in stories_4:
    print("======================================================================================================")
    print(f"Story {story_counter}")
    print("======================================================================================================\n\n")
    
    print(f"Original Story:")
    print("======================================================================================================")
    print(f"{story['original']}\n")
    for character in story["characters"]:
        # Print the character's name
        print("---------------------------------------------------------")
        print(f"Character: {character}")
        print("---------------------------------------------------------")
        # Get the end of the original story
        original_end = story["original"] + character.capitalize() + " felt very"
        # Get the LLM prediction for the original story
        print(f"Mood Prediction:")
        original_prediction = utils.test_prompt(original_end, "", model, prepend_bos=True, top_k=5)
    
    print(f"Reversed Story:")
    print("======================================================================================================")
    print(f"{story['reversed']}\n")
    for character in story["characters"]:
        print("---------------------------------------------------------")
        print(f"Character: {character}")
        print("---------------------------------------------------------")
        print(f"Mood Prediction:")
        reversed_end = story["reversed"] + character.capitalize() + " felt very"
        # Get the LLM prediction for the reversed story
        reversed_prediction = utils.test_prompt(reversed_end, "", model, prepend_bos=True, top_k=5)
    story_counter += 1


Story 0


Original Story:
Penny, the puppy, and Whiskers, the cat, stumbled upon a bouncing ball. Penny was thrilled and started chasing the ball. Whiskers, however, was startled by the ball's movement and climbed a tree. Penny had a fun-filled day playing, while Whiskers watched warily from above.

---------------------------------------------------------
Character: Penny
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'P', 'enny', ',', ' the', ' puppy', ',', ' and', ' Wh', 'isk', 'ers', ',', ' the', ' cat', ',', ' stumbled', ' upon', ' a', ' bouncing', ' ball', '.', ' Penny', ' was', ' thrilled', ' and', ' started', ' chasing', ' the', ' ball', '.', ' Wh', 'isk', 'ers', ',', ' however', ',', ' was', ' startled', ' by', ' the', ' ball', "'s", ' movement', ' and', ' climbed', ' a', ' tree', '.', ' Penny', ' had', ' a', ' fun', '-', 'filled', ' day', ' playing', ',', ' while', ' Wh', 'isk', 'ers', ' watched', ' war', 'ily', 

Top 0th token. Logit: 18.07 Prob: 21.51% Token: | proud|
Top 1th token. Logit: 17.48 Prob: 11.98% Token: | happy|
Top 2th token. Logit: 16.75 Prob:  5.75% Token: | good|
Top 3th token. Logit: 16.72 Prob:  5.60% Token: | sorry|
Top 4th token. Logit: 16.57 Prob:  4.80% Token: | sad|


---------------------------------------------------------
Character: Whiskers
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'P', 'enny', ',', ' the', ' puppy', ',', ' and', ' Wh', 'isk', 'ers', ',', ' the', ' cat', ',', ' stumbled', ' upon', ' a', ' bouncing', ' ball', '.', ' Penny', ' was', ' thrilled', ' and', ' started', ' chasing', ' the', ' ball', '.', ' Wh', 'isk', 'ers', ',', ' however', ',', ' was', ' startled', ' by', ' the', ' ball', "'s", ' movement', ' and', ' climbed', ' a', ' tree', '.', ' Penny', ' had', ' a', ' fun', '-', 'filled', ' day', ' playing', ',', ' while', ' Wh', 'isk', 'ers', ' watched', ' war', 'ily', ' from', ' above', '.', 'Wh', 'isk', 'ers', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.10 Prob: 10.25% Token: | lonely|
Top 1th token. Logit: 17.10 Prob: 10.16% Token: | sad|
Top 2th token. Logit: 17.07 Prob:  9.93% Token: | sorry|
Top 3th token. Logit: 16.45 Prob:  5.31% Token: | jealous|
Top 4th token. Logit: 16.26 Prob:  4.42% Token: | bad|


Reversed Story:
Penny, the puppy, and Whiskers, the cat, found a quiet corner. Penny was bored without anything to play with, but Whiskers curled up and started purring. Penny wandered around restlessly, while Whiskers had a peaceful nap.

---------------------------------------------------------
Character: Penny
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'P', 'enny', ',', ' the', ' puppy', ',', ' and', ' Wh', 'isk', 'ers', ',', ' the', ' cat', ',', ' found', ' a', ' quiet', ' corner', '.', ' Penny', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Wh', 'isk', 'ers', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Penny', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Wh', 'isk', 'ers', ' had', ' a', ' peaceful', ' nap', '.', 'P', 'enny', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.41 Prob: 12.96% Token: | lonely|
Top 1th token. Logit: 16.96 Prob:  8.25% Token: | sad|
Top 2th token. Logit: 16.84 Prob:  7.30% Token: | sorry|
Top 3th token. Logit: 16.24 Prob:  4.01% Token: | happy|
Top 4th token. Logit: 16.23 Prob:  3.96% Token: | guilty|


---------------------------------------------------------
Character: Whiskers
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'P', 'enny', ',', ' the', ' puppy', ',', ' and', ' Wh', 'isk', 'ers', ',', ' the', ' cat', ',', ' found', ' a', ' quiet', ' corner', '.', ' Penny', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Wh', 'isk', 'ers', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Penny', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Wh', 'isk', 'ers', ' had', ' a', ' peaceful', ' nap', '.', 'Wh', 'isk', 'ers', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.19 Prob: 10.44% Token: | comfortable|
Top 1th token. Logit: 16.97 Prob:  8.42% Token: | happy|
Top 2th token. Logit: 16.88 Prob:  7.68% Token: | safe|
Top 3th token. Logit: 16.54 Prob:  5.49% Token: | content|
Top 4th token. Logit: 16.45 Prob:  4.97% Token: | warm|


Story 1


Original Story:
Daisy, the cow, and Benny, the bird, found a large grassy field. Daisy was delighted and started grazing, but Benny couldn't find any trees and felt upset. Daisy spent her day munching happily, while Benny flew around, searching for a perch.

---------------------------------------------------------
Character: Daisy
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'D', 'aisy', ',', ' the', ' cow', ',', ' and', ' B', 'enny', ',', ' the', ' bird', ',', ' found', ' a', ' large', ' grass', 'y', ' field', '.', ' Daisy', ' was', ' delighted', ' and', ' started', ' grazing', ',', ' but', ' B', 'enny', ' couldn', "'t", ' find', ' any', ' trees', ' and', ' felt', ' upset', '.', ' Daisy', ' spent', ' her', ' day', ' m', 'unch', 'ing', ' happily', ',', ' while', ' B', 'enny', ' flew', ' around', ',', ' searching', ' for', ' a', ' per', 'ch', '.', 'D', 'aisy', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 18.41 Prob: 26.78% Token: | happy|
Top 1th token. Logit: 17.20 Prob:  8.01% Token: | sad|
Top 2th token. Logit: 16.85 Prob:  5.63% Token: | sorry|
Top 3th token. Logit: 16.62 Prob:  4.51% Token: | comfortable|
Top 4th token. Logit: 16.58 Prob:  4.33% Token: | lonely|


---------------------------------------------------------
Character: Benny
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'D', 'aisy', ',', ' the', ' cow', ',', ' and', ' B', 'enny', ',', ' the', ' bird', ',', ' found', ' a', ' large', ' grass', 'y', ' field', '.', ' Daisy', ' was', ' delighted', ' and', ' started', ' grazing', ',', ' but', ' B', 'enny', ' couldn', "'t", ' find', ' any', ' trees', ' and', ' felt', ' upset', '.', ' Daisy', ' spent', ' her', ' day', ' m', 'unch', 'ing', ' happily', ',', ' while', ' B', 'enny', ' flew', ' around', ',', ' searching', ' for', ' a', ' per', 'ch', '.', 'B', 'enny', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 18.32 Prob: 20.96% Token: | sad|
Top 1th token. Logit: 17.99 Prob: 15.12% Token: | lonely|
Top 2th token. Logit: 17.42 Prob:  8.51% Token: | unhappy|
Top 3th token. Logit: 16.60 Prob:  3.75% Token: | hungry|
Top 4th token. Logit: 16.52 Prob:  3.48% Token: | discouraged|


Reversed Story:
Daisy, the cow, and Benny, the bird, came across a tall tree. Daisy couldn't climb and felt disappointed, but Benny happily nestled among the branches. Daisy spent her day wishing for grass, while Benny enjoyed the view from the treetop.

---------------------------------------------------------
Character: Daisy
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'D', 'aisy', ',', ' the', ' cow', ',', ' and', ' B', 'enny', ',', ' the', ' bird', ',', ' came', ' across', ' a', ' tall', ' tree', '.', ' Daisy', ' couldn', "'t", ' climb', ' and', ' felt', ' disappointed', ',', ' but', ' B', 'enny', ' happily', ' nest', 'led', ' among', ' the', ' branches', '.', ' Daisy', ' spent', ' her', ' day', ' wishing', ' for', ' grass', ',', ' while', ' B', 'enny', ' enjoyed', ' the', ' view', ' from', ' the', ' tre', 'et', 'op', '.', 'D', 'aisy', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 18.67 Prob: 32.25% Token: | sad|
Top 1th token. Logit: 17.85 Prob: 14.34% Token: | lonely|
Top 2th token. Logit: 16.84 Prob:  5.18% Token: | sorry|
Top 3th token. Logit: 16.57 Prob:  3.95% Token: | unhappy|
Top 4th token. Logit: 16.51 Prob:  3.73% Token: | disappointed|


---------------------------------------------------------
Character: Benny
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'D', 'aisy', ',', ' the', ' cow', ',', ' and', ' B', 'enny', ',', ' the', ' bird', ',', ' came', ' across', ' a', ' tall', ' tree', '.', ' Daisy', ' couldn', "'t", ' climb', ' and', ' felt', ' disappointed', ',', ' but', ' B', 'enny', ' happily', ' nest', 'led', ' among', ' the', ' branches', '.', ' Daisy', ' spent', ' her', ' day', ' wishing', ' for', ' grass', ',', ' while', ' B', 'enny', ' enjoyed', ' the', ' view', ' from', ' the', ' tre', 'et', 'op', '.', 'B', 'enny', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 18.61 Prob: 30.84% Token: | happy|
Top 1th token. Logit: 17.35 Prob:  8.74% Token: | lonely|
Top 2th token. Logit: 17.25 Prob:  7.92% Token: | proud|
Top 3th token. Logit: 17.03 Prob:  6.34% Token: | sad|
Top 4th token. Logit: 16.68 Prob:  4.49% Token: | content|


Story 2


Original Story:
Molly, the mole, and Frankie, the falcon, discovered a dark tunnel. Molly was excited and quickly burrowed inside, while Frankie, who preferred the sky, felt out of place. Molly had an adventurous day, while Frankie circled overhead, feeling lonely.

---------------------------------------------------------
Character: Molly
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'M', 'olly', ',', ' the', ' mole', ',', ' and', ' Frank', 'ie', ',', ' the', ' fal', 'con', ',', ' discovered', ' a', ' dark', ' tunnel', '.', ' Molly', ' was', ' excited', ' and', ' quickly', ' bur', 'row', 'ed', ' inside', ',', ' while', ' Frank', 'ie', ',', ' who', ' preferred', ' the', ' sky', ',', ' felt', ' out', ' of', ' place', '.', ' Molly', ' had', ' an', ' advent', 'urous', ' day', ',', ' while', ' Frank', 'ie', ' cir', 'cled', ' overhead', ',', ' feeling', ' lonely', '.', 'M', 'olly', ' felt', ' very']
Tokenized answer:

Top 0th token. Logit: 16.14 Prob:  7.48% Token: | safe|
Top 1th token. Logit: 16.06 Prob:  6.92% Token: | happy|
Top 2th token. Logit: 16.04 Prob:  6.76% Token: | lonely|
Top 3th token. Logit: 15.71 Prob:  4.85% Token: | comfortable|
Top 4th token. Logit: 15.48 Prob:  3.88% Token: | sad|


---------------------------------------------------------
Character: Frankie
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'M', 'olly', ',', ' the', ' mole', ',', ' and', ' Frank', 'ie', ',', ' the', ' fal', 'con', ',', ' discovered', ' a', ' dark', ' tunnel', '.', ' Molly', ' was', ' excited', ' and', ' quickly', ' bur', 'row', 'ed', ' inside', ',', ' while', ' Frank', 'ie', ',', ' who', ' preferred', ' the', ' sky', ',', ' felt', ' out', ' of', ' place', '.', ' Molly', ' had', ' an', ' advent', 'urous', ' day', ',', ' while', ' Frank', 'ie', ' cir', 'cled', ' overhead', ',', ' feeling', ' lonely', '.', 'Frank', 'ie', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 16.80 Prob: 16.65% Token: | sad|
Top 1th token. Logit: 16.58 Prob: 13.32% Token: | lonely|
Top 2th token. Logit: 15.84 Prob:  6.36% Token: | alone|
Top 3th token. Logit: 15.54 Prob:  4.71% Token: | sorry|
Top 4th token. Logit: 15.17 Prob:  3.27% Token: | much|


Reversed Story:
Molly, the mole, and Frankie, the falcon, found an open sky. Molly felt exposed and dug a hole for protection, but Frankie soared high, loving the freedom. Molly spent her day in hiding, while Frankie enjoyed the thrill of flying.

---------------------------------------------------------
Character: Molly
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'M', 'olly', ',', ' the', ' mole', ',', ' and', ' Frank', 'ie', ',', ' the', ' fal', 'con', ',', ' found', ' an', ' open', ' sky', '.', ' Molly', ' felt', ' exposed', ' and', ' dug', ' a', ' hole', ' for', ' protection', ',', ' but', ' Frank', 'ie', ' so', 'ared', ' high', ',', ' loving', ' the', ' freedom', '.', ' Molly', ' spent', ' her', ' day', ' in', ' hiding', ',', ' while', ' Frank', 'ie', ' enjoyed', ' the', ' thrill', ' of', ' flying', '.', 'M', 'olly', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.01 Prob: 14.47% Token: | lonely|
Top 1th token. Logit: 16.88 Prob: 12.77% Token: | alone|
Top 2th token. Logit: 16.59 Prob:  9.58% Token: | sad|
Top 3th token. Logit: 15.49 Prob:  3.17% Token: | safe|
Top 4th token. Logit: 15.44 Prob:  3.03% Token: | exposed|


---------------------------------------------------------
Character: Frankie
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'M', 'olly', ',', ' the', ' mole', ',', ' and', ' Frank', 'ie', ',', ' the', ' fal', 'con', ',', ' found', ' an', ' open', ' sky', '.', ' Molly', ' felt', ' exposed', ' and', ' dug', ' a', ' hole', ' for', ' protection', ',', ' but', ' Frank', 'ie', ' so', 'ared', ' high', ',', ' loving', ' the', ' freedom', '.', ' Molly', ' spent', ' her', ' day', ' in', ' hiding', ',', ' while', ' Frank', 'ie', ' enjoyed', ' the', ' thrill', ' of', ' flying', '.', 'Frank', 'ie', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 16.29 Prob:  9.24% Token: | happy|
Top 1th token. Logit: 16.18 Prob:  8.27% Token: | proud|
Top 2th token. Logit: 16.16 Prob:  8.14% Token: | safe|
Top 3th token. Logit: 16.02 Prob:  7.09% Token: | lonely|
Top 4th token. Logit: 15.47 Prob:  4.08% Token: | comfortable|


Story 3


Original Story:
Nina, the nightingale, and Sammy, the snake, found a tall tree full of ripe fruits. Nina sang in delight and pecked at the fruits, but Sammy, who ate only meat, watched with indifference. Nina enjoyed her sweet meal, while Sammy slithered away, still hungry.

---------------------------------------------------------
Character: Nina
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'N', 'ina', ',', ' the', ' night', 'ingale', ',', ' and', ' Sam', 'my', ',', ' the', ' snake', ',', ' found', ' a', ' tall', ' tree', ' full', ' of', ' ripe', ' fruits', '.', ' Nina', ' sang', ' in', ' delight', ' and', ' pe', 'ck', 'ed', ' at', ' the', ' fruits', ',', ' but', ' Sam', 'my', ',', ' who', ' ate', ' only', ' meat', ',', ' watched', ' with', ' indifference', '.', ' Nina', ' enjoyed', ' her', ' sweet', ' meal', ',', ' while', ' Sam', 'my', ' slit', 'hered', ' away', ',', ' still', ' hungry', '.', 'N', 'ina', ' f

Top 0th token. Logit: 18.38 Prob: 24.45% Token: | happy|
Top 1th token. Logit: 17.15 Prob:  7.18% Token: | sad|
Top 2th token. Logit: 17.01 Prob:  6.18% Token: | sorry|
Top 3th token. Logit: 16.77 Prob:  4.87% Token: | proud|
Top 4th token. Logit: 16.76 Prob:  4.86% Token: | hungry|


---------------------------------------------------------
Character: Sammy
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'N', 'ina', ',', ' the', ' night', 'ingale', ',', ' and', ' Sam', 'my', ',', ' the', ' snake', ',', ' found', ' a', ' tall', ' tree', ' full', ' of', ' ripe', ' fruits', '.', ' Nina', ' sang', ' in', ' delight', ' and', ' pe', 'ck', 'ed', ' at', ' the', ' fruits', ',', ' but', ' Sam', 'my', ',', ' who', ' ate', ' only', ' meat', ',', ' watched', ' with', ' indifference', '.', ' Nina', ' enjoyed', ' her', ' sweet', ' meal', ',', ' while', ' Sam', 'my', ' slit', 'hered', ' away', ',', ' still', ' hungry', '.', 'Sam', 'my', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.98 Prob: 21.10% Token: | hungry|
Top 1th token. Logit: 16.66 Prob:  5.64% Token: | sorry|
Top 2th token. Logit: 16.45 Prob:  4.54% Token: | sad|
Top 3th token. Logit: 16.37 Prob:  4.20% Token: | lonely|
Top 4th token. Logit: 16.35 Prob:  4.14% Token: | bad|


Reversed Story:
Nina, the nightingale, and Sammy, the snake, discovered a small rodent. Nina flew away in fright, but Sammy quickly caught it and was contented. Nina spent her time searching for berries, while Sammy enjoyed his catch.

---------------------------------------------------------
Character: Nina
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'N', 'ina', ',', ' the', ' night', 'ingale', ',', ' and', ' Sam', 'my', ',', ' the', ' snake', ',', ' discovered', ' a', ' small', ' rodent', '.', ' Nina', ' flew', ' away', ' in', ' fright', ',', ' but', ' Sam', 'my', ' quickly', ' caught', ' it', ' and', ' was', ' content', 'ed', '.', ' Nina', ' spent', ' her', ' time', ' searching', ' for', ' berries', ',', ' while', ' Sam', 'my', ' enjoyed', ' his', ' catch', '.', 'N', 'ina', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.50 Prob: 12.35% Token: | lonely|
Top 1th token. Logit: 17.33 Prob: 10.45% Token: | happy|
Top 2th token. Logit: 16.99 Prob:  7.47% Token: | sad|
Top 3th token. Logit: 16.91 Prob:  6.88% Token: | sorry|
Top 4th token. Logit: 16.78 Prob:  6.04% Token: | hungry|


---------------------------------------------------------
Character: Sammy
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'N', 'ina', ',', ' the', ' night', 'ingale', ',', ' and', ' Sam', 'my', ',', ' the', ' snake', ',', ' discovered', ' a', ' small', ' rodent', '.', ' Nina', ' flew', ' away', ' in', ' fright', ',', ' but', ' Sam', 'my', ' quickly', ' caught', ' it', ' and', ' was', ' content', 'ed', '.', ' Nina', ' spent', ' her', ' time', ' searching', ' for', ' berries', ',', ' while', ' Sam', 'my', ' enjoyed', ' his', ' catch', '.', 'Sam', 'my', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 18.13 Prob: 22.15% Token: | happy|
Top 1th token. Logit: 17.61 Prob: 13.13% Token: | proud|
Top 2th token. Logit: 16.58 Prob:  4.69% Token: | content|
Top 3th token. Logit: 16.52 Prob:  4.41% Token: | pleased|
Top 4th token. Logit: 16.51 Prob:  4.40% Token: | good|


Story 4


Original Story:
Timmy, the tortoise, and Ricky, the rabbit, came across a slow-moving stream. Timmy was pleased and moved towards the water, but Ricky, who disliked water, hopped away. Timmy had a good time soaking, while Ricky nibbled on grass, eyeing the stream warily.

---------------------------------------------------------
Character: Timmy
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Tim', 'my', ',', ' the', ' tort', 'oise', ',', ' and', ' Ricky', ',', ' the', ' rabbit', ',', ' came', ' across', ' a', ' slow', '-', 'moving', ' stream', '.', ' Tim', 'my', ' was', ' pleased', ' and', ' moved', ' towards', ' the', ' water', ',', ' but', ' Ricky', ',', ' who', ' dis', 'liked', ' water', ',', ' ho', 'pped', ' away', '.', ' Tim', 'my', ' had', ' a', ' good', ' time', ' so', 'aking', ',', ' while', ' Ricky', ' nib', 'bled', ' on', ' grass', ',', ' eye', 'ing', ' the', ' stream', ' war', 'ily', '.', 'Tim', 'my',

Top 0th token. Logit: 17.72 Prob: 23.13% Token: | happy|
Top 1th token. Logit: 16.34 Prob:  5.78% Token: | warm|
Top 2th token. Logit: 16.08 Prob:  4.45% Token: | comfortable|
Top 3th token. Logit: 16.04 Prob:  4.31% Token: | good|
Top 4th token. Logit: 15.96 Prob:  3.98% Token: | pleased|


---------------------------------------------------------
Character: Ricky
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Tim', 'my', ',', ' the', ' tort', 'oise', ',', ' and', ' Ricky', ',', ' the', ' rabbit', ',', ' came', ' across', ' a', ' slow', '-', 'moving', ' stream', '.', ' Tim', 'my', ' was', ' pleased', ' and', ' moved', ' towards', ' the', ' water', ',', ' but', ' Ricky', ',', ' who', ' dis', 'liked', ' water', ',', ' ho', 'pped', ' away', '.', ' Tim', 'my', ' had', ' a', ' good', ' time', ' so', 'aking', ',', ' while', ' Ricky', ' nib', 'bled', ' on', ' grass', ',', ' eye', 'ing', ' the', ' stream', ' war', 'ily', '.', 'R', 'icky', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 16.79 Prob:  8.84% Token: | lonely|
Top 1th token. Logit: 16.52 Prob:  6.71% Token: | sad|
Top 2th token. Logit: 16.40 Prob:  5.93% Token: | sorry|
Top 3th token. Logit: 16.40 Prob:  5.93% Token: | uncomfortable|
Top 4th token. Logit: 16.21 Prob:  4.92% Token: | thirst|


Reversed Story:
Timmy, the tortoise, and Ricky, the rabbit, found a dry, grassy meadow. Timmy missed the water and stayed put, but Ricky was delighted and started hopping around. Timmy spent his day longing for a pond, while Ricky had a delightful time exploring the meadow.

---------------------------------------------------------
Character: Timmy
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Tim', 'my', ',', ' the', ' tort', 'oise', ',', ' and', ' Ricky', ',', ' the', ' rabbit', ',', ' found', ' a', ' dry', ',', ' grass', 'y', ' me', 'adow', '.', ' Tim', 'my', ' missed', ' the', ' water', ' and', ' stayed', ' put', ',', ' but', ' Ricky', ' was', ' delighted', ' and', ' started', ' hopping', ' around', '.', ' Tim', 'my', ' spent', ' his', ' day', ' longing', ' for', ' a', ' pond', ',', ' while', ' Ricky', ' had', ' a', ' delightful', ' time', ' exploring', ' the', ' me', 'adow', '.', 'Tim', 'my', ' felt', ' very']
Token

Top 0th token. Logit: 18.50 Prob: 25.56% Token: | sad|
Top 1th token. Logit: 18.33 Prob: 21.57% Token: | lonely|
Top 2th token. Logit: 17.20 Prob:  6.99% Token: | sorry|
Top 3th token. Logit: 16.58 Prob:  3.77% Token: | alone|
Top 4th token. Logit: 16.22 Prob:  2.61% Token: | unhappy|


---------------------------------------------------------
Character: Ricky
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Tim', 'my', ',', ' the', ' tort', 'oise', ',', ' and', ' Ricky', ',', ' the', ' rabbit', ',', ' found', ' a', ' dry', ',', ' grass', 'y', ' me', 'adow', '.', ' Tim', 'my', ' missed', ' the', ' water', ' and', ' stayed', ' put', ',', ' but', ' Ricky', ' was', ' delighted', ' and', ' started', ' hopping', ' around', '.', ' Tim', 'my', ' spent', ' his', ' day', ' longing', ' for', ' a', ' pond', ',', ' while', ' Ricky', ' had', ' a', ' delightful', ' time', ' exploring', ' the', ' me', 'adow', '.', 'R', 'icky', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.51 Prob: 17.49% Token: | happy|
Top 1th token. Logit: 16.82 Prob:  8.76% Token: | sad|
Top 2th token. Logit: 16.60 Prob:  7.05% Token: | proud|
Top 3th token. Logit: 16.40 Prob:  5.74% Token: | sorry|
Top 4th token. Logit: 16.19 Prob:  4.67% Token: | lonely|


In [13]:
def print_lists(list1, list2):
    # Find the maximum length
    max_len = max(len(list1), len(list2))

    # Print the items and their indices
    for i in range(max_len):
        item1 = list1[i] if i < len(list1) else None
        item2 = list2[i] if i < len(list2) else None
        print(f"Index: {i}, '{item1}', '{item2}'")


In [15]:
story = {
        "original": "Spot, the puppy, and Max, the cat, stumbled upon a bouncing ball. Spot was thrilled and started chasing the ball. Max, however, was startled by the ball's movement and climbed a tree. Spot had a fun-filled day playing, while Max watched warily from above.",
        "reversed": "Max, the cat, and Spot, the puppy, stumbled upon a bouncing ball. Max was thrilled and started chasing the ball. Spot, however, was startled by the ball's movement and climbed a tree. Max had a fun-filled day playing, while Spot watched warily from above.",
        "characters": ["Spot", "Max"]
    }
story = {
    "original": "Spot, the puppy, and Max, the cat, found a quiet corner. Spot was bored without anything to play with, but Max curled up and started purring. Spot wandered around restlessly, while Max had a peaceful nap.",
    "reversed": "Max, the cat, and Spot, the puppy, found a quiet corner. Max was bored without anything to play with, but Spot curled up and started purring. Max wandered around restlessly, while Spot had a peaceful nap.",
    "characters": ["Spot", "Max"]
}
print(model.to_str_tokens(story["original"]))
print(model.to_str_tokens(story["reversed"]))
print_lists(model.to_str_tokens(story["original"]), model.to_str_tokens(story["reversed"]))

for character in story["characters"]:
    # Print the character's name
    print("---------------------------------------------------------")
    print(f"Character: {character}")
    print("---------------------------------------------------------")
    # Get the end of the original story
    original_end = story["original"] + " " + character.capitalize() + " felt very"
    # Get the LLM prediction for the original story
    print(f"Mood Prediction:")
    original_prediction = utils.test_prompt(original_end, "", model, prepend_bos=True, top_k=5)
    
print(f"Reversed Story:")
print("======================================================================================================")
print(f"{story['reversed']}\n")
for character in story["characters"]:
    print("---------------------------------------------------------")
    print(f"Character: {character}")
    print("---------------------------------------------------------")
    print(f"Mood Prediction:")
    reversed_end = story["reversed"] + " " + character.capitalize() + " felt very"
    # Get the LLM prediction for the reversed story
    reversed_prediction = utils.test_prompt(reversed_end, "", model, prepend_bos=True, top_k=5)

['<|endoftext|>', 'Spot', ',', ' the', ' puppy', ',', ' and', ' Max', ',', ' the', ' cat', ',', ' found', ' a', ' quiet', ' corner', '.', ' Spot', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Max', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Spot', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Max', ' had', ' a', ' peaceful', ' nap', '.']
['<|endoftext|>', 'Max', ',', ' the', ' cat', ',', ' and', ' Spot', ',', ' the', ' puppy', ',', ' found', ' a', ' quiet', ' corner', '.', ' Max', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Spot', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Max', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Spot', ' had', ' a', ' peaceful', ' nap', '.']
Index: 0, '<|endoftext|>', '<|endoftext|>'
Index: 1, 'Spot', 'Max'
Index: 2, ',', ','
Index: 3, ' the', ' the'
Index: 4, ' puppy', ' cat'
Index: 5, ',', ','
Index: 6, ' and', ' and

Top 0th token. Logit: 16.62 Prob:  7.16% Token: | sorry|
Top 1th token. Logit: 16.62 Prob:  7.14% Token: | lonely|
Top 2th token. Logit: 16.53 Prob:  6.57% Token: | happy|
Top 3th token. Logit: 16.27 Prob:  5.05% Token: | guilty|
Top 4th token. Logit: 16.25 Prob:  4.96% Token: | sad|


---------------------------------------------------------
Character: Max
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Spot', ',', ' the', ' puppy', ',', ' and', ' Max', ',', ' the', ' cat', ',', ' found', ' a', ' quiet', ' corner', '.', ' Spot', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Max', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Spot', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Max', ' had', ' a', ' peaceful', ' nap', '.', ' Max', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.64 Prob: 12.38% Token: | safe|
Top 1th token. Logit: 17.56 Prob: 11.34% Token: | comfortable|
Top 2th token. Logit: 17.51 Prob: 10.82% Token: | content|
Top 3th token. Logit: 17.23 Prob:  8.19% Token: | happy|
Top 4th token. Logit: 16.44 Prob:  3.70% Token: | warm|


Reversed Story:
Max, the cat, and Spot, the puppy, found a quiet corner. Max was bored without anything to play with, but Spot curled up and started purring. Max wandered around restlessly, while Spot had a peaceful nap.

---------------------------------------------------------
Character: Spot
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Max', ',', ' the', ' cat', ',', ' and', ' Spot', ',', ' the', ' puppy', ',', ' found', ' a', ' quiet', ' corner', '.', ' Max', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Spot', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Max', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Spot', ' had', ' a', ' peaceful', ' nap', '.', ' Spot', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.52 Prob: 11.70% Token: | comfortable|
Top 1th token. Logit: 17.31 Prob:  9.44% Token: | safe|
Top 2th token. Logit: 17.25 Prob:  8.88% Token: | happy|
Top 3th token. Logit: 17.24 Prob:  8.86% Token: | content|
Top 4th token. Logit: 16.38 Prob:  3.75% Token: | lucky|


---------------------------------------------------------
Character: Max
---------------------------------------------------------
Mood Prediction:
Tokenized prompt: ['<|endoftext|>', 'Max', ',', ' the', ' cat', ',', ' and', ' Spot', ',', ' the', ' puppy', ',', ' found', ' a', ' quiet', ' corner', '.', ' Max', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Spot', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Max', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Spot', ' had', ' a', ' peaceful', ' nap', '.', ' Max', ' felt', ' very']
Tokenized answer: [' ']


Top 0th token. Logit: 17.01 Prob: 10.85% Token: | lonely|
Top 1th token. Logit: 16.56 Prob:  6.89% Token: | sorry|
Top 2th token. Logit: 16.26 Prob:  5.15% Token: | sad|
Top 3th token. Logit: 16.16 Prob:  4.66% Token: | alone|
Top 4th token. Logit: 16.06 Prob:  4.21% Token: | happy|


### Dataset Construction

In [15]:
dog_names = " Spot, Rex, Max, Sam, Lucky, Rocky, Bear, Jake, Duke, Cody, Bailey, Jack, Murphy, Shelby, Winston, Tyson, Sam, Shadow, Gus, Hunter, Casey, Joey, Bruno, Beau, Dakota, Luke, Henry, Tucker, Oscar"
cat_names = " Tiger, Oliver, Shadow, Princess, Max, Angel, Lucy, Charlie, Chloe, Baby, Molly, Daisy, Sophie, Lily, Kitty, Lily, Simon, Salem, Oscar, George, Sebastian, Felix, Pepper, Rocky, Sebastian"
dog_names = [n[1:] for n in model.to_str_tokens(dog_names)[1::2]]
cat_names = [n[1:] for n in model.to_str_tokens(cat_names)[1::2]]

orig_prompts = [f"Story: {d_name}, the puppy, and {c_name}, the cat, found a quiet corner. {d_name} was bored without anything to play with, but {c_name} curled up and started purring. {d_name} wandered around restlessly, while {c_name} had a peaceful nap. {c_name} felt very" for d_name in dog_names for c_name in cat_names]
flip_prompts = [f"Story: {c_name}, the cat, and {d_name}, the puppy, found a quiet corner. {c_name} was bored without anything to play with, but {d_name} curled up and started purring. {c_name} wandered around restlessly, while {d_name} had a peaceful nap. {c_name} felt very" for d_name in dog_names for c_name in cat_names]

print(model.to_str_tokens(orig_prompts[0]))
print(model.to_str_tokens(flip_prompts[0]))

['<|endoftext|>', 'Story', ':', ' Spot', ',', ' the', ' puppy', ',', ' and', ' Tiger', ',', ' the', ' cat', ',', ' found', ' a', ' quiet', ' corner', '.', ' Spot', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Tiger', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Spot', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Tiger', ' had', ' a', ' peaceful', ' nap', '.', ' Tiger', ' felt', ' very']
['<|endoftext|>', 'Story', ':', ' Tiger', ',', ' the', ' cat', ',', ' and', ' Spot', ',', ' the', ' puppy', ',', ' found', ' a', ' quiet', ' corner', '.', ' Tiger', ' was', ' bored', ' without', ' anything', ' to', ' play', ' with', ',', ' but', ' Spot', ' curled', ' up', ' and', ' started', ' pur', 'ring', '.', ' Tiger', ' wandered', ' around', ' rest', 'lessly', ',', ' while', ' Spot', ' had', ' a', ' peaceful', ' nap', '.', ' Tiger', ' felt', ' very']


In [25]:
for n in cat_names:
    print(n)
    print(model.to_str_tokens(n))
    print(model.to_str_tokens(" " + n))

Tiger
['<|endoftext|>', 'T', 'iger']
['<|endoftext|>', ' Tiger']
Oliver
['<|endoftext|>', 'O', 'liver']
['<|endoftext|>', ' Oliver']
Shadow
['<|endoftext|>', 'Shadow']
['<|endoftext|>', ' Shadow']
Princess
['<|endoftext|>', 'Pr', 'incess']
['<|endoftext|>', ' Princess']
Max
['<|endoftext|>', 'Max']
['<|endoftext|>', ' Max']
Angel
['<|endoftext|>', 'Angel']
['<|endoftext|>', ' Angel']
Lucy
['<|endoftext|>', 'Lu', 'cy']
['<|endoftext|>', ' Lucy']
Charlie
['<|endoftext|>', 'Charlie']
['<|endoftext|>', ' Charlie']
Chloe
['<|endoftext|>', 'Ch', 'loe']
['<|endoftext|>', ' Chloe']
Baby
['<|endoftext|>', 'Baby']
['<|endoftext|>', ' Baby']
Molly
['<|endoftext|>', 'M', 'olly']
['<|endoftext|>', ' Molly']
Daisy
['<|endoftext|>', 'D', 'aisy']
['<|endoftext|>', ' Daisy']
Sophie
['<|endoftext|>', 'S', 'oph', 'ie']
['<|endoftext|>', ' Sophie']
Lily
['<|endoftext|>', 'L', 'ily']
['<|endoftext|>', ' Lily']
Kitty
['<|endoftext|>', 'Kit', 'ty']
['<|endoftext|>', ' Kitty']
Lily
['<|endoftext|>', 'L', 'ily

In [18]:
len(model.to_str_tokens(neg_answers[2]))

2

In [29]:
for p in orig_prompts+flip_prompts:
    if (len(model.to_str_tokens(p)))==53:
        continue
    else:
        print(p)

In [16]:
pos_answers = [" happy", " comfortable", " content"] #, " amazing", " good"]
neg_answers = [" sad", " lonely", " sorry"] #, " terrible", " bad"]
batch_size = 2 * len(orig_prompts)
n_pairs = 2

all_prompts = []
answer_tokens = torch.empty(
        (batch_size, n_pairs, 2), 
        device=device, 
        dtype=torch.long
    )
for i in range(len(orig_prompts)):
    all_prompts.append(orig_prompts[i])
    all_prompts.append(flip_prompts[i])

    for pair_idx in range(n_pairs):
            pos_token = model.to_single_token(pos_answers[pair_idx])
            neg_token = model.to_single_token(neg_answers[pair_idx])
            tokens_dict = {
                'positive': pos_token, 
                'negative': neg_token, 
            }
            answer_tokens[i * 2, pair_idx, 0] = tokens_dict['positive']
            answer_tokens[i * 2, pair_idx, 1] = tokens_dict['negative']
            answer_tokens[i * 2 + 1, pair_idx, 0] = tokens_dict['negative']
            answer_tokens[i * 2 + 1, pair_idx, 1] = tokens_dict['positive']

prompts_tokens: Float[Tensor, "batch pos"] = model.to_tokens(
        all_prompts, prepend_bos=True
    )
clean_tokens = prompts_tokens.to(device)
corrupted_tokens = model.to_tokens(
    all_prompts[1:] + [all_prompts[0]], prepend_bos=True
).to(device)


In [17]:
len(all_prompts), answer_tokens.shape, clean_tokens.shape, corrupted_tokens.shape

(1450,
 torch.Size([1450, 2, 2]),
 torch.Size([1450, 53]),
 torch.Size([1450, 53]))

In [18]:
for i in range(0, 20, 1):
    logits, _ = model.run_with_cache(all_prompts[i])
    log_diff = get_logit_diff(logits, answer_tokens[i].unsqueeze(0))
    #if log_diff < 0.1:
    print(all_prompts[i])
    print(model.to_str_tokens(answer_tokens[i][0]))
    print(log_diff, "\n")

Story: Spot, the puppy, and Tiger, the cat, found a quiet corner. Spot was bored without anything to play with, but Tiger curled up and started purring. Spot wandered around restlessly, while Tiger had a peaceful nap. Tiger felt very
[' happy', ' sad']
tensor(1.0655, device='cuda:0') 

Story: Tiger, the cat, and Spot, the puppy, found a quiet corner. Tiger was bored without anything to play with, but Spot curled up and started purring. Tiger wandered around restlessly, while Spot had a peaceful nap. Tiger felt very
[' sad', ' happy']
tensor(1.2824, device='cuda:0') 

Story: Spot, the puppy, and Oliver, the cat, found a quiet corner. Spot was bored without anything to play with, but Oliver curled up and started purring. Spot wandered around restlessly, while Oliver had a peaceful nap. Oliver felt very
[' happy', ' sad']
tensor(0.9091, device='cuda:0') 

Story: Oliver, the cat, and Spot, the puppy, found a quiet corner. Oliver was bored without anything to play with, but Spot curled up a

In [19]:
all_prompts = all_prompts[:12]
answer_tokens = answer_tokens[:12]
clean_tokens = clean_tokens[:12]
corrupted_tokens = corrupted_tokens[:12]

#### Logit Differences

In [79]:
pos_logits, pos_cache = model.run_with_cache(clean_tokens[0::2,:])
pos_logit_diff = get_logit_diff(pos_logits, answer_tokens[0::2,:])
pos_logit_diff

tensor(0.9553, device='cuda:0')

In [80]:
neg_logits, neg_cache = model.run_with_cache(clean_tokens[1::2,:])
neg_logit_diff = get_logit_diff(neg_logits, answer_tokens[1::2,:])
neg_logit_diff

tensor(1.4725, device='cuda:0')

In [81]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = get_logit_diff(clean_logits, answer_tokens, per_prompt=False)
clean_logit_diff

tensor(1.2139, device='cuda:0')

In [82]:
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_tokens, per_prompt=False)
corrupted_logit_diff

tensor(-1.2213, device='cuda:0')

In [83]:
def logit_diff_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch n_pairs 2"] = answer_tokens,
    flipped_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    return_tensor: bool = False,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = get_logit_diff(logits, answer_tokens)
    ld = ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff))
    if return_tensor:
        return ld
    else:
        return ld.item()


def logit_diff_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float = clean_logit_diff,
        corrupted_logit_diff: float = corrupted_logit_diff,
        answer_tokens: Float[Tensor, "batch n_pairs 2"] = answer_tokens,
        return_tensor: bool = False,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = get_logit_diff(logits, answer_tokens)
        ld = ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff))

        if return_tensor:
            return ld
        else:
            return ld.item()

logit_diff_denoising_tensor = partial(logit_diff_denoising, return_tensor=True)
logit_diff_noising_tensor = partial(logit_diff_noising, return_tensor=True)

### Direct Logit Attribution

In [25]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

# added for multi-answer support
answer_residual_directions = answer_residual_directions.mean(dim=1)

print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([12, 2, 2560])
Logit difference directions shape: torch.Size([12, 2560])


In [26]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(all_prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff.item())

Final residual stream shape: torch.Size([12, 53, 2560])
Calculated average logit diff: 1.2139166593551636
Original logit difference: 1.2139177322387695


#### Logit Lens

In [35]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = clean_cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(all_prompts)

In [36]:
accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*1+1)/2, hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

#### Layer Attribution

In [37]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, clean_cache)

line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")


#### Head Attribution

In [38]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
per_head_logit_diffs_pct = per_head_logit_diffs
imshow_p(
    per_head_logit_diffs_pct * 100, 
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    title="Logit Difference From Each Head", 
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600)

Tried to stack head results when they weren't cached. Computing head results now


### Activation Patching

#### Attention Heads

In [39]:
results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode("z"), # iterating over all heads' output in all layers
    patching_metric=logit_diff_denoising,
    verbose=True,
)

  0%|          | 0/1024 [00:00<?, ?it/s]

results['z'].shape = (layer=32, head=32)


In [40]:
imshow_p(
    results['z'] * 100,
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

#### Head Output by Component

In [44]:
# iterating over all heads' output in all layers
results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["z", "q", "k", "v", "pattern"]),
    patching_metric=logit_diff_denoising,
    verbose=True,
)

  0%|          | 0/5120 [00:00<?, ?it/s]

results['z'].shape = (layer=32, head=32)
results['q'].shape = (layer=32, head=32)
results['k'].shape = (layer=32, head=32)
results['v'].shape = (layer=32, head=32)
results['pattern'].shape = (layer=32, head=32)


In [45]:
with open("data/2_8b_mood_inference/act_patch_head_output_by_component.pkl", "wb") as f:
    pickle.dump(results, f)

In [46]:
with open("data/2_8b_mood_inference/act_patch_head_output_by_component.pkl", "rb") as f:
    act_patch_head_output_by_component = pickle.load(f)

assert act_patch_head_output_by_component.keys() == {"z", "q", "k", "v", "pattern"}
#assert all([r.shape == (12, 12) for r in results.values()])

imshow_p(
    torch.stack(tuple(act_patch_head_output_by_component.values())) * 100,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    margin={"r": 100, "l": 100}
)

#### Residual Stream & Layer Outputs

In [47]:
# patching at each (layer, sequence position) for each of (resid_pre, attn_out, mlp_out) in turn

results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos="each"),
    patching_metric=logit_diff_denoising,
    verbose=True,
)
with open("data/2_8b_mood_inference/act_patch_resid_layer_output.pkl", "wb") as f:
    pickle.dump(results, f)

  0%|          | 0/5088 [00:00<?, ?it/s]

results['resid_pre'].shape = (seq_pos=53, layer=32)
results['attn_out'].shape = (seq_pos=53, layer=32)
results['mlp_out'].shape = (seq_pos=53, layer=32)


In [57]:
with open("data/2_8b_mood_inference/act_patch_resid_layer_output.pkl", "rb") as f:
    act_patch_resid_layer_output = pickle.load(f)

assert act_patch_resid_layer_output.keys() == {"resid_pre", "attn_out", "mlp_out"}
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]
imshow_p(
    torch.stack([r.T for r in act_patch_resid_layer_output.values()]) * 100, # we transpose so layer is on the y-axis
    facet_col=0,
    facet_labels=["resid_pre", "attn_out", "mlp_out"],
    title="Patching at resid stream & layer outputs (corrupted -> clean)",
    labels={"x": "Sequence position", "y": "Layer", "color": "Logit diff variation"},
    x=labels,
    xaxis_tickangle=45,
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    height=600,
    zmin=-50,
    zmax=50,
    margin={"r": 100, "l": 100}
)

In [34]:
import transformer_lens.patching as patching
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]

attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, logit_diff_denoising_tensor)
attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")

  0%|          | 0/54272 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from neel_plotly import imshow as imshow_n
imshow_n(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        height=2000,
        width=1500,
        zmin=-0.1,
        zmax=0.1,
        title="attn_head_out Activation Patching By Pos")

In [101]:
with open("data/mood_circuit_act_patch_pos_heads.pkl", "wb") as f:
    pickle.dump(results, f)

### Circuit Analysis With Patch Patching & Attn Visualization

#### Heads Influencing Logit Diff

In [34]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 31), # This is resid_post at layer 11
    patching_metric=logit_diff_noising,
    verbose=True
)

  0%|          | 0/1024 [00:00<?, ?it/s]

results['z'].shape = (layer=32, head=32)


In [39]:
imshow_p(
    results['z'] * 100,
    title="Direct effect on logit diff (patch from head output -> final resid)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

In [47]:
plot_attention_heads(-results['z'].cuda(), top_n=15, range_x=[0, 0.7])

Total logit diff contribution above threshold: 0.85


In [56]:
top_k = 9
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, all_prompts[9], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [57]:
DE_HEADS = [(17, 19), (12, 2), (15, 26), (22, 5), (19, 6), (14,4), (14, 1), (19, 29), (24, 15)]
DEVC_HEADS = [(17, 19), (15, 26), (22, 5), (24, 15)]
DEQC_HEADS = [(12, 2), (19, 6), (14,4), (14, 1), (19, 29), (24, 15)]

In [59]:
# V-weighted version
from utils.visualization import get_attn_pattern, plot_attention
plot_attention(
    model, 
    all_prompts[9],
    DE_HEADS,
    clean_cache,
    weighted=True)

#### Direct Effect V-Composition Heads

##### Residual Stream by Position

In [36]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["resid_pre"], seq_pos="each"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DE_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1696 [00:00<?, ?it/s]

results['resid_pre'].shape = (seq_pos=53, layer=32)


In [42]:
res = einops.rearrange(results['resid_pre'], "seq layer -> layer seq")

In [47]:
imshow_n(
        res * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

##### Attention Out by Position

In [60]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["attn_out"], seq_pos="each"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DEVC_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


In [61]:
res = einops.rearrange(results['attn_out'], "seq layer -> layer seq")

In [62]:
imshow_n(
        res * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

##### Attention Out by Position for Individual Receiver Heads (V Composition)

In [63]:
for layer, head in DEVC_HEADS:
    res = path_patch(
        model,
        orig_input=clean_tokens,
        new_input=corrupted_tokens,
        sender_nodes=IterNode(node_names=["attn_out"], seq_pos="each"),
        receiver_nodes=[Node("v", layer, head=head)],
        patching_metric=logit_diff_noising,
        verbose=True,
    )
    res = einops.rearrange(res['attn_out'], "seq layer -> layer seq")
    imshow_n(
        res * 100,
        title=f"Direct effect on Layer {layer} Head {head}' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


##### MLP Out by Position

In [51]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["mlp_out"], seq_pos="each"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DE_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1696 [00:00<?, ?it/s]

results['mlp_out'].shape = (seq_pos=53, layer=32)


In [53]:
res = einops.rearrange(results['mlp_out'], "seq layer -> layer seq")

In [54]:
imshow_n(
        res * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

##### Patching by Attention Heads

In [64]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DEVC_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1024 [00:00<?, ?it/s]

results['z'].shape = (layer=32, head=32)


In [67]:
results['z']

tensor([[ 2.0070e-05, -1.3609e-05,  1.0045e-04,  ...,  3.5637e-05,
         -2.9660e-04,  1.6923e-04],
        [ 4.4155e-05, -1.6056e-05,  6.3148e-05,  ...,  4.6015e-06,
          1.3408e-04, -1.9317e-04],
        [ 2.5201e-04, -1.2449e-04, -4.1007e-04,  ...,  3.8379e-05,
          2.0217e-05, -5.5414e-05],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])

In [72]:
imshow_p(
        results["z"][:24] * 100,
        title=f"Direct effect on Intermediate AE Heads' values",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [73]:
plot_attention_heads(-results['z'].cuda(), top_n=30, range_x=[0, 0.5])

Total logit diff contribution above threshold: 0.51


In [77]:
top_k = 4
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, all_prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [76]:
IAM_HEADS = [(12, 2), (14, 1), (15, 16), (14, 4)]

# V-weighted version
from utils.visualization import get_attn_pattern, plot_attention
plot_attention(
    model, 
    all_prompts[9],
    IAM_HEADS,
    clean_cache,
    weighted=True)

#### Direct Effect Q-Composition Heads

##### Residual Stream by Position

In [None]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["resid_pre"], seq_pos="each"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DEQC_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1696 [00:00<?, ?it/s]

results['resid_pre'].shape = (seq_pos=53, layer=32)


In [None]:
res = einops.rearrange(results['resid_pre'], "seq layer -> layer seq")

In [None]:
imshow_n(
        res * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

##### Attention Out by Position

In [85]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["attn_out"], seq_pos="each"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in DEQC_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


In [86]:
res = einops.rearrange(results['attn_out'], "seq layer -> layer seq")

In [88]:
imshow_n(
        res * 100,
        title=f"Direct effect on DEQC Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

##### Attention Out by Position for Individual Receiver Heads (V Composition)

In [None]:
for layer, head in DEVC_HEADS:
    res = path_patch(
        model,
        orig_input=clean_tokens,
        new_input=corrupted_tokens,
        sender_nodes=IterNode(node_names=["attn_out"], seq_pos="each"),
        receiver_nodes=[Node("v", layer, head=head)],
        patching_metric=logit_diff_noising,
        verbose=True,
    )
    res = einops.rearrange(res['attn_out'], "seq layer -> layer seq")
    imshow_n(
        res * 100,
        title=f"Direct effect on Layer {layer} Head {head}' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


  0%|          | 0/1696 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=53, layer=32)


##### MLP Out by Position

In [None]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["mlp_out"], seq_pos="each"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DE_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1696 [00:00<?, ?it/s]

results['mlp_out'].shape = (seq_pos=53, layer=32)


In [None]:
res = einops.rearrange(results['mlp_out'], "seq layer -> layer seq")

In [None]:
imshow_n(
        res * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

##### Patching by Attention Heads

In [89]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in DEQC_HEADS],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/1024 [00:00<?, ?it/s]

results['z'].shape = (layer=32, head=32)


In [90]:
imshow_p(
        results["z"][:20] * 100,
        title=f"Direct effect on Intermediate AE Heads' values",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [91]:
plot_attention_heads(-results['z'].cuda(), top_n=30, range_x=[0, 0.5])

Total logit diff contribution above threshold: 0.85


In [92]:
top_k = 7
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, all_prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [93]:
IAMQ_HEADS = [(10, 14), (11, 12), (17, 6), (9, 19), (11, 0), (11, 22), (13, 20)]

# V-weighted version
from utils.visualization import get_attn_pattern, plot_attention
plot_attention(
    model, 
    all_prompts[9],
    IAMQ_HEADS,
    clean_cache,
    weighted=True)