In [6]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')

from buffer import AllActivationBuffer
from trainers.scae import TrainerConfig
from training import train_scae_suite

from datasets import load_dataset
import torch as t
from nnsight import LanguageModel
from collections import defaultdict

DTYPE = t.bfloat16
device = "cuda:0" if t.cuda.is_available() else "cpu"
model = LanguageModel("gpt2", device_map=device, torch_dtype=DTYPE)

dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    '''dumb helper class to make the dataset iterable'''
    def __init__(self, dataset):
        self.data = iter(dataset)
    def __iter__(self):
        return self
    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
C = 10
expansion = 16
k = 128

num_features = model.config.n_embd * expansion
n_layer = model.config.n_layer

In [8]:
pretrained_configs = {}
connections = defaultdict(dict)

for down_layer in range(n_layer):
    for module in ['attn', 'mlp']:
        down_name = f'{module}_{down_layer}'
        pretrained_configs[f'{module}_{down_layer}'] = {
            'repo_id': 'jacobcd52/scae', 
            'filename': f'ae_{module}_{down_layer}.pt',
            'k' : k
            }
        
        # Use random connections for testing
        if module=='mlp':
            for up_layer in range(down_layer):
                up_name = f'{module}_{up_layer}'
                connections[down_name][up_name] = t.randint(0, num_features, (num_features, C), dtype=t.long)

In [9]:
submodules = {}
for layer in range(n_layer):
    submodules[f"mlp_{layer}"] = (model.transformer.h[layer].mlp, "in_and_out")
    submodules[f"attn_{layer}"] = (model.transformer.h[layer].attn, "out")
    
buffer = AllActivationBuffer(
    data=data,
    model=model,
    submodules=submodules,
    d_submodule=model.config.n_embd,
    n_ctxs=128,
    out_batch_size = 128,
    refresh_batch_size = 256,
    device=device,
    dtype=DTYPE,
)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [10]:
trainer_cfg = TrainerConfig(
    connection_sparsity_coeff=0.1,
    steps=100,
)

trainer = train_scae_suite(
    buffer,
    module_specs=pretrained_configs,
    trainer_config=trainer_cfg,
    connections=connections,
    steps=100,
    save_steps = 10,
    dtype = DTYPE,
    device=device,
    # save_dir: Optional[str] = None,
    # log_steps: Optional[int] = None,
    # use_wandb: bool = False,
    # hf_repo_id: Optional[str] = None,
    # seed: Optional[int] = None,
)

  0%|          | 0/100 [00:00<?, ?it/s]