# Initial Exploratory Analysis

## Setup

In [1]:
try:
    from google.colab import drive # type: ignore
    %pip install transformer_lens
    %pip install gdown
    # %pip install plotly
    # %pip install jaxtyping
    # %pip install einops
    # %pip install protobuf==3.20.*
    import os
    import sys
    from pathlib import Path
    import gdown
    if not Path("ioi_dataset.py").resolve().exists():
        urls = {
            "ioi_dataset.py": "https://drive.google.com/uc?id=19UjxFnb6kztuhvz6dGAXjA9oRZmd84kC",
            "path_patching.py": "https://drive.google.com/uc?id=1duF7B3IjG_E5nGcjT_BuoSrSkynUhZI5",
        }
        for filename, url in urls.items():
            output = str(Path(filename).resolve())
            gdown.download(url, output)
except:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [2]:
import os
import pathlib
from typing import List, Optional, Union

import torch
import numpy as np
import yaml

import einops
from fancy_einsum import einsum

from datasets import load_dataset
#from transformers import pipeline

import transformers
import circuitsvis as cv
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from torch import Tensor
from tqdm.notebook import tqdm
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from rich import print as rprint

from typing import List, Union
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import re

from functools import partial

from torchtyping import TensorType as TT

In [3]:
torch.set_grad_enabled(False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"}

# def imshow(tensor, renderer=None, **kwargs):
#     kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
#     kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
#     facet_labels = kwargs_pre.pop("facet_labels", None)
#     border = kwargs_pre.pop("border", False)
#     if "color_continuous_scale" not in kwargs_pre:
#         kwargs_pre["color_continuous_scale"] = "RdBu"
#     if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
#         kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
#     fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre).update_layout(**kwargs_post)
#     if facet_labels:
#         for i, label in enumerate(facet_labels):
#             fig.layout.annotations[i]['text'] = label
#     if border:
#         fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
#         fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
#     fig.show(renderer=renderer)

In [4]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [14]:
def get_logit_diff(logits, answer_token_indices, per_prompt=False):
    """Gets the difference between the logits of the provided tokens (e.g., the correct and incorrect tokens in IOI)

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    if len(logits.shape) == 3:
        # Get final logits only
        logits = logits[:, -1, :]
    left_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    right_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    if per_prompt:
        print(left_logits - right_logits)

    return (left_logits - right_logits).mean()

## Exploratory Analysis


In [6]:
#source_model = AutoModelForCausalLM.from_pretrained("lvwerra/gpt2-imdb")
#rlhf_model = AutoModelForCausalLM.from_pretrained("curt-tigges/gpt2-negative-movie-reviews")

#hooked_source_model = HookedTransformer.from_pretrained(model_name="gpt2", hf_model=source_model)
model = HookedTransformer.from_pretrained(model_name="EleutherAI/pythia-410m")
#model = HookedTransformer.from_pretrained(model_name="gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-410m into HookedTransformer


### Initial Examination

In [7]:
example_prompt = "This movie was a terrible viewing experience, and the plot and acting were awful. Overall, the film was"
example_answer = " great"

In [8]:
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=10)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' was', ' a', ' terrible', ' viewing', ' experience', ',', ' and', ' the', ' plot', ' and', ' acting', ' were', ' awful', '.', ' Overall', ',', ' the', ' film', ' was']
Tokenized answer: [' great']


Top 0th token. Logit: 16.85 Prob: 19.07% Token: | a|
Top 1th token. Logit: 15.35 Prob:  4.25% Token: | not|
Top 2th token. Logit: 15.22 Prob:  3.73% Token: | pretty|
Top 3th token. Logit: 15.20 Prob:  3.67% Token: | very|
Top 4th token. Logit: 15.13 Prob:  3.41% Token: | awful|
Top 5th token. Logit: 15.05 Prob:  3.15% Token: | terrible|
Top 6th token. Logit: 14.86 Prob:  2.59% Token: | bad|
Top 7th token. Logit: 14.69 Prob:  2.20% Token: | just|
Top 8th token. Logit: 14.47 Prob:  1.77% Token: | an|
Top 9th token. Logit: 14.43 Prob:  1.69% Token: | horrible|


In [9]:
example_prompt = "This movie was a fantastic viewing experience, and the plot and acting were superb. Overall, the film was"
example_answer = " awful"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=10)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' was', ' a', ' fantastic', ' viewing', ' experience', ',', ' and', ' the', ' plot', ' and', ' acting', ' were', ' superb', '.', ' Overall', ',', ' the', ' film', ' was']
Tokenized answer: [' awful']


Top 0th token. Logit: 17.73 Prob: 25.36% Token: | a|
Top 1th token. Logit: 16.73 Prob:  9.29% Token: | very|
Top 2th token. Logit: 16.58 Prob:  8.04% Token: | well|
Top 3th token. Logit: 15.68 Prob:  3.27% Token: | an|
Top 4th token. Logit: 15.55 Prob:  2.87% Token: | quite|
Top 5th token. Logit: 15.52 Prob:  2.79% Token: | great|
Top 6th token. Logit: 15.49 Prob:  2.70% Token: | entertaining|
Top 7th token. Logit: 15.30 Prob:  2.23% Token: | enjoyable|
Top 8th token. Logit: 15.24 Prob:  2.10% Token: | good|
Top 9th token. Logit: 14.96 Prob:  1.59% Token: | not|


### Dataset Construction

In [10]:
positive_adjectives = [
    ' perfect', ' fantastic',' delightful',' cheerful',' marvelous',' good',' remarkable',' wonderful',
    ' fabulous',' outstanding',' awesome',' exceptional',' incredible',' extraordinary',
    ' amazing',' lovely',' brilliant',' charming',' terrific',' superb',' spectacular',' great',' splendid',
    ' beautiful',' joyful',' positive',' excellent'
    ]

negative_adjectives = [
    ' dreadful',' bad',' dull',' depressing',' miserable',' tragic',' nasty',' inferior',' horrific',' terrible',
    ' ugly',' disgusting',' disastrous',' horrendous',' annoying',' boring',' offensive',' frustrating',' wretched',' dire',
    ' awful',' unpleasant',' horrible',' mediocre',' disappointing',' inadequate'
    ]

len(positive_adjectives), len(negative_adjectives)


(27, 26)

In [11]:
all_prompts = []

pos_prompts = [
    f"This movie was a {positive_adjectives[i]} viewing experience, and the plot and acting were {positive_adjectives[i+1]}. Overall, the film was" for i in range(len(positive_adjectives)-1)
]
neg_prompts = [
    f"This movie was a {negative_adjectives[i]} viewing experience, and the plot and acting were {negative_adjectives[i+1]}. Overall, the film was" for i in range(len(negative_adjectives)-1)
]
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
pos_answer_tokens = []
neg_answer_tokens = []
for i in range(len(pos_prompts)-1):

    all_prompts.append(pos_prompts[i])
    all_prompts.append(neg_prompts[i])
    
    answer_tokens.append(
        (
            model.to_single_token(" great"),
            model.to_single_token(" awful"),
        )
    )

    pos_answer_tokens.append(
        (
            model.to_single_token(" great"),
            model.to_single_token(" awful"),
        )
    )

    answer_tokens.append(
        (
            model.to_single_token(" great"),
            model.to_single_token(" awful"),
        )
    )

    neg_answer_tokens.append(
        (
            model.to_single_token(" awful"),
            model.to_single_token(" great"),
        )
    )
answer_tokens = torch.tensor(answer_tokens).to(device)
pos_answer_tokens = torch.tensor(pos_answer_tokens).to(device)
neg_answer_tokens = torch.tensor(neg_answer_tokens).to(device)

In [12]:
all_prompts_tokens.shape

In [45]:
all_prompts_tokens = model.to_tokens(all_prompts, prepend_bos=True)
all_prompts_tokens = all_prompts_tokens.to(device)

In [48]:
clean_tokens = model.to_tokens(all_prompts[:2])
clean_logits, clean_cache = model.run_with_cache(all_prompts[0])
clean_logit_diff = get_logit_diff(clean_logits, answer_tokens[0].unsqueeze(0), per_prompt=True)
clean_logit_diff

tensor([[7.1921]], device='cuda:0')


tensor(7.1921, device='cuda:0')

In [40]:
clean_tokens = model.to_tokens(all_prompts[:2])
clean_logits, clean_cache = model.run_with_cache(clean_tokens[0:2])
clean_logit_diff = get_logit_diff(clean_logits, answer_tokens[0:2], per_prompt=True)
clean_logit_diff

tensor([[-0.1994],
        [ 0.1468]], device='cuda:0')


tensor(-0.0263, device='cuda:0')

In [39]:
clean_logits[0, -1, :][answer_tokens[0][0]], clean_logits[0, -1, :][answer_tokens[0][1]]

(tensor(-1.4972, device='cuda:0'), tensor(-1.2978, device='cuda:0'))

In [23]:
pos_tokens = model.to_tokens(pos_prompts, prepend_bos=True)
neg_tokens = model.to_tokens(neg_prompts, prepend_bos=True)

# Run the models and cache all activations
pos_logits, pos_cache = model.run_with_cache(pos_tokens)
neg_logits, neg_cache = model.run_with_cache(neg_tokens)


In [24]:
pos_answer_tokens

tensor([[ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580],
        [ 1270, 18580]], device='cuda:0')

In [25]:
pos_logit_diff = get_logit_diff(pos_logits, pos_answer_tokens, per_prompt=True)
pos_logit_diff

tensor([[-1.0402],
        [-0.2136],
        [ 6.1793],
        [ 6.6180],
        [-0.7280],
        [-1.1039],
        [-0.2653],
        [-0.2278],
        [-0.7180],
        [-0.7481],
        [-0.7208],
        [-0.2508],
        [-0.2212],
        [-0.7085],
        [-0.6933],
        [-0.2373],
        [-0.2187],
        [-0.7417],
        [-0.7083],
        [-0.6966],
        [-1.0318],
        [-0.7247],
        [-0.2148],
        [-0.7170],
        [-1.0457]], device='cuda:0')


tensor(-0.0471, device='cuda:0')

In [412]:
# 
neg_logit_diff = get_logit_diff(neg_logits, answer_tokens.flip(dims=(0,)))
neg_logit_diff

### Direct Logit Attribution

In [390]:
# Here we get the unembedding vectors for the answer tokens
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([27, 2, 768])
Logit difference directions shape: torch.Size([27, 768])


In [391]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = pos_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = pos_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(pos_prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",pos_logit_diff.item())

Final residual stream shape: torch.Size([27, 9, 768])
Calculated average logit diff: 3.4623138904571533
Original logit difference: 4.853545665740967


#### Logit Lens

In [392]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = pos_cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(pos_prompts)

In [393]:
accumulated_residual, labels = pos_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, pos_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

#### Layer Attribution

In [394]:
per_layer_residual, labels = pos_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, pos_cache)

line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

#### Head Attribution

In [398]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

per_head_residual, labels = pos_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, pos_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")