# Setup

In [None]:
%pip install -r requirements.txt

In [None]:
import torch 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt
import os
from baukit import Trace, TraceDict



def init_cfg():
    cfg = dotdict()
    # models: "EleutherAI/pythia-6.9b", "lomahony/eleuther-pythia6.9b-hh-sft", "usvsnsp/pythia-6.9b-ppo", "Dahoas/gptj-rm-static", "reciprocate/dahoas-gptj-rm-static"
    # cfg.model_name="lomahony/eleuther-pythia6.9b-hh-sft"
    # "EleutherAI/pythia-70m", "lomahony/pythia-70m-helpful-sft", "lomahony/eleuther-pythia70m-hh-sft"
    cfg.model_name="EleutherAI/pythia-70m-deduped"
    cfg.layers=[0, 1] # Change this to run multiple layers
    cfg.setting="residual"
    # cfg.tensor_name="gpt_neox.layers.{layer}" or "transformer.h.{layer}"
    cfg.tensor_name="gpt_neox.layers.{layer}.mlp"
    original_l1_alpha = 8e-4
    cfg.l1_alpha=original_l1_alpha
    cfg.l1_alphas=[8e-5, 1e-4, 2e-4, 4e-4, 8e-4, 1e-3, 2e-3, 4e-3, 8e-3]
    cfg.sparsity=None
    cfg.num_epochs=2
    cfg.model_batch_size=8
    cfg.lr=1e-3 # ORIGINAL: 1e-3
    cfg.kl=False
    cfg.reconstruction=False
    #cfg.dataset_name="NeelNanda/pile-10k"
    cfg.dataset_name="Elriggs/openwebtext-100k"
    cfg.device="cuda:0"
    cfg.ratio = 8
    cfg.seed = 0
    # cfg.device="cpu"

    return cfg

# Main Code

In [None]:
def setup_execute_training(model_name,
                          dataset_name,
                          ratio,
                          layers,
                          seed,
                          wandb_log,
                          split,
                          epoches):
    cfg = init_cfg()
    cfg.num_epoches = epoches
    cfg.model_name = model_name
    cfg.dataset_name = dataset_name
    cfg.ratio = ratio
    cfg.layers = layers
    cfg.seed = seed
    cfg.wandb_log = wandb_log

    model, tokenizer = load_model(cfg)
    get_activation_size(cfg, model, tokenizer)

    # naming
    start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    wandb_run_name = f"{cfg.model_name}_{cfg.dataset_name}_s{cfg.seed}_dim{cfg.ratio*cfg.activation_size}_{start_time[4:]}"
    model_name_path = cfg.model_name.replace("/", "_")
    dataset_name_path = cfg.dataset_name.split("/")[-1]
    storage_path = f"{model_name_path}/{dataset_name_path}_s{cfg.seed}"
    filename = f"{cfg.ratio*cfg.activation_size}_{start_time[4:]}"
    token_loader = init_dataloader(cfg, model, tokenizer, split)
    autoencoders, optimizers = init_autoencoder(cfg)
    
    if wandb_log:
        setup_wandb(cfg, wandb_run_name)
    
    training_run(cfg, model, optimizers, autoencoders, token_loader)

    for layer in range(len(cfg.layers)):
        model_save(cfg, autoencoders[layer], storage_path, filename, cfg.layers[layer])

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-160m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 32
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

Activation size: 768


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-f0c48232d0cb60cb_*_of_00008.arrow


Number of tokens: 112750592


  0%|          | 0/55054 [00:00<?, ?it/s]

Sparsity: 12321.4 | Dead Features: 24576 | Total Loss: 0.7348 | Reconstruction Loss: 0.1742 | L1 Loss: 0.5607 | l1_alpha: 8.0000e-04 | Tokens: 0 | Self Similarity: 1.0000
Sparsity: 12315.4 | Dead Features: 24576 | Total Loss: 0.3584 | Reconstruction Loss: 0.0555 | L1 Loss: 0.3029 | l1_alpha: 8.0000e-04 | Tokens: 0 | Self Similarity: -0.0002


  0%|          | 1/55054 [00:00<3:52:44,  3.94it/s]

