This example trains a classification probe in order to investigate the information present in the residual stream.

# Setup

In [1]:
from transformer_lens import HookedTransformer


llm = HookedTransformer.from_pretrained("EleutherAI/pythia-2.8b-deduped-v0")

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model EleutherAI/pythia-2.8b-deduped-v0 into HookedTransformer


# Look for activation heads that move information from nouns to the word 'it'

In [2]:
nouns = ["cat","dog","car","tree","house","book","river","mountain","computer","phone","table","chair","window","door","city","road","flower","bird","fish","apple","banana","train","plane","boat","shoe","shirt","hat","cup","plate","fork","spoon","knife","bed","pillow","blanket","clock","watch","bag","box","key","pen","pencil","paper","bottle","glass","lamp","mirror","painting","camera","television","radio","guitar","piano","drum","violin","ball","bat","glove","bike","bus","truck","bridge","tower","statue","park","garden","forest","desert","island","beach","lake","ocean","cloud","star","moon","sun","planet","ring","necklace","note","wallet","coin","ticket","passport","map","letter","envelope","stamp","magazine","bean","calendar","notebook","folder","file","mouse","board","monitor"]

In [3]:
for noun in nouns:
    tokens = llm.tokenizer.tokenize(f"There is a {noun}", add_special_tokens=True)
    if len(tokens) > 5:
        print(noun)
        print(tokens)

In [4]:
sentence_templates = [
    "I saw a _noun_, it was big.",
    "Look at that _noun_, I saw it yesterday too.",
    "A _noun_ is coming, it will be here soon.",
    "After the _noun_ arrived, it looked at it."
]

In [5]:
sentence_data = [
    {"sentence": sentence.replace("_noun_", noun), "noun": noun} for noun in nouns for sentence in sentence_templates
]

print(f"{len(sentence_data)} sentences")

388 sentences


In [7]:
from tqdm import tqdm
from llm_token_finder import ActivationAnalyzer, TokenFinder, TokenDisplayer
from llm_token_finder.activation_analyser import AttentionHead
import random


all_it_heads = []

# Get random sample to find attention heads
sample_data = random.sample(sentence_data, 50)

for item in tqdm(sample_data):
    token_finder = TokenFinder.create_from_tokenizer(item["sentence"], llm.tokenizer)
    noun_token = token_finder.find_first(item["noun"], allow_space_prefix=True)
    it_token = token_finder.find_first("it", allow_space_prefix=True)
    activation_analyzer = ActivationAnalyzer.from_forward_pass(llm, item["sentence"])
    item_heads = activation_analyzer.find_heads_where_query_looks_at_value(it_token, noun_token, ignore_bos=True)
    all_it_heads.append(item_heads)

it_heads = AttentionHead.intersection(all_it_heads)

print(f"Found {len(it_heads)} attention heads moving information from noun tokens to 'it' tokens: {it_heads}")

100%|██████████| 50/50 [00:44<00:00,  1.13it/s]

Found 1 attention heads moving information from noun tokens to 'it' tokens: [5.0]





# Generate activation dataset for attention head output

wWe have found attention heads that move information from nouns to 'it'. It seems likely that these heads are moving information about the meaning of the noun, but the above doesn't confirm that, only that they are moving *something*; they could be moving some other information. In order to investigate what is being moved, we will train a classifier on the output of one of the attention heads.

In [8]:
head = it_heads[0]

In [9]:
from typing import Generator
from activation_probing.activation_dataset_generator import ActivationDatasetGenerator, ActivationGeneratorInput


# Our input generator function must yield ActivationGenerationInput objects that will be fed through the LLM to create our head-output dataset

def input_generator() -> Generator[ActivationGeneratorInput, None, None]:
    for item in sentence_data:
        token_finder = TokenFinder.create_from_tokenizer(item["sentence"], llm.tokenizer)
        it_token = token_finder.find_first("it", allow_space_prefix=True)
        yield ActivationGeneratorInput(
            text=item["sentence"],
            token_position=it_token.index,
            label_class_index=nouns.index(item["noun"])
        )

