# Pairing with Keith on tying specific features to a behavior/ability

1. Find a task that gelu-2l can do.
2. Identify features on that task. 
3. Understand them / prove with causal intervention. 

In [1]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch as t
import numpy as np
import gradio as gr

import torch as t
#from google.colab import drive

# This will prompt for authorization.
#drive.mount('/content/drive')

import einops
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import tqdm
from functools import partial
from datasets import load_dataset
from IPython.display import display

In [2]:
# %%
import os
os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
# %%
from neel.imports import *
from neel_plotly import *

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(False)

model = HookedTransformer.from_pretrained("gelu-2l")

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab
# %%
evals.sanity_check(model)

In IPython
In IPython
Set autoreload
Imported everything!
Loaded pretrained model gelu-2l into HookedTransformer


tensor(3.9483, device='cuda:0')

In [18]:
from transformer_lens.utils import test_prompt


prompt = "I love my cat. I love my cat. I love my "
answer = " cat"
utils.test_prompt(prompt, answer,model) # Pass, 0


Tokenized prompt: ['<|BOS|>', 'I', ' love', ' my', ' cat', '.', ' I', ' love', ' my', ' cat', '.', ' I', ' love', ' my', ' ']
Tokenized answer: [' cat']


Top 0th token. Logit: 13.91 Prob: 14.68% Token: |icky|
Top 1th token. Logit: 13.51 Prob:  9.85% Token: |iced|
Top 2th token. Logit: 12.91 Prob:  5.39% Token: |icing|
Top 3th token. Logit: 12.15 Prob:  2.53% Token: |xt|
Top 4th token. Logit: 11.96 Prob:  2.10% Token: |ump|
Top 5th token. Logit: 11.85 Prob:  1.88% Token: |ire|
Top 6th token. Logit: 11.83 Prob:  1.85% Token: |ute|
Top 7th token. Logit: 11.51 Prob:  1.34% Token: |irc|
Top 8th token. Logit: 11.46 Prob:  1.27% Token: |ia|
Top 9th token. Logit: 11.24 Prob:  1.02% Token: |urns|


In [23]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
# import circuitsvis as cv

def current_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of diagonal elements
            score = attention_pattern.diagonal().mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

def prev_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of sub-diagonal elements
            score = attention_pattern.diagonal(-1).mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

def first_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of 0th elements
            score = attention_pattern[:, 0].mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads


