In [7]:
import torch as t
import itertools as it
import polars as pl
import altair as alt
from sae_lens import SAE
from transformer_lens import HookedTransformer
from datasets import load_dataset
from torch.utils.data import DataLoader
from collections import defaultdict


In [9]:
device = "cuda"
dataset_name = "stas/openwebtext-10k"
transformer_model_name = "gpt2-small"
sae_model_name = "gpt2-small-res-jb"

chunk_size = 20
batch_size = 16

In [10]:
model = HookedTransformer.from_pretrained(transformer_model_name, device=device)

for params in model.parameters():
    params.requires_grad_(False)




Loaded pretrained model gpt2-small into HookedTransformer


In [11]:
dataset = load_dataset(dataset_name, split='train', trust_remote_code=True)
dataset = dataset.to_list()
dataset = [x["text"] for x in dataset]

In [12]:
def chunker(iterable, n):
    for i in range(0, len(iterable), n):
        yield iterable[i:i+n]

dataset_chunked = [
    model.tokenizer.decode(chunk)
    for x in dataset[:3]
    for chunk in chunker(model.tokenizer(x, padding=True, truncation=True)["input_ids"], chunk_size)
]

len(dataset_chunked)

128

In [23]:
dataloader = DataLoader(dataset_chunked, batch_size=batch_size)

for layer in range(0, 6):
    sae_block_hook = f"blocks.{layer}.hook_resid_pre"

    print(f"Printing examples for block {sae_block_hook}")

    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_model_name,
        sae_id=sae_block_hook,
        device=device)

    for params in sae.parameters():
        params.requires_grad_(False)

    eos_token_id = model.tokenizer.eos_token_id

    for batch_idx, batch in enumerate(dataloader):
        tokens = model.tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
        token_ids = tokens.input_ids

        logits, cache = model.run_with_cache(token_ids)
        sae_feats = sae.encode(cache[sae_block_hook])
        sae_out = sae.decode(sae_feats)

        n = sae_feats.size(0)

        for i, j in it.product(r := range(n), r):
            if i >= j: continue

            for m, n in it.product(r := range(chunk_size), r):
                # A lot of false positives from EOS tokens
                if eos_token_id in token_ids[i, :m+1] or eos_token_id in token_ids[j, :n+1]:
                    continue

                f_i = sae_feats[i, m]
                f_j = sae_feats[j, n]

                x_i = sae_out[i, m]
                x_j = sae_out[j, n]

                norm_feat = (f_i - f_j).norm(2)
                norm_out = (x_i - x_j).norm(2)

                if 0 < norm_out < 10 and norm_feat > norm_out * 2:
                    str_i = model.tokenizer.decode(token_ids[i, :m+1])
                    str_j = model.tokenizer.decode(token_ids[j, :n+1])

                    print()
                    print(f"{layer = }, {i = }, {j = }")
                    print(f"norm_feats   = {norm_feat.item()}")
                    print(f"norm_out     = {norm_out.item()}")
                    print(f"str_i        = {repr(str_i)}")
                    print(f"str_j        = {repr(str_j)}")
                    print(f"argsort(f_i) = {f_i.argsort()[-5:]}")
                    print(f"argsort(f_j) = {f_j.argsort()[-5:]}")

    del sae
    t.cuda.empty_cache()


Printing examples for block blocks.0.hook_resid_pre

layer = 0, i = 0, j = 10
norm_feats   = 3.140711784362793
norm_out     = 1.4815704822540283
str_i        = 'A magazine supplement with an image of Adolf Hitler and the'
str_j        = ' tight lid on Hitler’s writings has become a'
argsort(f_i) = tensor([24182, 18458,  7261, 15603, 17652], device='cuda:0')
argsort(f_j) = tensor([18458, 24182, 15603, 19064, 11233], device='cuda:0')

layer = 0, i = 2, j = 6
norm_feats   = 1.9613784551620483
norm_out     = 0.9013238549232483
str_i        = ' of Bavaria,'
str_j        = 't have,'
argsort(f_i) = tensor([20998, 16482, 15033, 22351, 11916], device='cuda:0')
argsort(f_j) = tensor([19462, 15033, 22351, 11916,  7292], device='cuda:0')

layer = 0, i = 3, j = 10
norm_feats   = 3.2950387001037598
norm_out     = 1.5878008604049683
str_i        = ')\n\nThe city that was the'
str_j        = ' tight lid on Hitler’s writings has become a'
argsort(f_i) = tensor([19064, 15603, 24182, 20998, 17652], devic