Sparsity: 12254.1 | Dead Features: 24576 | Total Loss: 0.6150 | Reconstruction Loss: 0.1848 | L1 Loss: 0.4302 | l1_alpha: 8.0000e-04 | Tokens: 0 | Self Similarity: -0.0001
Sparsity: 12276.1 | Dead Features: 24576 | Total Loss: 1.5557 | Reconstruction Loss: 1.0040 | L1 Loss: 0.5517 | l1_alpha: 8.0000e-04 | Tokens: 0 | Self Similarity: 0.0003
Sparsity: 12315.4 | Dead Features: 24576 | Total Loss: 0.6045 | Reconstruction Loss: 0.1241 | L1 Loss: 0.4804 | l1_alpha: 8.0000e-04 | Tokens: 0 | Self Similarity: -0.0002
Sparsity: 12296.9 | Dead Features: 24576 | Total Loss: 0.7907 | Reconstruction Loss: 0.1979 | L1 Loss: 0.5928 | l1_alpha: 8.0000e-04 | Tokens: 0 | Self Similarity: -0.0001


  0%|          | 101/55054 [00:15<2:23:24,  6.39it/s]

Sparsity: 16.9 | Dead Features: 0 | Total Loss: 0.0582 | Reconstruction Loss: 0.0534 | L1 Loss: 0.0048 | l1_alpha: 8.0000e-04 | Tokens: 204800 | Self Similarity: -0.0000
Sparsity: 4.4 | Dead Features: 0 | Total Loss: 0.0201 | Reconstruction Loss: 0.0181 | L1 Loss: 0.0020 | l1_alpha: 8.0000e-04 | Tokens: 204800 | Self Similarity: -0.0002
Sparsity: 22.9 | Dead Features: 0 | Total Loss: 0.0426 | Reconstruction Loss: 0.0393 | L1 Loss: 0.0033 | l1_alpha: 8.0000e-04 | Tokens: 204800 | Self Similarity: 0.0230
Sparsity: 78.9 | Dead Features: 0 | Total Loss: 0.0591 | Reconstruction Loss: 0.0462 | L1 Loss: 0.0129 | l1_alpha: 8.0000e-04 | Tokens: 204800 | Self Similarity: -0.0756
Sparsity: 8.8 | Dead Features: 0 | Total Loss: 0.0521 | Reconstruction Loss: 0.0487 | L1 Loss: 0.0034 | l1_alpha: 8.0000e-04 | Tokens: 204800 | Self Similarity: 0.0290
Sparsity: 43.2 | Dead Features: 0 | Total Loss: 0.0606 | Reconstruction Loss: 0.0529 | L1 Loss: 0.0077 | l1_alpha: 8.0000e-04 | Tokens: 204800 | Self Simi

  0%|          | 201/55054 [00:30<2:28:38,  6.15it/s]

Sparsity: 18.1 | Dead Features: 0 | Total Loss: 0.0474 | Reconstruction Loss: 0.0414 | L1 Loss: 0.0060 | l1_alpha: 8.0000e-04 | Tokens: 409600 | Self Similarity: 0.0031
Sparsity: 4.5 | Dead Features: 0 | Total Loss: 0.0188 | Reconstruction Loss: 0.0166 | L1 Loss: 0.0021 | l1_alpha: 8.0000e-04 | Tokens: 409600 | Self Similarity: -0.0082
Sparsity: 13.4 | Dead Features: 0 | Total Loss: 0.0375 | Reconstruction Loss: 0.0343 | L1 Loss: 0.0033 | l1_alpha: 8.0000e-04 | Tokens: 409600 | Self Similarity: 0.0244
Sparsity: 25.5 | Dead Features: 0 | Total Loss: 0.0465 | Reconstruction Loss: 0.0386 | L1 Loss: 0.0079 | l1_alpha: 8.0000e-04 | Tokens: 409600 | Self Similarity: -0.0752
Sparsity: 11.1 | Dead Features: 0 | Total Loss: 0.0482 | Reconstruction Loss: 0.0438 | L1 Loss: 0.0044 | l1_alpha: 8.0000e-04 | Tokens: 409600 | Self Similarity: 0.0297
Sparsity: 24.4 | Dead Features: 0 | Total Loss: 0.0543 | Reconstruction Loss: 0.0467 | L1 Loss: 0.0076 | l1_alpha: 8.0000e-04 | Tokens: 409600 | Self Simi

  1%|          | 301/55054 [00:46<2:21:20,  6.46it/s]

