# Load in model

In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification


# Download the model
device = "cuda:0"
model_name="EleutherAI/Pythia-70M-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


# Load in Sparse AE's

In [7]:
from autoencoders import *
# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "jbrinkma/Pythia-70M-deduped-SAEs"
filename = ["Pythia-70M-deduped-1.pt", "Pythia-70M-deduped-mlp-2.pt"]
autoencoders = []
for filen in filename:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to_device(device)
    autoencoders.append(autoencoder)
cache_names = ["gpt_neox.layers.1", "gpt_neox.layers.2.mlp"]

# Load training data

In [5]:
max_seq_length=30 # max length of per data point
from datasets import load_dataset
dataset_name = "NeelNanda/pile-10k"
dataset = load_dataset(dataset_name, split="train").map(
    lambda x: tokenizer(x['text']),
    batched=True
).filter(
    lambda x: len(x['input_ids']) > max_seq_length
).map(
    lambda x: {'input_ids': x['input_ids'][:max_seq_length]}
)

Found cached dataset parquet (/root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-7888eac6661e5b23.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-a67af88cba1e56b6.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-5872d54a9717e874.arrow


## Grab Intermediate Layer Activations

In [8]:
from baukit import TraceDict
# with TraceDict(model, cache_names) as ret:

num_features, d_model = autoencoder.encoder.shape
datapoints = dataset.num_rows
batch_size = 32
with torch.no_grad(), dataset.formatted_as("pt"):
    dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
    for i, batch in enumerate(tqdm(dl)):
        batch = batch.to(device)
        # Get LLM intermediate activations
        with TraceDict(model, cache_names) as ret:
            _ = model(batch)
        # Get SAE intermediate codes
        for ae_ind, cache_name in enumerate(cache_names):
            autoencoder = autoencoders[ae_ind]
            internal_activations = ret[cache_name].output
            # check if instance tuple ie a layer output
            if(isinstance(internal_activations, tuple)):
                internal_activations = internal_activations[0]

            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
        break

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]


In [11]:
# Get MLP
model.gpt_neox.layers[2].mlp

GPTNeoXMLP(
  (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
  (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
  (act): GELUActivation()
)