In [1]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from transformer_lens.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm



In [2]:
# NBVAL_IGNORE_OUTPUT
model = EasyTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

import transformer_lens.utils as utils
# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


In [4]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [5]:
sorted_vocab = sorted(list(model.tokenizer.vocab.items()), key=lambda n:n[1])
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



In [6]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = model.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(model.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]], device='mps:0')
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


In [7]:
import torch

def complete(reference_text, max_tokens=100, T=0.7):
    tokens = model.to_tokens(reference_text)
    for i in range(max_tokens):
        tokens = tokens.to(device)
        logits, cache = model.run_with_cache(tokens)
        
        # Apply temperature scaling
        scaled_logits = logits / T
        
        # Convert logits to probabilities
        probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
        
        # Sample from the probability distribution
        next_token = torch.multinomial(probs[0, -1], num_samples=1)
        
        # Concatenate the new token to the existing sequence
        tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=-1)
    
    # Decode the tokens to text
    return model.tokenizer.decode(tokens[0]), cache

text, cache = complete(reference_text, max_tokens=20, T=0.5)

In [8]:
for activation_name, activation in cache.cache_dict.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 54, 768])
hook_pos_embed torch.Size([1, 54, 768])
blocks.0.hook_resid_pre torch.Size([1, 54, 768])
blocks.0.ln1.hook_scale torch.Size([1, 54, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 54, 768])
blocks.0.attn.hook_q torch.Size([1, 54, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 54, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 54, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 54, 54])
blocks.0.attn.hook_pattern torch.Size([1, 12, 54, 54])
blocks.0.attn.hook_z torch.Size([1, 54, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 54, 768])
blocks.0.hook_resid_mid torch.Size([1, 54, 768])
blocks.0.ln2.hook_scale torch.Size([1, 54, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 54, 768])
blocks.0.mlp.hook_pre torch.Size([1, 54, 3072])
blocks.0.mlp.hook_post torch.Size([1, 54, 3072])
blocks.0.hook_mlp_out torch.Size([1, 54, 768])
blocks.0.hook_resid_post torch.Size([1, 54, 768])
ln_final.hook_scale torch.Size([1, 54, 1])
ln_final.hook_normalized torc

In [9]:
for name, param in model.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


In [10]:
reference_gpt2 = model

# As a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': False,
 'use_attn_result': Fals

In [11]:

@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [14]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))
    
    def forward(self, residual):
        # residual: [batch, position, d_model]
        if self.cfg.debug: print("Residual:", residual.shape)
        residual = residual - einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
        # Calculate the variance, square root it. Add in an epsilon to prevent divide by zero.
        scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + cfg.layer_norm_eps).sqrt()
        normalized = residual / scale
        normalized = normalized * self.w + self.b
        if self.cfg.debug: print("Normalized:", residual.shape)
        return normalized

class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        embed = self.W_E[tokens, :] # [batch, position, d_model]
        if self.cfg.debug: print("Embeddings:", embed.shape)
        return embed

class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device=device))
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)
        
        q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K
        
        attn_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
        attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
        attn_scores = self.apply_causal_mask(attn_scores)

        pattern = attn_scores.softmax(dim=-1) # [batch, n_head, query_pos, key_pos]

        v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V

        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", pattern, v)

        attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
        return attn_out

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
    
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
        pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        post = gelu_new(pre)
        mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
        return mlp_out
    
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out
        
        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post

class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
        logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
        return logits


class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits



In [15]:
demo_gpt2 = DemoTransformer(Config(debug=False))
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
demo_gpt2.to(device)

len(demo_gpt2.blocks)

12

In [20]:

def permuted(self, permutation):
    def model(tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block_id in permutation:
            block = self.blocks[block_id]
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits
    return model


def complete(model, reference_text, max_tokens=100, T=1e-3):
    tokens = reference_gpt2.to_tokens(reference_text)
    for i in range(max_tokens):
        tokens = tokens.to(device)
        logits = model(tokens)
        
        # Apply temperature scaling
        scaled_logits = logits / T
        
        # Convert logits to probabilities
        probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
        
        # Sample from the probability distribution
        next_token = torch.multinomial(probs[0, -1], num_samples=1)
        
        # Concatenate the new token to the existing sequence
        tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=-1)
    
    # Decode the tokens to text
    return reference_gpt2.tokenizer.decode(tokens[0]), cache