Sparsity: 19.1 | Dead Features: 0 | Total Loss: 0.0420 | Reconstruction Loss: 0.0351 | L1 Loss: 0.0069 | l1_alpha: 8.0000e-04 | Tokens: 614400 | Self Similarity: 0.0038
Sparsity: 9.3 | Dead Features: 0 | Total Loss: 0.0198 | Reconstruction Loss: 0.0174 | L1 Loss: 0.0024 | l1_alpha: 8.0000e-04 | Tokens: 614400 | Self Similarity: -0.0087
Sparsity: 16.3 | Dead Features: 0 | Total Loss: 0.0372 | Reconstruction Loss: 0.0334 | L1 Loss: 0.0038 | l1_alpha: 8.0000e-04 | Tokens: 614400 | Self Similarity: 0.0267
Sparsity: 29.4 | Dead Features: 0 | Total Loss: 0.0435 | Reconstruction Loss: 0.0374 | L1 Loss: 0.0061 | l1_alpha: 8.0000e-04 | Tokens: 614400 | Self Similarity: -0.0745
Sparsity: 14.7 | Dead Features: 0 | Total Loss: 0.0472 | Reconstruction Loss: 0.0422 | L1 Loss: 0.0050 | l1_alpha: 8.0000e-04 | Tokens: 614400 | Self Similarity: 0.0308
Sparsity: 23.5 | Dead Features: 0 | Total Loss: 0.0525 | Reconstruction Loss: 0.0444 | L1 Loss: 0.0081 | l1_alpha: 8.0000e-04 | Tokens: 614400 | Self Simi

  1%|          | 401/55054 [01:01<2:23:20,  6.35it/s]

Sparsity: 21.2 | Dead Features: 0 | Total Loss: 0.0386 | Reconstruction Loss: 0.0311 | L1 Loss: 0.0075 | l1_alpha: 8.0000e-04 | Tokens: 819200 | Self Similarity: 0.0034
Sparsity: 8.0 | Dead Features: 0 | Total Loss: 0.0186 | Reconstruction Loss: 0.0160 | L1 Loss: 0.0026 | l1_alpha: 8.0000e-04 | Tokens: 819200 | Self Similarity: -0.0097
Sparsity: 18.0 | Dead Features: 0 | Total Loss: 0.0359 | Reconstruction Loss: 0.0316 | L1 Loss: 0.0043 | l1_alpha: 8.0000e-04 | Tokens: 819200 | Self Similarity: 0.0260
Sparsity: 40.4 | Dead Features: 0 | Total Loss: 0.0429 | Reconstruction Loss: 0.0364 | L1 Loss: 0.0065 | l1_alpha: 8.0000e-04 | Tokens: 819200 | Self Similarity: -0.0745
Sparsity: 16.0 | Dead Features: 0 | Total Loss: 0.0471 | Reconstruction Loss: 0.0416 | L1 Loss: 0.0055 | l1_alpha: 8.0000e-04 | Tokens: 819200 | Self Similarity: 0.0315
Sparsity: 25.0 | Dead Features: 0 | Total Loss: 0.0522 | Reconstruction Loss: 0.0433 | L1 Loss: 0.0089 | l1_alpha: 8.0000e-04 | Tokens: 819200 | Self Simi

  1%|          | 501/55054 [01:17<2:21:44,  6.41it/s]

Sparsity: 27.6 | Dead Features: 0 | Total Loss: 0.0363 | Reconstruction Loss: 0.0281 | L1 Loss: 0.0082 | l1_alpha: 8.0000e-04 | Tokens: 1024000 | Self Similarity: 0.0028
Sparsity: 10.7 | Dead Features: 0 | Total Loss: 0.0181 | Reconstruction Loss: 0.0154 | L1 Loss: 0.0027 | l1_alpha: 8.0000e-04 | Tokens: 1024000 | Self Similarity: -0.0109
Sparsity: 20.7 | Dead Features: 0 | Total Loss: 0.0347 | Reconstruction Loss: 0.0298 | L1 Loss: 0.0048 | l1_alpha: 8.0000e-04 | Tokens: 1024000 | Self Similarity: 0.0111
Sparsity: 28.5 | Dead Features: 0 | Total Loss: 0.0407 | Reconstruction Loss: 0.0342 | L1 Loss: 0.0065 | l1_alpha: 8.0000e-04 | Tokens: 1024000 | Self Similarity: -0.0834
Sparsity: 20.1 | Dead Features: 0 | Total Loss: 0.0459 | Reconstruction Loss: 0.0396 | L1 Loss: 0.0063 | l1_alpha: 8.0000e-04 | Tokens: 1024000 | Self Similarity: 0.0376
Sparsity: 25.7 | Dead Features: 0 | Total Loss: 0.0521 | Reconstruction Loss: 0.0429 | L1 Loss: 0.0092 | l1_alpha: 8.0000e-04 | Tokens: 1024000 | Se

  1%|          | 601/55054 [01:32<2:21:46,  6.40it/s]

