In [1]:
import torch 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
# from standard_metrics import run_with_model_intervention, perplexity_under_reconstruction, mean_nonzero_activations
# Create 
# # make an argument parser directly below
# parser = argparse.ArgumentParser()
# parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-70m-deduped")
# parser.add_argument("--layer", type=int, default=4)
# parser.add_argument("--setting", type=str, default="residual")
# parser.add_argument("--l1_alpha", type=float, default=3e-3)
# parser.add_argument("--num_epochs", type=int, default=10)
# parser.add_argument("--model_batch_size", type=int, default=4)
# parser.add_argument("--lr", type=float, default=1e-3)
# parser.add_argument("--kl", type=bool, default=False)
# parser.add_argument("--reconstruction", type=bool, default=False)
# parser.add_argument("--dataset_name", type=str, default="NeelNanda/pile-10k")
# parser.add_argument("--device", type=str, default="cuda:4")

# args = parser.parse_args()
cfg = dotdict()
cfg.model_name="EleutherAI/pythia-70m-deduped"
# cfg.model_name="usvsnsp/pythia-6.9b-rm-full-hh-rlhf"
# cfg.model_name="reciprocate/dahoas-gptj-rm-static"
cfg.layers=[1,2]
cfg.setting="residual"
cfg.tensor_name="gpt_neox.layers.{layer}"
# cfg.tensor_name="transformer.h.{layer}"
cfg.l1_alpha=8e-4
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=4
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
cfg.dataset_name="NeelNanda/pile-10k"
# cfg.dataset_name="Skylion007/openwebtext"
# cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cuda:0"
cfg.ratio = 4
# cfg.device="cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]

In [3]:
# Load in the model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(cfg.model_name).to(cfg.device)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attention): GPTNeoXAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=512, out_features=50304, bias=False)
)

In [4]:
# Download the dataset
# TODO iteratively grab dataset?
cfg.max_length = 256
cfg.model_batch_size = 4
token_loader = setup_token_data(cfg, tokenizer, model)
num_tokens = cfg.max_length*cfg.model_batch_size*len(token_loader)
print(f"Number of tokens: {num_tokens}")