In [10]:
activation_dataset_generator = ActivationDatasetGenerator.create_attention_head_output_generator(
    llm,
    input_generator,
    class_labels=nouns,
    head=head,
    meta_data={"experiment": "noun-it test"}
)

In [11]:
from activation_probing.activation_dataset import ActivationDataset
import tempfile


# ActivationDatasetGenerator saves the dataset to a file, which can be loaded later. For this example, we'll just save and load immediately from a temporary file

with tempfile.TemporaryFile(mode="r+", encoding="utf-8") as file:
    activation_dataset_generator.generate_and_save_to(file)
    activation_dataset = ActivationDataset.load_from_file(file, device=llm.cfg.device)

In [13]:
# The activation dataset is a subclass of a PyTorch TensorDataset that saves extra metadata, including custom metadata we passed to the generator, about the data

print(f"Size of activation dataset: {len(activation_dataset)}")

print(f"Metadata: {activation_dataset.meta_data}")

Size of activation dataset: 388
Metadata: {'layer': 5, 'head': 0, 'class_labels': ['cat', 'dog', 'car', 'tree', 'house', 'book', 'river', 'mountain', 'computer', 'phone', 'table', 'chair', 'window', 'door', 'city', 'road', 'flower', 'bird', 'fish', 'apple', 'banana', 'train', 'plane', 'boat', 'shoe', 'shirt', 'hat', 'cup', 'plate', 'fork', 'spoon', 'knife', 'bed', 'pillow', 'blanket', 'clock', 'watch', 'bag', 'box', 'key', 'pen', 'pencil', 'paper', 'bottle', 'glass', 'lamp', 'mirror', 'painting', 'camera', 'television', 'radio', 'guitar', 'piano', 'drum', 'violin', 'ball', 'bat', 'glove', 'bike', 'bus', 'truck', 'bridge', 'tower', 'statue', 'park', 'garden', 'forest', 'desert', 'island', 'beach', 'lake', 'ocean', 'cloud', 'star', 'moon', 'sun', 'planet', 'ring', 'necklace', 'note', 'wallet', 'coin', 'ticket', 'passport', 'map', 'letter', 'envelope', 'stamp', 'magazine', 'bean', 'calendar', 'notebook', 'folder', 'file', 'mouse', 'board', 'monitor'], 'llm': 'pythia-2.8b-deduped-v0', 'tim

# Use head-output activation dataset to train a classification probe

In [14]:
probe, _, _, history = activation_dataset.train_probe(num_epochs=50, learning_rate=0.01, return_history=True, device="cpu")

print(f"Probe training accuracy: {history["training_accuracy"][-1]}; test accuracy {history['testing_accuracy'][-1]}")

Probe training accuracy: 1.0; test accuracy 0.9358974358974359


The probe has high accuracy, meaning it has learnt to classify nouns from the output of this attention head **at the 'it' token position**.

In [19]:
# Show an example that isn't in the dataset

test_sentence = "There is a big, square box in the corner of the room; I wonder what is in it."

token_finder = TokenFinder.create_from_tokenizer(test_sentence, llm.tokenizer)

noun_token = token_finder.find_first("box", allow_space_prefix=True)
it_token = token_finder.find_first("it", allow_space_prefix=True)

_, cache = llm.run_with_cache(llm.tokenizer.encode(test_sentence, add_special_tokens=True, return_tensors="pt"), names_filter=lambda x: x == f"blocks.{head.layer}.attn.hook_result")

In [28]:
import torch


head_output_at_it_position = cache[f"blocks.{head.layer}.attn.hook_result"][0][it_token.index][head.head]

probe_logits = probe.forward(head_output_at_it_position)
probe_predicted_index = torch.argmax(probe_logits)
probe_prediction = nouns[probe_predicted_index]

print(f"Probe prediction (should be 'box'): {probe_prediction}")

Probe prediction (should be 'box'): box