permuted_model = permuted(demo_gpt2, list(range(12)))
text, _ = complete(demo_gpt2, reference_text)
print(text)
text, _ = complete(permuted_model, reference_text)
print(text)

<|endoftext|>The distance between the Colosseum and the Eiffel is approximately 1,000 miles. The distance between the Eiffel and the Colosseum is approximately 1,000 miles.

The distance between the Colosseum and the Eiffel is approximately 1,000 miles.

The distance between the Colosseum and the Eiffel is approximately 1,000 miles.

The distance between the Colosseum and the Eiffel is approximately 1,000 miles.

The distance between the Colosse
<|endoftext|>The distance between the Colosseum and the Eiffel is approximately 1,000 miles. The distance between the Eiffel and the Colosseum is approximately 1,000 miles.

The distance between the Colosseum and the Eiffel is approximately 1,000 miles.

The distance between the Colosseum and the Eiffel is approximately 1,000 miles.

The distance between the Colosseum and the Eiffel is approximately 1,000 miles.

The distance between the Colosse


In [22]:
for i in range(1, 12):
    permuted_model = permuted(demo_gpt2, list(range(i)))
    text, _ = complete(permuted_model, reference_text)
    print(f"Layer {i}: {text}")

Layer 1: <|endoftext|>The distance between the Colosseum and the Eiffel is approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately a

Observations (n=1):
- layer 1 only repeats the token
- layer 3 is the first to bring up a number-ish completion (20th century)
- layers 4 is the first to bring up 'distance' related words
- layer 7 is the first to bring up an actual distance (200 miles)
- layer 8 is the first to produce long grammatically correct phrases / sentences