Sparsity: 30.5 | Dead Features: 0 | Total Loss: 0.0368 | Reconstruction Loss: 0.0281 | L1 Loss: 0.0088 | l1_alpha: 8.0000e-04 | Tokens: 1228800 | Self Similarity: 0.0028
Sparsity: 12.6 | Dead Features: 0 | Total Loss: 0.0179 | Reconstruction Loss: 0.0152 | L1 Loss: 0.0028 | l1_alpha: 8.0000e-04 | Tokens: 1228800 | Self Similarity: -0.0119
Sparsity: 22.6 | Dead Features: 0 | Total Loss: 0.0355 | Reconstruction Loss: 0.0303 | L1 Loss: 0.0052 | l1_alpha: 8.0000e-04 | Tokens: 1228800 | Self Similarity: 0.0108
Sparsity: 31.6 | Dead Features: 0 | Total Loss: 0.0403 | Reconstruction Loss: 0.0336 | L1 Loss: 0.0067 | l1_alpha: 8.0000e-04 | Tokens: 1228800 | Self Similarity: -0.0843
Sparsity: 23.2 | Dead Features: 0 | Total Loss: 0.0455 | Reconstruction Loss: 0.0384 | L1 Loss: 0.0070 | l1_alpha: 8.0000e-04 | Tokens: 1228800 | Self Similarity: 0.0376
Sparsity: 27.9 | Dead Features: 0 | Total Loss: 0.0529 | Reconstruction Loss: 0.0435 | L1 Loss: 0.0094 | l1_alpha: 8.0000e-04 | Tokens: 1228800 | Se

  1%|          | 618/55054 [01:35<2:18:40,  6.54it/s]

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-160m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 16
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-160m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 2
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-70m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 2
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-70m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 8
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-70m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 4
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train[:50000]"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-70m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 4
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train[50000:]"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-160m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 4
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train[:50000]"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-160m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 4
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train[50000:]"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

In [None]:
# Code that actually starts a full training run!

model_name = "EleutherAI/pythia-410m"
dataset_name = "Elriggs/openwebtext-100k" # "Elriggs/openwebtext-100k"
ratio = 2
layers = [0, 1, 2, 3, 4, 5]
wandb_log = False
seed = 0
split = "train"
epoches = 1

setup_execute_training(model_name,
                       dataset_name,
                       ratio,
                       layers,
                       seed,
                       wandb_log=wandb_log,
                       split=split,
                      epoches=epoches)

# Model + Data

In [None]:
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, GPTJForSequenceClassification


# Load in the model
def load_model(cfg):
    model = AutoModelForCausalLM.from_pretrained(cfg.model_name)
    model = model.to(cfg.device)
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    return model, tokenizer

In [None]:
# Download the dataset
# TODO iteratively grab dataset?
def init_dataloader(cfg, model, tokenizer, split="train"):
    cfg.max_length = 256
    token_loader = setup_token_data(cfg, tokenizer, model, seed=cfg.seed, split=split)
    num_tokens = cfg.max_length*cfg.model_batch_size*len(token_loader)
    print(f"Number of tokens: {num_tokens}")
    cfg.total_tokens = num_tokens
    
    return token_loader

In [None]:
# Run 1 datapoint on model to get the activation size

