In [46]:
import torch 
from transformer_lens import HookedTransformer
import numpy as np 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
from standard_metrics import run_with_model_intervention, perplexity_under_reconstruction, mean_nonzero_activations
# Create 
# # make an argument parser directly below
# parser = argparse.ArgumentParser()
# parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-70m-deduped")
# parser.add_argument("--layer", type=int, default=4)
# parser.add_argument("--setting", type=str, default="residual")
# parser.add_argument("--l1_alpha", type=float, default=3e-3)
# parser.add_argument("--num_epochs", type=int, default=10)
# parser.add_argument("--model_batch_size", type=int, default=4)
# parser.add_argument("--lr", type=float, default=1e-3)
# parser.add_argument("--kl", type=bool, default=False)
# parser.add_argument("--reconstruction", type=bool, default=False)
# parser.add_argument("--dataset_name", type=str, default="NeelNanda/pile-10k")
# parser.add_argument("--device", type=str, default="cuda:4")

# args = parser.parse_args()
cfg = dotdict()
cfg.model_name="EleutherAI/pythia-70m-deduped"
cfg.layers=[4]
cfg.setting="residual"
cfg.tensor_name="gpt_neox.layers.{layer}"
cfg.l1_alpha=1e-3
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=4
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
cfg.dataset_name="NeelNanda/pile-10k"
# cfg.device="cuda:0"
cfg.device="cpu"

In [2]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]

In [7]:
# Load in the model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(cfg.model_name).to(cfg.device)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)

In [12]:
# Download the dataset
# TODO iteratively grab dataset?
cfg.max_length = 256
cfg.model_batch_size = 4
token_loader = setup_token_data(cfg, tokenizer, model)

Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-bf32bd9ecc05cbbb_*_of_00008.arrow


In [38]:
# Run 1 datapoint on model
from baukit import Trace

text = "1"
tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
# Your activation name will be different. In the next cells, we will show you how to find it.
with torch.no_grad():
    with Trace(model, tensor_names[0]) as ret:
        _ = model(tokens)
        representation = ret.output
        # check if instance tuple
        if(isinstance(representation, tuple)):
            representation = representation[0]
        activation_size = representation.shape[-1]
print(f"Activation size: {activation_size}")

Activation size: 512


In [42]:
# Initialize New autoencoder
from autoencoders.learned_dict import TiedSAE
import torch.nn as nn
ratio = 4
params = dict()
n_dict_components = activation_size*ratio
params["encoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["encoder"])

params["encoder_bias"] = torch.empty((n_dict_components,), device=cfg.device)
nn.init.zeros_(params["encoder_bias"])

autoencoder = TiedSAE(
    # n_feats = n_dict_components, 
    # activation_size=activation_size,
    encoder=params["encoder"],
    encoder_bias=params["encoder_bias"],
)
autoencoder.to_device(cfg.device)


In [47]:
# Set target sparsity to 10% of activation_size if not set
if cfg.sparsity is None:
    cfg.sparsity = int(activation_size*0.1)

In [55]:
optimizer = torch.optim.Adam([autoencoder.encoder, autoencoder.encoder_bias], lr=1e-3)
# Wandb setup
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
wandb_run_name = f"{cfg.model_name}_{start_time[4:]}_{cfg.sparsity}"  # trim year
print(f"wandb_run_name: {wandb_run_name}")
wandb.init(project="sparse coding", config=dict(cfg), name=wandb_run_name, entity="sparse_coding")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/mchorse/.netrc


wandb_run_name: EleutherAI/pythia-70m-deduped_0929-205548_51


In [56]:
dead_features = torch.zeros(autoencoder.encoder.shape[0])
for i, batch in enumerate(token_loader):
    if(i > 250000):
        break
    tokens = batch["input_ids"].to(cfg.device)
    with Trace(model, tensor_names[0]) as ret:
        _ = model(tokens)
        representation = ret.output
        if(isinstance(representation, tuple)):
            representation = representation[0]
    layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")

    c = autoencoder.encode(layer_activations)
    x_hat = autoencoder.decode(c)
    
    reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
    l1_loss = torch.norm(c, 1, dim=-1).mean()
    total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss

    dead_features += c.sum(dim=0).cpu()
    if (i % 500 == 0): # Check here so first check is model w/o change
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            # Count number of dead_features are zero
            num_dead_features = (dead_features == 0).sum().item()
        print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.2f} | Reconstruction Loss: {reconstruction_loss:.2f} | L1 Loss: {cfg.l1_alpha*l1_loss:.2f}")
        wandb.log({
            'Sparsity': sparsity,
            'Dead Features': num_dead_features,
            'Total Loss': total_loss.item(),
            'Reconstruction Loss': reconstruction_loss.item(),
            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
        })
        dead_features = torch.zeros(autoencoder.encoder.shape[0])

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # Running sparsity check
    if(i % 10 == 0):
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
        if(i == 0):
            target_lower_sparsity = sparsity - 5.0
            target_upper_sparsity = sparsity + 5.0
            adjustment_factor = 0.1  # You can set this to whatever you like
        if sparsity > target_upper_sparsity:
            cfg.l1_alpha *= (1 + adjustment_factor)
        elif sparsity < target_lower_sparsity:
            cfg.l1_alpha *= (1 - adjustment_factor)
        # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")


Sparsity: 1029.8 | Dead Features: 0 | Total Loss: 0.93 | Reconstruction Loss: 0.56 | L1 Loss: 0.37


KeyboardInterrupt: 

In [None]:
wandb.finish()


In [53]:
model_save_name = cfg.model_name.split("/")[-1]
save_name = f"{model_save_name}_sp{cfg.sparsity}_r{ratio}_{tensor_names[0]}"  # trim year

# Make directory traiend_models if it doesn't exist
import os
if not os.path.exists("trained_models"):
    os.makedirs("trained_models")
# Save model
torch.save(autoencoder, f"trained_models/{save_name}.pt")

In [50]:
cfg

{'model_name': 'EleutherAI/pythia-70m-deduped',
 'layers': [4],
 'setting': 'residual',
 'tensor_name': 'gpt_neox.layers.{layer}',
 'l1_alpha': 0.001,
 'sparsity': 51,
 'num_epochs': 10,
 'model_batch_size': 4,
 'lr': 0.001,
 'kl': False,
 'reconstruction': False,
 'dataset_name': 'NeelNanda/pile-10k',
 'device': 'cpu'}