The distance between Rome and Paris is approximately
--------------------------------------------------------------------------------
1 + 3 = 4 
4 + 5 = 9
2 + 3 =
--------------------------------------------------------------------------------
[{"name": "James", "age": 34, "skills": ["Python", "git"]}, {"name": "Alan", "age": 28, "skills": ["MS Office", "
--------------------------------------------------------------------------------
Obama was elected in the year
--------------------------------------------------------------------------------


In [28]:
for reference_text in texts:
    for i in range(1, 13):
        permuted_model = permuted(demo_gpt2, list(range(i)))
        text, _ = complete(permuted_model, reference_text)
        print(f"Layer {i}: {text}")

Layer 1: <|endoftext|>The distance between Rome and Paris is approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately a

# Basic questions
- what happens if we ablate the last n layers?
  - [ ] systematic evaluation
- what happens if we ablate the first n layers?

# Evaluation
- if we ablate something we want to save this info for later
  - (model, prompt, deleted compute nodes, logit diffs)


# Synthetic Dataset
- we want to quickly form hypothesis of the form: "(circuit) is important for (task)"
- we start with circuit = (layer), transcribe the logit diffs
- we let gpt-4 guess:
  - what is a task in which this circuit is relevant
  - what other layers might it interact strongly with?
- generate a task
- 

# Inspecting layers
- which token logits differ the most if we ablate layer i?
- for each layer, we can collect a dataset:
  - (prompt, logit_diffs: token -> float)


# Random thoughts
- we potentially don't want to look at single layers / attention heads, but ablate arbitrary parts (spanning multiple layers / only parts of layers)
-

In [44]:

import os
import pickle
import glob

class ExperimentData():
    def __init__(self, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.items = self._load()
    
    def _load(self):
        # Find all .pkl files in the save directory
        files = glob.glob(os.path.join(self.save_dir, '*.pkl'))
        # Extract version numbers from the filenames
        versions = [int(os.path.basename(f).split('.')[0]) for f in files]
        # If there are no files, return an empty dict
        if not versions:
            return []
        # Find the latest version
        latest_version = max(versions)
        # Load the latest version
        with open(os.path.join(self.save_dir, f'{latest_version:05d}.pkl'), 'rb') as f:
            return pickle.load(f)
    
    def __call__(self, item):
        self.items += [item]
    
    def filter(self, **filters):
        res = []
        for item in self.items:
            if self._check_filter(item, filters):
                res += [item]
        return res
    
    def _check_filter(self, item, filters):
        for k, condition in filters.items():
            # We check if each condition in the filters dict is true
            v = item.get(k)
            if not (
                v == condition or (
                    type(condition) in ['funciton', 'method'] and
                    condition(v)
                )
            ):
                # the k, condition might not be true, but it might also be that it is a nested dict of more conditions
                if isinstance(condition, dict) and isinstance(v, dict):
                    for k_, c_ in condition.items():
                        if not self._check_filter(v.get(k_), c_):
                            # c is a dict of conditions that is False because c_ is False
                            return False
                        # the entire dict of conditions is checked and all are true
                else:
                    return False
        return True
                
    
    def to_disk(self):
        # Find the next version number
        files = glob.glob(os.path.join(self.save_dir, '*.pkl'))
        versions = [int(os.path.basename(f).split('.')[0]) for f in files]
        next_version = max(versions) + 1 if versions else 0
        # Save the new version
        with open(os.path.join(self.save_dir, f'{next_version:05d}.pkl'), 'wb') as f:
            pickle.dump(self.items, f)


store = ExperimentData('/Users/nielswarncke/Documents/code/TransformerLens/experiments')

store({
    'model': 'test-model',
    'prompt': 'hello world',
    'completion': 'yo yo yo'
})

store({
    'model': 'test-model',
    'prompt': 'hello',
    'completion': 'yo yo yo'
})

store.filter(model='test-model', prompt='hello')

[{'model': 'test-model', 'prompt': 'hello', 'completion': 'yo yo yo'}]

In [37]:
type(store.add)

method

In [38]:
type(lambda i: i)

function

In [45]:
import json
reference_text = "The distance between the Colosseum and the Eiffel is approximately"
texts = [
    "The distance between Rome and Paris is approximately",
    "1 + 3 = 4 \n4 + 5 = 9\n2 + 3 =",
    json.dumps([{'name': 'James', 'age': 34, 'skills': ['Python', 'git']}, {'name': 'Alan', 'age': 28, 'skills': ['MS Office', '<cut>']}]).split('<cut>')[0],
    "Obama was elected in the year",
]
for prompt in texts:
    for i in range(0, 13):
        keep_layer_ids =  list(range(i))
        permuted_model = permuted(demo_gpt2, keep_layer_ids)
        text, _ = complete(permuted_model, prompt)
        print(f"Layer {i}: {text}")
        store({
            'model': 'gpt-2-small',
            'prompt': prompt,
            'completion': text,
            'ablation': {
                'layers': keep_layer_ids
            }
        })

Layer 0: <|endoftext|>The distance between Rome and Paris is approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately approximately a

In [None]:
for prompt in texts:
    for i in range(1, 13):
        keep_layer_ids =  list(range(0, i))
        permuted_model = permuted(demo_gpt2, keep_layer_ids)
        text, _ = complete(permuted_model, prompt)
        print(f"Layer {i}: {text}")
        store({
            'model': 'gpt-2-small',
            'prompt': prompt,
            'completion': text,
            'ablation': {
                'layers': keep_layer_ids
            }
        })

# Looking at logit changes
We now want to generate objects like this one:
```
{
    'model': 'gpt-2-small',
    'prompt': prompt,
    'token': pos_id
    'logit_diffs': [
        {' hello': -10},
        ...
    ]
    'ablation': ablation
}
```

In [None]:
import json
reference_text = "The distance between the Colosseum and the Eiffel is approximately"
texts = [
    "The distance between Rome and Paris is approximately",
    "1 + 3 = 4 \n4 + 5 = 9\n2 + 3 =",
    json.dumps([{'name': 'James', 'age': 34, 'skills': ['Python', 'git']}, {'name': 'Alan', 'age': 28, 'skills': ['MS Office', '<cut>']}]).split('<cut>')[0],
    "Obama was elected in the year",
]
for prompt in texts:
    for i in range(0, 13):
        keep_layer_ids =  list(range(i))
        permuted_model = permuted(demo_gpt2, keep_layer_ids)
        text, _ = complete(permuted_model, prompt)
        print(f"Layer {i}: {text}")
        store({
            'model': 'gpt-2-small',
            'prompt': prompt,
            'completion': text,
            'ablation': {
                'layers': keep_layer_ids
            }
        })