def get_activation_size(cfg, model, tokenizer):
    text = "1"
    tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]
    tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
    # Your activation name will be different. In the next cells, we will show you how to find it.
    with torch.no_grad():
        with Trace(model, tensor_names[0]) as ret:
            _ = model(tokens)
            representation = ret.output
            # check if instance tuple
            if(isinstance(representation, tuple)):
                representation = representation[0]
            activation_size = representation.shape[-1]
    print(f"Activation size: {activation_size}")
    cfg.activation_size = activation_size
    return activation_size

In [6]:
# text = "1"
# tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
# # Your activation name will be different. In the next cells, we will show you how to find it.
# with torch.no_grad():
#     with Trace(model, 'gpt_neox.layers.0.mlp') as ret:
#         _ = model(tokens)
#         representation = ret.output
#         # check if instance tuple
#         if(isinstance(representation, tuple)):
#             representation = representation[0]
#         activation_size = representation.shape[-1]
# print(f"Activation size: {activation_size}")
# cfg.activation_size = activation_size


NameError: name 'tokenizer' is not defined

In [None]:
# # Set target sparsity to 10% of activation_size if not set

# # NOT USED
# if cfg.sparsity is None:
#     cfg.sparsity = int(activation_size*0.05)
#     print(f"Target sparsity: {cfg.sparsity}")

# target_lower_sparsity = cfg.sparsity * 0.9
# target_upper_sparsity = cfg.sparsity * 1.1
# adjustment_factor = 0.1  # You can set this to whatever you like

# Sparse Autoencocer init

In [None]:
def init_autoencoder(cfg):
    autoencoders = []
    optimizers = []
    for layer in range(len(cfg.layers)):
        params = dict()
        n_dict_components = cfg.activation_size*cfg.ratio # Sparse Autoencoder Size
        params["encoder"] = torch.empty((n_dict_components, cfg.activation_size), device=cfg.device)
        nn.init.xavier_uniform_(params["encoder"])
    
        params["decoder"] = torch.empty((n_dict_components, cfg.activation_size), device=cfg.device)
        nn.init.xavier_uniform_(params["decoder"])
    
        params["encoder_bias"] = torch.empty((n_dict_components,), device=cfg.device)
        nn.init.zeros_(params["encoder_bias"])
    
        params["shift_bias"] = torch.empty((cfg.activation_size,), device=cfg.device)
        nn.init.zeros_(params["shift_bias"])
    
        autoencoder = AnthropicSAE(  # TiedSAE, UntiedSAE, AnthropicSAE
            # n_feats = n_dict_components, 
            # activation_size=cfg.activation_size,
            encoder=params["encoder"],
            encoder_bias=params["encoder_bias"],
            decoder=params["decoder"],
            shift_bias=params["shift_bias"],
        )
        autoencoder.to_device(cfg.device)
        autoencoder.set_grad()
    
        optimizer = torch.optim.Adam(
            [
                autoencoder.encoder, 
                autoencoder.encoder_bias,
                autoencoder.decoder,
                autoencoder.shift_bias,
            ], lr=cfg.lr)
        autoencoders.append(autoencoder)
        optimizers.append(optimizer)
    return autoencoders, optimizers

# Training Run

In [None]:
# tensor_names = ['gpt_neox.layers.0.mlp', 'gpt_neox.layers.5.mlp']

In [None]:
# original_bias = autoencoder.encoder_bias.clone().detach()
# Wandb setup
def setup_wandb(cfg, wandb_run_name):
    secrets = json.load(open("secrets.json"))
    wandb.login(key=secrets["wandb_key"])
    wandb.init(project="Sparse Coding >70m", config=dict(cfg), name=wandb_run_name)
    return wandb_run_name

