In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
from buffer import MultiModelActivationBuffer

from datasets import load_dataset
import torch as t

from nnsight import LanguageModel
from buffer import MultiModelActivationBuffer
from trainers.top_k import TopKTrainer, AutoEncoderTopK
from training import trainSAE

In [None]:
layer = 5
expansion = 16
num_tokens = int(1e6)
out_batch_size = 4096

submodule_list = []
model_list = []
for step in [1, 2, 4, 8]:
    model = LanguageModel("EleutherAI/pythia-70m", revision=f"step{step}", trust_remote_code=False)
    model_list.append(model)
    submodule_list.append(model.gpt_neox.layers[1].mlp)
                       

device = "cuda:0"

activation_dim = 512
dictionary_size = expansion * activation_dim

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True,
                                trust_remote_code=True)

class CustomData():
    def __init__(self, dataset):
        self.data = iter(dataset)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)


buffer = MultiModelActivationBuffer(
    data=data,
    model_list=model_list,
    submodule_list=submodule_list,
    d_submodule=activation_dim, # output dimension of the model component
    n_ctxs=128,  # you can set this higher or lower dependong on your available memory
    device=device,
    refresh_batch_size=128,
    out_batch_size=out_batch_size,
)  # buffer will yield batches of tensors of dimension = submodule's output dimension

In [None]:
trainer_cfg = {
    "trainer": TopKTrainer,
    "dict_class": AutoEncoderTopK,
    "activation_dim": activation_dim * len(model_list),
    "dict_size": dictionary_size,
    "lr": 1e-3,
    "device": device,
    "steps": num_tokens // out_batch_size,
    "k": 128,
    "layer": layer,
    "lm_name": "blah",
    "warmup_steps": 0,
}

# train the sparse autoencoder (SAE)
ae = trainSAE(
    data=buffer,  # you could also use another (i.e. pytorch dataloader) here instead of buffer
    trainer_configs=[trainer_cfg],
    steps=num_tokens // out_batch_size,
    autocast_dtype=t.bfloat16
)