**SOURCE NOTEBOOK:** [TransformerLens Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)

## Setup


In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the TransformerLens code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-ev0s_lsm
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-ev0s_lsm
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 218ebd6f491f47f5e2f64e4c4327548b60a093eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting typeguard<4.0.0,>=3.0.2 (from transformer-lens==0.0.0)
  Using cached typeguard-3.0.2-py3-none-any.whl (30 kB)
Installing collected packages: typeguard
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.13.3
    Uninstalling typeguard-2.13.3:
      Successfully uninstalled typeguard-2.13.3
[31mERROR: pip's dependency resolver does not currently take 

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
line(np.arange(5))

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{DEVICE = }")

In [None]:
solu_model = HookedTransformer.from_pretrained(
    "solu-1l",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=DEVICE
)
# gelu_model = HookedTransformer.from_pretrained(
#     "gelu-1l",
#     center_unembed=True,
#     center_writing_weights=True,
#     fold_ln=True,
#     refactor_factored_attn_matrices=True,
#     device=DEVICE
# )

Loaded pretrained model solu-1l into HookedTransformer


### Prompts

In [None]:
# Written by ChatGPT (3) (from Alana's notebook)
prompts = [
    "I am happy.",
    "The sun shines.",
    "Cats meow.",
    "Dogs bark.",
    "Birds fly.",
    "Time flies.",
    "Love conquers.",
    "Dreams inspire.",
    "Music heals.",
    "Laughter echoes."
    "The sky is blue.",
    "I love pizza.",
    "She walks to work.",
    "He plays guitar.",
    "We went to the beach.",
    "They are watching a movie.",
    "I need some coffee.",
    "The cat is sleeping.",
    "He ran to catch the bus.",
    "She smiled and waved goodbye.",
    "The book is on the table.",
    "They laughed at the joke.",
    "I want to learn coding.",
    "We had a great time.",
    "He asked me a question.",
    "She sings beautifully.",
    "The sun sets in the west.",
    "They went hiking in the mountains.",
    "I forgot my keys at home.",
    "He bought a new car.",
    "She enjoys playing tennis.",
    "We had dinner at a fancy restaurant.",
    "They are planning a trip to Europe.",
    "I saw a shooting star last night.",
    "She wrote a letter to her friend.",
    "The dog chased its tail."
]

## Using prompts to compute scores

Here we're trying to tackle the [problem 4.39c](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/o6ptPu7arZrqRCxyz#Problems): "look for tokens where the direct logit attribution of the MLP layer is high, but no single neuron is high.".

- ablation loss - how much ablating that position on that prompt increases loss
    - if I'm getting it right, in this particular 1-layer case this should be equivalent to direct logit attribution
    - supposed to be a proxy for "the direct logit attribution of the MLP layer is high"
- inverse outlier score - also per position: `1 / (maximum_activation - mean_activation)`
    - supposed to be a proxy for "no single neuron is high"

So multiplying these two element-wise should give us a reasonable proxy (?; :crossed_fingers:)

In [None]:
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
import math
from pprint import pprint
from typing import DefaultDict


@dataclass(slots=True)
class PromptResults:
    """To store ~everything (potentially) important for the prompt"""

    index: int
    text: str
    tokens: torch.Tensor
    str_tokens: list[str]

    original_loss: float = math.inf

    # these are indexed by token position
    outlier_scores: torch.Tensor = field(default_factory=lambda: torch.tensor([]))
    ablation_losses: torch.Tensor = field(default_factory=lambda: torch.tensor([]))

    @classmethod
    def make(
        cls, model: HookedTransformer, index: int, prompt: str
    ) -> "PromptResults":
        tokens = model.to_tokens(prompt)
        str_tokens = model.to_str_tokens(prompt)
        original_loss = model(tokens, return_type="loss")
        return cls(
            index=index,
            text=prompt,
            tokens=tokens,
            str_tokens=str_tokens,
            original_loss=original_loss
        )

    @property
    def ablation_loss_diffs(self) -> torch.Tensor:
        return self.ablation_losses - self.original_loss

    @property
    def inv_outlier_scores(self) -> torch.Tensor:
        return 1 / self.outlier_scores

    @property
    def length(self) -> int:
        return len(self.str_tokens)

# indexed by prompt index in the list of prompts
prompt_result_dict: dict[int, PromptResults] = {
    index: PromptResults.make(solu_model, index, prompt)
    for index, prompt in enumerate(prompts)
}


def compute_outlier_scores(
    x: torch.Tensor,
    dim: int | None = None,
    *,
    keepdim: bool = False
) -> Float[torch.Tensor, "batch pos"]:
    return (
        x.max(dim=dim, keepdim=keepdim).values
        - x.mean(dim=dim, keepdim=keepdim)
    )



def compute_outlier_score_per_pos_hook(
    index: int,
    x: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos d_model"]:
    outlier_scores = compute_outlier_scores(x, dim=-1)
    prompt_result_dict[index].outlier_scores = outlier_scores.squeeze(0)
    return x


def post_mlp_pre_ln_per_pos_ablation_hook(
    pos: int,
    x: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos d_model"]:
    x[:, pos, :] = 0
    return x



def compute_ablation_losses(
    model: HookedTransformer,
    tokens: torch.Tensor,
) -> torch.Tensor:
    p_len = tokens.size(1)
    layer_name = "blocks.0.mlp.hook_mid"

    ablation_losses = []

    for pos in range(p_len):
        ablation_hook = partial(post_mlp_pre_ln_per_pos_ablation_hook, pos)
        ablation_loss = model.run_with_hooks(
            tokens,
            return_type="loss",
            fwd_hooks=[(layer_name, ablation_hook)]
        )
        ablation_losses.append(ablation_loss)
    return torch.tensor(ablation_losses).to(device=model.cfg.device)


for index, pr in tqdm.tqdm(prompt_result_dict.items()):
    outlier_score_hook = partial(compute_outlier_score_per_pos_hook, index)
    layer_name = "blocks.0.mlp.hook_mid"
    solu_model.run_with_hooks(
        pr.tokens,
        fwd_hooks=[(layer_name, outlier_score_hook)]
    )
    pr.ablation_losses = compute_ablation_losses(solu_model, pr.tokens)



  0%|          | 0/35 [00:00<?, ?it/s]

### Computing scores for important positions with low outlier-ish-ness

In [None]:
all_ablation_loss_diffs = torch.cat([pr.ablation_loss_diffs for pr in prompt_result_dict.values()], dim=-1)
all_inv_outlier_scores = torch.cat([pr.inv_outlier_scores for pr in prompt_result_dict.values()], dim=-1)

assert len(all_ablation_loss_diffs) == len(all_inv_outlier_scores) == sum(pr.length for pr in prompt_result_dict.values())

scores = all_ablation_loss_diffs * all_inv_outlier_scores

Parse these scores back into sentences #TODO: explain this better

In [None]:
def parse_scores(scores: torch.Tensor, prompt_lens: list[int]) -> dict[int, torch.Tensor]:
    parsed_scores = {}
    parsed_len = 0
    for i, prompt_len in enumerate(prompt_lens):
        parsed_scores[i] = scores[parsed_len:parsed_len+prompt_len]
        parsed_len += prompt_len
    return parsed_scores

prompt_lens = [pr.length for pr in prompt_result_dict.values()]
parsed_scores = parse_scores(scores, prompt_lens)

assert sum(len(s) for s in parsed_scores.values()) == len(all_ablation_loss_diffs)

In [None]:
# pprint(parsed_scores)

Now get prompt-position pairs with highest scores, suggesting that they are important positions where "no single neuron is high"-ish

In [None]:
def get_top_k_inds(x, k: int = 10) -> list[int]:
    return x.sort(descending=True).indices[:k]

top_k_inds = get_top_k_inds(scores)
top_k_inds, scores[top_k_inds], scores.max()

(tensor([ 18,  38, 152,   8,  23, 151,  48,  33,  29,  70]),
 tensor([606.1760, 396.3642, 324.9307, 301.1727, 239.6782, 215.7159, 198.5753,
         191.8652, 191.6696, 186.2972]),
 tensor(606.1760))

In [None]:
def get_top_k_from_parsed_scores(parsed_scores: dict[int, torch.Tensor], top_k_inds: torch.Tensor) -> list[tuple[int, int]]:
    processed = 0 #TODO: rename
    results: list[tuple[int, int]] = []
    for index, scores in parsed_scores.items():
        length = len(scores)
        for i in top_k_inds:
            if processed <= i < processed + length:
                results.append((index, i - processed))
        processed += length
    return results

top_k = get_top_k_from_parsed_scores(parsed_scores, top_k_inds)
top_k

[(1, tensor(3)),
 (3, tensor(1)),
 (4, tensor(1)),
 (5, tensor(2)),
 (6, tensor(2)),
 (7, tensor(1)),
 (9, tensor(1)),
 (12, tensor(2)),
 (24, tensor(2)),
 (24, tensor(1))]

In [None]:
for i, pos in top_k:
    print(parsed_scores[i][pos])

tensor(301.1727)
tensor(606.1760)
tensor(239.6782)
tensor(191.6696)
tensor(191.8652)
tensor(396.3642)
tensor(198.5753)
tensor(186.2972)
tensor(324.9307)
tensor(215.7159)


In [None]:
scores[top_k_inds]

tensor([606.1760, 396.3642, 324.9307, 301.1727, 239.6782, 215.7159, 198.5753,
        191.8652, 191.6696, 186.2972])

Yay, works

In [None]:
def compare_outlier_scores(ind_pos: list[tuple[int, int]]) -> None:
    layer_name_pre_ln = "blocks.0.mlp.hook_mid"
    layer_name_post_ln = "blocks.0.mlp.ln.hook_normalized"
    for i, pos in ind_pos:
        pr = prompt_result_dict[i]
        out, cache = solu_model.run_with_cache(pr.tokens)
        act_pre = cache[layer_name_pre_ln].squeeze(0)[pos]
        act_post = cache[layer_name_post_ln].squeeze(0)[pos]
        outlier_score_pre = compute_outlier_scores(act_pre, dim=-1).item()
        outlier_score_post = compute_outlier_scores(act_post, dim=-1).item()
        post_pre_ratio = outlier_score_post / outlier_score_pre
        print(f"[{i} : {pos}] Pre: {outlier_score_pre:.5f}; Post: {outlier_score_post:.5f}; Ratio: {post_pre_ratio:.5f}")


In [None]:
print("Increases in outlier score in allegedly most important positions")
compare_outlier_scores(top_k)

Increases in outlier score in allegedly most important positions
[1 : 3] Pre: 0.00775; Post: 2.41059; Ratio: 311.22069
[3 : 1] Pre: 0.00352; Post: 1.10702; Ratio: 314.65047
[4 : 1] Pre: 0.00441; Post: 1.38654; Ratio: 314.24857
[5 : 2] Pre: 0.00316; Post: 0.99321; Ratio: 314.14181
[6 : 2] Pre: 0.00473; Post: 1.48471; Ratio: 313.65597
[7 : 1] Pre: 0.00352; Post: 1.10702; Ratio: 314.65047
[9 : 1] Pre: 0.00255; Post: 0.80200; Ratio: 314.41707
[12 : 2] Pre: 0.00858; Post: 2.68574; Ratio: 313.03522
[24 : 2] Pre: 0.00277; Post: 0.86979; Ratio: 314.53183
[24 : 1] Pre: 0.00424; Post: 1.33068; Ratio: 314.00618


In [None]:
rand_inds = random.sample(range(len(scores)), 10)
rand_ind_pos = get_top_k_from_parsed_scores(parsed_scores, rand_inds)
print("Increases in outlier score in random positions")
compare_outlier_scores(rand_ind_pos)

Increases in outlier score in random positions
[2 : 1] Pre: 0.00396; Post: 1.24349; Ratio: 314.24359
[12 : 0] Pre: 0.03588; Post: 10.66425; Ratio: 297.22357
[14 : 2] Pre: 0.00577; Post: 1.80612; Ratio: 313.11740
[18 : 5] Pre: 0.00798; Post: 2.49091; Ratio: 312.15110
[26 : 6] Pre: 0.00778; Post: 2.42271; Ratio: 311.49457
[28 : 3] Pre: 0.01071; Post: 3.34791; Ratio: 312.66995
[31 : 4] Pre: 0.00985; Post: 3.08533; Ratio: 313.19182
[32 : 6] Pre: 0.02101; Post: 6.51769; Ratio: 310.24370
[33 : 2] Pre: 0.00416; Post: 1.30706; Ratio: 314.09935
[34 : 2] Pre: 0.00338; Post: 1.06090; Ratio: 313.90099


So it looks like the ratio is the same for allegedly most important positions (those with greatest ablation loss increase) as for any other.