In [None]:
def training_run(cfg, model, optimizers, autoencoders, token_loader):

    time_since_activation = torch.zeros(autoencoders[0].encoder.shape[0])
    total_activations = torch.zeros(autoencoders[0].encoder.shape[0])
    tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]
    max_num_tokens = cfg.total_tokens # 100_000_000
    save_every = 30_000
    num_saved_so_far = 0

    # Freeze model parameters 
    model.eval()
    model.requires_grad_(False)
    model.to(cfg.device)
    
    last_encoder = autoencoders[0].encoder.clone().detach()
    assert len(cfg.layers) == len(tensor_names), "layers and tensor_names have different lengths"
    for epoch in range(cfg.num_epochs):
        for i, batch in enumerate(tqdm(token_loader)): #,total=int(max_num_tokens/(cfg.max_length*cfg.model_batch_size)))):
            tokens = batch["input_ids"].to(cfg.device)
            # print(f"tokens shape: {tokens.shape}")
            
            with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
                
                #print(tensor_names)
                representations = []
                with TraceDict(model, tensor_names) as ret:
                    _ = model(tokens)
                    for tensor_name in tensor_names:
                        representations.append(ret[tensor_name].output)
                    assert not isinstance(representations[0], tuple), "representations is type tuple"
                    # print(len(representations), representations[0].shape)
                    # if(isinstance(representation, tuple)):
                    #     representation = representation[0]
            #print(f"representation is: {representation}")
            #print(f"representation shape is: {representation.shape}")
            
        
            # activation_saver.save_batch(layer_activations.clone().cpu().detach())
            for layer in range(len(cfg.layers)):
                representation = representations[layer]
                layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
                autoencoder = autoencoders[layer]
                optimizer = optimizers[layer]
                
                c = autoencoder.encode(layer_activations)
                x_hat = autoencoder.decode(c)
                
                reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
                l1_loss = torch.norm(c, 1, dim=-1).mean()
                total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss
            
                time_since_activation += 1
                time_since_activation = time_since_activation * (c.sum(dim=0).cpu()==0)
                # total_activations += c.sum(dim=0).cpu()
                if ((i) % 100 == 0): # Check here so first check is model w/o change
                    # self_similarity = torch.cosine_similarity(c, last_encoder, dim=-1).mean().cpu().item()
                    # Above is wrong, should be similarity between encoder and last encoder
                    self_similarity = torch.cosine_similarity(autoencoder.encoder, last_encoder, dim=-1).mean().cpu().item()
                    last_encoder = autoencoder.encoder.clone().detach()
            
                    num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
                    with torch.no_grad():
                        sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
                        # Count number of dead_features are zero
                        num_dead_features = (time_since_activation >= min(i, 200)).sum().item()
                    print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.4f} | Reconstruction Loss: {reconstruction_loss:.4f} | L1 Loss: {cfg.l1_alpha*l1_loss:.4f} | l1_alpha: {cfg.l1_alpha:.4e} | Tokens: {num_tokens_so_far} | Self Similarity: {self_similarity:.4f}")
                    
                    if cfg.wandb_log:
                        wandb.log({
                            'Sparsity': sparsity,
                            'Dead Features': num_dead_features,
                            'Total Loss': total_loss.item(),
                            'Reconstruction Loss': reconstruction_loss.item(),
                            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
                            'l1_alpha': cfg.l1_alpha,
                            'Tokens': num_tokens_so_far,
                            'Self Similarity': self_similarity
                        })
                    
                    dead_features = torch.zeros(autoencoder.encoder.shape[0])
                    
                    # if(num_tokens_so_far > max_num_tokens):
                    #     print(f"Reached max number of tokens: {max_num_tokens}")
                    #     break
                    
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
    wandb.finish()
        # resample_period = 10000
        # if (i % resample_period == 0):
        #     # RESAMPLING
        #     with torch.no_grad():
        #         # Count number of dead_features are zero
        #         num_dead_features = (total_activations == 0).sum().item()
        #         print(f"Dead Features: {num_dead_features}")
                
        #     if num_dead_features > 0:
        #         print("Resampling!")
        #         # hyperparams:
        #         max_resample_tokens = 1000 # the number of token activations that we consider for inserting into the dictionary
        #         # compute loss of model on random subset of inputs
        #         resample_loader = setup_token_data(cfg, tokenizer, model, seed=i)
        #         num_resample_data = 0
    
        #         resample_activations = torch.empty(0, activation_size)
        #         resample_losses = torch.empty(0)
    
        #         for resample_batch in resample_loader:
        #             resample_tokens = resample_batch["input_ids"].to(cfg.device)
        #             with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
        #                 with Trace(model, tensor_names[0]) as ret:
        #                     _ = model(resample_tokens)
        #                     representation = ret.output
        #                     if(isinstance(representation, tuple)):
        #                         representation = representation[0]
        #             layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
        #             resample_activations = torch.cat((resample_activations, layer_activations.detach().cpu()), dim=0)
    
        #             c = autoencoder.encode(layer_activations)
        #             x_hat = autoencoder.decode(c)
                    
        #             reconstruction_loss = (x_hat - layer_activations).pow(2).mean(dim=-1)
        #             l1_loss = torch.norm(c, 1, dim=-1)
        #             temp_loss = reconstruction_loss + cfg.l1_alpha*l1_loss
                    
        #             resample_losses = torch.cat((resample_losses, temp_loss.detach().cpu()), dim=0)
                    
        #             num_resample_data +=layer_activations.shape[0]
        #             if num_resample_data > max_resample_tokens:
        #                 break
    
                    
        #         # sample num_dead_features vectors of input activations
        #         probabilities = resample_losses**2
        #         probabilities /= probabilities.sum()
        #         sampled_indices = torch.multinomial(probabilities, num_dead_features, replacement=True)
        #         new_vectors = resample_activations[sampled_indices]
    
        #         # calculate average encoder norm of alive neurons
        #         alive_neurons = list((total_activations!=0))
        #         modified_columns = total_activations==0
        #         avg_norm = autoencoder.encoder.data[alive_neurons].norm(dim=-1).mean()
    
        #         # replace dictionary and encoder weights with vectors
        #         new_vectors = new_vectors / new_vectors.norm(dim=1, keepdim=True)
                
        #         params_to_modify = [autoencoder.encoder, autoencoder.encoder_bias]
    
        #         current_weights = autoencoder.encoder.data
        #         current_weights[modified_columns] = (new_vectors.to(cfg.device) * avg_norm * 0.02)
        #         autoencoder.encoder.data = current_weights
    
        #         current_weights = autoencoder.encoder_bias.data
        #         current_weights[modified_columns] = 0
        #         autoencoder.encoder_bias.data = current_weights
                
        #         if hasattr(autoencoder, 'decoder'):
        #             current_weights = autoencoder.decoder.data
        #             current_weights[modified_columns] = new_vectors.to(cfg.device)
        #             autoencoder.decoder.data = current_weights
        #             params_to_modify += [autoencoder.decoder]
    
        #         for param_group in optimizer.param_groups:
        #             for param in param_group['params']:
        #                 if any(param is d_ for d_ in params_to_modify):
        #                     # Extract the corresponding rows from m and v
        #                     m = optimizer.state[param]['exp_avg']
        #                     v = optimizer.state[param]['exp_avg_sq']
                            
        #                     # Update the m and v values for the modified columns
        #                     m[modified_columns] = 0  # Reset moving average for modified columns
        #                     v[modified_columns] = 0  # Reset squared moving average for modified columns
            
        #     total_activations = torch.zeros(autoencoder.encoder.shape[0])
        
        
        
    
        # if ((i+2) % save_every ==0): # save periodically but before big changes
        #     model_save_name = cfg.model_name.split("/")[-1]
        #     save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}_ckpt{num_saved_so_far}"  # trim year
    
        #     # Make directory traiend_models if it doesn't exist
        #     import os
        #     if not os.path.exists("trained_models"):
        #         os.makedirs("trained_models")
        #     # Save model
        #     torch.save(autoencoder, f"trained_models/{save_name}.pt")
            
        #     num_saved_so_far += 1
    
        # # Running sparsity check
        # num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
        # if(num_tokens_so_far > 200000):
        #     if(i % 100 == 0):
        #         with torch.no_grad():
        #             sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
        #         if sparsity > target_upper_sparsity:
        #             cfg.l1_alpha *= (1 + adjustment_factor)
        #         elif sparsity < target_lower_sparsity:
        #             cfg.l1_alpha *= (1 - adjustment_factor)
        #         # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")

## Duplicated training run

# Saving

In [None]:
def model_save(cfg, autoencoder, storage_path, filename, layer):
    model_save_name = cfg.model_name.split("/")[-1]

    # start_time = datetime.now().strftime("%Y%m%d-%H%M%S")

    # save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}_{start_time}"  # trim year
    storage_path_name = "trained_models/" + storage_path + f"/layer_{layer}"
    # Make directory traiend_models if it doesn't exist
    if not os.path.exists(storage_path_name):
        os.makedirs(storage_path_name)
    # Save model
    filename = f"L{layer}_{filename}"
    
    torch.save(autoencoder, f"{storage_path_name}/{filename}.pt")
    print(f"Saved file at: {storage_path_name}/{filename}.pt")