def generate_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch: int = 1
) -> Int[Tensor, "batch full_seq_len"]:
    '''
    Generates a sequence of repeated random tokens

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    # SOLUTION
    rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)
    rep_tokens = t.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).to(device)
    return rep_tokens

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> Tuple[t.Tensor, t.Tensor, ActivationCache]:
    '''
    Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache

    Should use the `generate_repeated_tokens` function above

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
        rep_logits: [batch, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    '''
    # SOLUTION
    rep_tokens = generate_repeated_tokens(model, seq_len, batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    return rep_tokens, rep_logits, rep_cache

In [32]:
device = "cuda"
def generate_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch: int = 1
) -> Int[Tensor, "batch full_seq_len"]:
    '''
    Generates a sequence of repeated random tokens

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    # SOLUTION
    rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)
    rep_tokens = t.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).to(device)
    return rep_tokens

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> Tuple[t.Tensor, t.Tensor, ActivationCache]:
    '''
    Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache

    Should use the `generate_repeated_tokens` function above

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
        rep_logits: [batch, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    '''
    # SOLUTION
    rep_tokens = generate_repeated_tokens(model, seq_len, batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    return rep_tokens, rep_logits, rep_cache

seq_len = 128
batch = 32


rep_tokens = generate_repeated_tokens(model, seq_len, batch)
logits, cache = model.run_with_cache(rep_tokens)
per_token_loss = model.loss_fn(logits, rep_tokens, True)
# per_token_loss.shape
px.line(per_token_loss.mean(0).detach().cpu().numpy(), title="Per-token loss")

In [54]:
px.imshow(cache["pattern", 1][0,0].detach().cpu())#.shape

In [57]:
def induction_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be induction heads

    Remember - the tokens used to generate rep_cache are (bos_token, *rand_tokens, *rand_tokens)
    '''
    # SOLUTION
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of (-seq_len+1)-offset elements
            seq_len = (attention_pattern.shape[-1] - 1) // 2
            score = attention_pattern.diagonal(-seq_len+1).sum().mean()
            print(score)
            if score > 5:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads


seq_len = 128
batch = 1

rep_tokens = generate_repeated_tokens(model, seq_len, batch)
logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
per_token_loss = model.loss_fn(logits, rep_tokens, True)
# per_token_loss.shape
# px.line(per_token_loss.mean(0).detach().cpu().numpy(), title="Per-token loss")

induction_attn_detector(cache)

tensor(0.3096, device='cuda:0')
tensor(0.7334, device='cuda:0')
tensor(0.2204, device='cuda:0')
tensor(0.0723, device='cuda:0')
tensor(0.6952, device='cuda:0')
tensor(0.0286, device='cuda:0')
tensor(0.4088, device='cuda:0')
tensor(0.9894, device='cuda:0')
tensor(0.0522, device='cuda:0')
tensor(1.1419, device='cuda:0')
tensor(0.0727, device='cuda:0')
tensor(0.3593, device='cuda:0')
tensor(37.0073, device='cuda:0')
tensor(0.8489, device='cuda:0')
tensor(96.2563, device='cuda:0')
tensor(0.3753, device='cuda:0')


['1.4', '1.6']

In [79]:
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self
encoder0 = AutoEncoder.load_from_hf("gelu-2l_L0_16384_mlp_out_51", "cuda")
encoder1 = AutoEncoder.load_from_hf("gelu-2l_L1_16384_mlp_out_50", "cuda")


gelu-2l_L0_16384_mlp_out_51_cfg.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

{'act_name': 'blocks.0.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'device': 'cuda',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 0,
 'lr': 0.0001,
 'model_batch_size': 512,
 'model_name': 'gelu-2l',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 51,
 'seq_len': 128,
 'site': 'mlp_out'}


gelu-2l_L0_16384_mlp_out_51.pt:   0%|          | 0.00/67.2M [00:00<?, ?B/s]

{'act_name': 'blocks.1.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'device': 'cuda',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 1,
 'lr': 0.0001,
 'model_batch_size': 512,
 'model_name': 'gelu-2l',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 50,
 'seq_len': 128,
 'site': 'mlp_out'}


In [65]:

seq_len = 128
batch = 32

rep_tokens = generate_repeated_tokens(model, seq_len, batch)
logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
per_token_loss = model.loss_fn(logits, rep_tokens, True)
original_mlp_out = cache["mlp_out", 1]
loss, reconstr_mlp_out, hidden_acts, l2_loss, l1_loss = encoder1(original_mlp_out)




In [68]:
from torch.nn.functional import mse_loss

mse_loss(original_mlp_out[:128], reconstr_mlp_out[:128])

tensor(0.1503, device='cuda:0')

In [69]:
mse_loss(original_mlp_out[128:], reconstr_mlp_out[128:])

tensor(0.7503, device='cuda:0')

In [71]:


def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out
def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)
print("reconstr", model.run_with_hooks(rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", 1), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out))], return_type="loss"))
print("Orig", model(rep_tokens, return_type="loss"))
print("Zero", model.run_with_hooks(rep_tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("mlp_out", 1), zero_abl_hook)]))

reconstr tensor(9.2783, device='cuda:0')
Orig tensor(7.7070, device='cuda:0')
Zero tensor(9.8091, device='cuda:0')


In [73]:
logits = model.run_with_hooks(
    rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", 1), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out))])
# per_token_loss = model.loss_fn(logits, example_tokens, True)
per_token_loss = model.loss_fn(logits, rep_tokens, True)
per_token_loss.shape

torch.Size([32, 256])

In [100]:
logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
per_token_loss = model.loss_fn(logits, rep_tokens, True)

layer = 0
original_mlp_out = cache["mlp_out", layer]
loss, reconstr_mlp_out_0, hidden_acts, l2_loss, l1_loss = encoder0(original_mlp_out)

def repeat_hook(mlp_out, hook):
    mlp_out_half = mlp_out[:, 1:129]
    print(mlp_out_half.shape)
    return t.cat([mlp_out[:,:129], mlp_out_half], dim=-1).to(device)

def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out

def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
per_token_loss_original = model.loss_fn(logits, rep_tokens, True).mean(0).detach().cpu().numpy()

logits = model.run_with_hooks(
    rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", layer), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out_0))])
# per_token_loss = model.loss_fn(logits, example_tokens, True)
per_token_loss = model.loss_fn(logits, rep_tokens, True)
per_token_loss_with_reconstruction = per_token_loss.mean(0).detach().cpu().numpy()
# px.line(per_token_loss.mean(0).detach().cpu().numpy())

logits = model.run_with_hooks(rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", layer), zero_abl_hook)])
per_token_loss = model.loss_fn(logits, rep_tokens, True)
per_token_loss_zero_ablation = per_token_loss.mean(0).detach().cpu().numpy()

logits = model.run_with_hooks(rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", layer), repeat_hook)])
per_token_loss = model.loss_fn(logits, rep_tokens, True)
per_token_loss_replacement = per_token_loss.mean(0).detach().cpu().numpy()

df = pd.DataFrame({
    "per_token_loss_original": per_token_loss_original,
    "per_token_loss_with_reconstruction": per_token_loss_with_reconstruction,
    "per_token_loss_zero_ablation": per_token_loss_zero_ablation,
    "per_token_loss_replacement"   : per_token_loss_replacement
})

px.line(df, title="Per-token loss with and without reconstruction")

torch.Size([32, 128, 512])


In [90]:
logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
per_token_loss = model.loss_fn(logits, rep_tokens, True)

layer = 0
original_mlp_out = cache["mlp_out", layer]
loss, reconstr_mlp_out_0, hidden_acts, l2_loss, l1_loss = encoder0(original_mlp_out)


first_half = original_mlp_out[1:129]
second_half = original_mlp_out[129:]
similarity = torch.cosine_similarity(first_half, second_half, dim=-1)
px.histogram(similarity.detach().cpu().numpy(), title="Cosine similarity between first and second half of original mlp_out").show()

first_half = reconstr_mlp_out_0[1:129]
second_half = reconstr_mlp_out_0[129:]
similarity = torch.cosine_similarity(first_half, second_half, dim=-1)
px.histogram(similarity.detach().cpu().numpy(), title="Cosine similarity between first and second half of reconstructed mlp_out").show()

In [89]:
px.imshow(cache["pattern",0].detach().cpu(), facet_col=0)

In [None]:
# cache["z"]

# cache all the L0H1 outs, pass them through ML0 and cache that. 
# compare this to the differences between the ground truth MLP0 Out and the Reconstruction. 
# hypothesis that the difference matches the patter of the L0H1. 


In [78]:
# logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
# per_token_loss = model.loss_fn(logits, rep_tokens, True)

# layer = 1
# original_mlp_out = cache["mlp_out", layer]
# loss, reconstr_mlp_out_1, hidden_acts, l2_loss, l1_loss = encoder1(original_mlp_out)

# def reconstr_hook(mlp_out, hook, new_mlp_out):
#     return new_mlp_out
# def zero_abl_hook(mlp_out, hook):
#     return torch.zeros_like(mlp_out)

# logits, cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)
# per_token_loss_original = model.loss_fn(logits, rep_tokens, True).mean(0).detach().cpu().numpy()

# logits = model.run_with_hooks(
#     rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", layer), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out_1))])
# # per_token_loss = model.loss_fn(logits, example_tokens, True)
# per_token_loss = model.loss_fn(logits, rep_tokens, True)
# per_token_loss_with_reconstruction = per_token_loss.mean(0).detach().cpu().numpy()
# # px.line(per_token_loss.mean(0).detach().cpu().numpy())

# logits = model.run_with_hooks(rep_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", layer), zero_abl_hook)])
# per_token_loss = model.loss_fn(logits, rep_tokens, True)
# per_token_loss_zero_ablation = per_token_loss.mean(0).detach().cpu().numpy()

# df = pd.DataFrame({
#     "per_token_loss_original": per_token_loss_original,
#     "per_token_loss_with_reconstruction": per_token_loss_with_reconstruction,
#     "per_token_loss_zero_ablation": per_token_loss_zero_ablation
# })

# px.line(df, title="Per-token loss with and without reconstruction")