Found cached dataset parquet (/root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-bf32bd9ecc05cbbb_*_of_00008.arrow


Number of tokens: 15360000


In [5]:
# Run 1 datapoint on model to get the activation size (cause don't want to deal w/ different naming schemes in config files)
from baukit import Trace

text = "1"
tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
# Your activation name will be different. In the next cells, we will show you how to find it.
with torch.no_grad():
    with Trace(model, tensor_names[0]) as ret:
        _ = model(tokens)
        representation = ret.output
        # check if instance tuple
        if(isinstance(representation, tuple)):
            representation = representation[0]
        activation_size = representation.shape[-1]
print(f"Activation size: {activation_size}")

Activation size: 512


In [6]:
from torch import nn
from torchtyping import TensorType


class TiedSAE(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super().__init__()
        self.encoder = nn.Parameter(torch.empty((n_dict_components, activation_size)))
        nn.init.xavier_uniform_(self.encoder)
        self.encoder_bias = nn.Parameter(torch.zeros((n_dict_components,)))

    def get_learned_dict(self):
        norms = torch.norm(self.encoder, 2, dim=-1)
        return self.encoder / torch.clamp(norms, 1e-8)[:, None]

    def to_device(self, device):
        self.encoder = self.encoder.to(device)
        self.encoder_bias = self.encoder_bias.to(device)

    def encode(self, batch):
        c = torch.einsum("nd,bd->bn", self.encoder, batch)
        c = c + self.encoder_bias
        c = torch.clamp(c, min=0.0)
        return c

    def decode(self, code: TensorType["_batch_size", "_n_dict_components"]) -> TensorType["_batch_size", "_activation_size"]:
        learned_dict = self.get_learned_dict()
        x_hat = torch.einsum("nd,bn->bd", learned_dict, code)
        return x_hat

    def forward(self, batch: TensorType["_batch_size", "_activation_size"]) -> TensorType["_batch_size", "_activation_size"]:
        c = self.encode(batch)
        x_hat = self.decode(c)
        return x_hat, c

    def n_dict_components(self):
        return self.get_learned_dict().shape[0]

n_dict_components = activation_size*cfg.ratio
all_autoencoders = [TiedSAE(activation_size, n_dict_components).to(cfg.device) for _ in range(len(tensor_names))]

In [36]:
optimizers = [torch.optim.Adam(autoencoder.parameters(), lr=cfg.learning_rate) for autoencoder in all_autoencoders]

In [30]:
list(all_autoencoders[1].parameters())

[]

In [14]:
# # Initialize New autoencoder
# from autoencoders.learned_dict import TiedSAE
# from torch import nn
# params = dict()
# n_dict_components = activation_size*cfg.ratio
# params["encoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
# nn.init.xavier_uniform_(params["encoder"])

# params["encoder_bias"] = torch.empty((n_dict_components,), device=cfg.device)
# nn.init.zeros_(params["encoder_bias"])

# autoencoder = TiedSAE(
#     # n_feats = n_dict_components, 
#     # activation_size=activation_size,
#     encoder=params["encoder"],
#     encoder_bias=params["encoder_bias"],
# )
# autoencoder.to_device(cfg.device)

In [16]:
print("WARNING: Only works on tied SAE")
# Set gradient to true for encoder & bias
autoencoder.encoder.requires_grad = True
autoencoder.encoder_bias.requires_grad = True
optimizer = torch.optim.Adam([autoencoder.encoder, autoencoder.encoder_bias], lr=1e-3)
# Wandb setup
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
wandb_run_name = f"{cfg.model_name}_{start_time[4:]}_{cfg.sparsity}"  # trim year
print(f"wandb_run_name: {wandb_run_name}")
wandb.init(project="sparse coding", config=dict(cfg), name=wandb_run_name, entity="sparse_coding")



wandb_run_name: reciprocate/dahoas-gptj-rm-static_1023-171456_409


0,1
Dead Features,▁▁▅▅▅▅▅▅█
L1 Loss,█▂▁▁▁▁▁▁▁
Reconstruction Loss,█▁▂▁▁▁▁▁▁
Self Similarity,█▁▇▇▇▇▇▇█
Sparsity,█▁▁▁▁▁▁▁▁
Tokens,▁▂▃▄▅▆▇█▁
Total Loss,█▁▂▁▁▁▁▁▁
l1_alpha,▁▁▁▁▁▁▁▁▁

0,1
Dead Features,9693.0
L1 Loss,0.10886
Reconstruction Loss,0.60398
Self Similarity,1.0
Sparsity,40.69141
Tokens,0.0
Total Loss,0.71283
l1_alpha,0.0008


In [17]:
import numpy as np
# Make directory trained_models if it doesn't exist
import os
if not os.path.exists("trained_models"):
    os.makedirs("trained_models")
model_save_name = cfg.model_name.split("/")[-1]

num_batch = len(token_loader)
log_space = np.logspace(0, np.log10(num_batch), 11)  # 11 to get 10 intervals
save_batches = [int(x) for x in log_space[1:]]  # Skip the first (0th) interval

dead_features = torch.zeros(autoencoder.encoder.shape[0])
# max_num_tokens = 100000000
# Freeze model parameters 
model.eval()
model.requires_grad_(False)
last_encoder = autoencoder.encoder.clone().detach()
for i, batch in enumerate(token_loader):
    tokens = batch["input_ids"].to(cfg.device)
    with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
        with Trace(model, tensor_names[0]) as ret:
            _ = model(tokens)
            representation = ret.output
            if(isinstance(representation, tuple)):
                representation = representation[0]
    layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")

    c = autoencoder.encode(layer_activations)
    x_hat = autoencoder.decode(c)
    
    reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
    l1_loss = torch.norm(c, 1, dim=-1).mean()
    total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss

    dead_features += c.sum(dim=0).cpu()
    if (i % 200 == 0): # Check here so first check is model w/o change
        # self_similarity = torch.cosine_similarity(c, last_encoder, dim=-1).mean().cpu().item()
        # Above is wrong, should be similarity between encoder and last encoder
        self_similarity = torch.cosine_similarity(autoencoder.encoder, last_encoder, dim=-1).mean().cpu().item()
        last_encoder = autoencoder.encoder.clone().detach()
        num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            # Count number of dead_features are zero
            num_dead_features = (dead_features == 0).sum().item()
        print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.2f} | Reconstruction Loss: {reconstruction_loss:.2f} | L1 Loss: {cfg.l1_alpha*l1_loss:.2f} | l1_alpha: {cfg.l1_alpha:.2e} | Tokens: {num_tokens_so_far} | Self Similarity: {self_similarity:.2f}")
        wandb.log({
            'Sparsity': sparsity,
            'Dead Features': num_dead_features,
            'Total Loss': total_loss.item(),
            'Reconstruction Loss': reconstruction_loss.item(),
            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
            'l1_alpha': cfg.l1_alpha,
            'Tokens': num_tokens_so_far,
            'Self Similarity': self_similarity
        })
        dead_features = torch.zeros(autoencoder.encoder.shape[0])
    # if i in save_batches:
    #     save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}_{i}"  # trim year
    #     torch.save(autoencoder, f"trained_models/{save_name}.pt")
    #     print(f"Saved model to trained_models/{save_name}")
        

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # # Running sparsity check
    # if(num_tokens_so_far > 5000000):
    #     if(i % 200 == 0):
    #         with torch.no_grad():
    #             sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
    #         if sparsity > target_upper_sparsity:
    #             cfg.l1_alpha *= (1 + adjustment_factor)
    #         elif sparsity < target_lower_sparsity:
    #             cfg.l1_alpha *= (1 - adjustment_factor)
            # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")

Sparsity: 8249.7 | Dead Features: 2 | Total Loss: 16.97 | Reconstruction Loss: 13.40 | L1 Loss: 3.57 | l1_alpha: 8.00e-04 | Tokens: 0 | Self Similarity: 1.00
Saved model to trained_models/dahoas-gptj-rm-static_sp409_r4_transformer.h.4_3
Saved model to trained_models/dahoas-gptj-rm-static_sp409_r4_transformer.h.4_10
Saved model to trained_models/dahoas-gptj-rm-static_sp409_r4_transformer.h.4_32
Saved model to trained_models/dahoas-gptj-rm-static_sp409_r4_transformer.h.4_104
Sparsity: 288.7 | Dead Features: 1 | Total Loss: 0.57 | Reconstruction Loss: 0.30 | L1 Loss: 0.27 | l1_alpha: 8.00e-04 | Tokens: 204800 | Self Similarity: 0.89
Saved model to trained_models/dahoas-gptj-rm-static_sp409_r4_transformer.h.4_332
Sparsity: 144.7 | Dead Features: 5774 | Total Loss: 0.37 | Reconstruction Loss: 0.23 | L1 Loss: 0.13 | l1_alpha: 8.00e-04 | Tokens: 409600 | Self Similarity: 0.99
Sparsity: 103.8 | Dead Features: 6558 | Total Loss: 0.52 | Reconstruction Loss: 0.37 | L1 Loss: 0.15 | l1_alpha: 8.00e

In [1]:
wandb.finish()

NameError: name 'wandb' is not defined

In [2]:
save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}"  # trim year


# Save model
torch.save(autoencoder, f"trained_models/{save_name}.pt")

NameError: name 'model_save_name' is not defined