In [1]:
from transformers import AutoTokenizer
from src.data.huggingface import HFDatasetLoader
from src.data.text import TextDatasetLoader
from src.interfaces.lens_backend import Variant
from src.interfaces.neuronpedia_api import NeuronpediaClient
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from sae_lens import SAE

#### Load in Models and SAEs

In [2]:
# Maybe try to find a way to manually hook into Deepseek-r1-distill-llama-8b?
gpt = Variant(
    model_id="gpt2-small",
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.11.hook_resid_pre",
)
# gemma = Variant(
#     model_id="google/gemma-2b-it",
#     sae_release="gemma-2b-it-res-jb",
#     sae_id="blocks.12.hook_resid_post"
# )
# llama = Variant(
#     model_id="meta-llama/Llama-3.1-8B-Instruct",
#     sae_release="llama-3-8b-it-res-jh",
#     sae_id="blocks.25.hook_resid_post"
# )


model, sae, cfg, tokenizer = gpt.get_components()

using HookedSAETransformer
Loaded pretrained model gpt2-small into HookedTransformer


#### Load in Data

In [3]:
# using hf datasets
hf_loader = HFDatasetLoader(
    hf_link="NeelNanda/pile-10k",
    tokenizer=tokenizer,
    sae=sae,          # pass in that cfg.metadata.context_size, prepend_bos
    split="train"
)

# using text dataset
sentences = [f"Sentence {i}" for i in range(50)]

text_loader = TextDatasetLoader(
    list_str=sentences,
    tokenizer=tokenizer,
    sae=sae
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Map (num_proc=10):   0%|          | 0/50 [00:00<?, ? examples/s]

In [8]:
hf_sentences = []
for i in range(5):
    token_ids = hf_loader.tokens[i]["tokens"]
    decoded = model.tokenizer.decode(token_ids, skip_special_tokens=True)
    hf_sentences.append(decoded)

#### Connect Neuronpedia Client

In [9]:
npedia_client = NeuronpediaClient(
    model_id="gpt2-small",
    sae_layer="11-res-jb",
    source_set="res-jb",
)



In [10]:
hf_sentences[0]

'It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playing on the web works, but you have to simulate multi-touch for table moving and that can be a bit confusing.\n\nThere’s a lot I’d like to talk about. I’ll go through every topic, insted of making the typical what went right/wrong list.\n\nConcept\n\nWorking over the theme was probably one of the hardest tasks I had to face.\n\nOriginally, I had an idea of what kind of'

In [11]:
test_query = "Rain pitter patters against the window, leaving the man with only a remembrance of what was once in his embrace"

In [12]:
feat_activations = npedia_client.all_text_feat(
    query=test_query,
    ignore_bos=True,
    density_threshold=0.05,
    num_results=100
)

# specific_feat = npedia_client.feat_specific_act(
#     index=1683,
#     text=hf_sentences[0]
# )

In [13]:
for act in feat_activations["result"]:
    print(act['index'])

21800
19832
23543
24098
18258
4181
9894


In [None]:
from tqdm.auto import tqdm
import torch

sae.eval()
feature_acts_list = []
batch_size = 5

token_dataset = hf_loader.tokens.select(range(100))

with torch.no_grad():
    for i in tqdm(range(0,len(token_dataset), batch_size)):

        batch = token_dataset[i:i+batch_size]
        batch_tokens = torch.tensor(batch['tokens']).to(model.cfg.device)

        hook_layer = int(sae.cfg.hook_name.split('.')[1])

        _, cache = model.run_with_cache(
            batch_tokens,
            prepend_bos=True,
            stop_at_layer = hook_layer+1
        )

        features = sae.encode(cache[sae.cfg.hook_name])
        feature_acts_list.append(features)

        del cache

feature_acts = torch.cat(feature_acts_list, dim=0)