In [17]:
import torch 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt
# from standard_metrics import run_with_model_intervention, perplexity_under_reconstruction, mean_nonzero_activations
# Create 
# # make an argument parser directly below
# parser = argparse.ArgumentParser()
# parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-70m-deduped")
# parser.add_argument("--layer", type=int, default=4)
# parser.add_argument("--setting", type=str, default="residual")
# parser.add_argument("--l1_alpha", type=float, default=3e-3)
# parser.add_argument("--num_epochs", type=int, default=10)
# parser.add_argument("--model_batch_size", type=int, default=4)
# parser.add_argument("--lr", type=float, default=1e-3)
# parser.add_argument("--kl", type=bool, default=False)
# parser.add_argument("--reconstruction", type=bool, default=False)
# parser.add_argument("--dataset_name", type=str, default="NeelNanda/pile-10k")
# parser.add_argument("--device", type=str, default="cuda:4")

# args = parser.parse_args()
cfg = dotdict()
# cfg.model_name="EleutherAI/pythia-70m-deduped", "usvsnsp/pythia-6.9b-sft"
cfg.model_name="EleutherAI/pythia-70m-deduped"
cfg.target_name="EleutherAI/pythia-70m-deduped"
cfg.layers=[4]
cfg.setting="residual"
cfg.tensor_name="gpt_neox.layers.{layer}"
cfg.l1_alpha=1e-3
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=4
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cuda:0"
cfg.ratio = 4
# cfg.device="cpu"

In [18]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]

In [19]:
# Load in the model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(cfg.model_name)
model = model.to(cfg.device)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)

In [20]:
# Download the dataset
# TODO iteratively grab dataset?
cfg.max_length = 256
cfg.model_batch_size = 4
token_loader = setup_token_data(cfg, tokenizer, model)
num_tokens = cfg.max_length*cfg.model_batch_size*len(token_loader)
print(f"Number of tokens: {num_tokens}")

Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow


Number of tokens: 112749568


In [21]:
# Run 1 datapoint on model to get the activation size
from baukit import Trace

text = "1"
tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
# Your activation name will be different. In the next cells, we will show you how to find it.
with torch.no_grad():
    with Trace(model, tensor_names[0]) as ret:
        _ = model(tokens)
        representation = ret.output
        # check if instance tuple
        if(isinstance(representation, tuple)):
            representation = representation[0]
        activation_size = representation.shape[-1]
print(f"Activation size: {activation_size}")

Activation size: 512


In [30]:
# Initialize New autoencoder
from autoencoders.learned_dict import TiedSAE, UntiedSAE, AnthropicSAE
from torch import nn
params = dict()
n_dict_components = activation_size*cfg.ratio
params["encoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["encoder"])

params["decoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["decoder"])

params["encoder_bias"] = torch.empty((n_dict_components,), device=cfg.device)
nn.init.zeros_(params["encoder_bias"])

params["shift_bias"] = torch.empty((activation_size,), device=cfg.device)
nn.init.zeros_(params["shift_bias"])

autoencoder = TiedSAE(  # Note: this has no decoder bias
    # n_feats = n_dict_components, 
    # activation_size=activation_size,
    encoder=params["encoder"],
    encoder_bias=params["encoder_bias"],
    # decoder=params["decoder"],
    # shift_bias=params["shift_bias"],
)
autoencoder.to_device(cfg.device)
autoencoder.encoder.requires_grad = True
autoencoder.encoder_bias.requires_grad = True
# autoencoder.decoder.requires_grad = True
# autoencoder.shift_bias.requires_grad = True
optimizer = torch.optim.Adam(
    [
        autoencoder.encoder, 
        autoencoder.encoder_bias,
        # autoencoder.decoder,
        # autoencoder.shift_bias,
    ], lr=cfg.lr)

In [31]:
# Set target sparsity to 10% of activation_size if not set
if cfg.sparsity is None:
    cfg.sparsity = int(activation_size*0.1)
    print(f"Target sparsity: {cfg.sparsity}")

target_lower_sparsity = cfg.sparsity - 5.0
target_upper_sparsity = cfg.sparsity + 5.0
adjustment_factor = 0.1  # You can set this to whatever you like

In [32]:
original_bias = autoencoder.encoder_bias.clone().detach()
# Wandb setup
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
wandb_run_name = f"{cfg.model_name}_{start_time[4:]}_{cfg.sparsity}"  # trim year
print(f"wandb_run_name: {wandb_run_name}")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


wandb_run_name: EleutherAI/pythia-70m-deduped_1010-191352_51


In [33]:
wandb.init(project="sparse coding", config=dict(cfg), name=wandb_run_name)

In [35]:
dead_features = torch.zeros(autoencoder.encoder.shape[0])
total_activations = torch.zeros(autoencoder.encoder.shape[0])
max_num_tokens = 100_000_000
saved_activations = []
# Freeze model parameters 
model.eval()
model.requires_grad_(False)
model.to(cfg.device)
last_encoder = autoencoder.encoder.clone().detach()
for i, batch in enumerate(tqdm(token_loader)):
    tokens = batch["input_ids"].to(cfg.device)
    with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
        with Trace(model, tensor_names[0]) as ret:
            _ = model(tokens)
            representation = ret.output
            if(isinstance(representation, tuple)):
                representation = representation[0]
    layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
    # saved_activations.append(layer_activations.clone().cpu().detach())

    c = autoencoder.encode(layer_activations)
    x_hat = autoencoder.decode(c)
    
    reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
    l1_loss = torch.norm(c, 1, dim=-1).mean()
    total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss

    dead_features += c.sum(dim=0).cpu()
    total_activations += c.sum(dim=0).cpu()
    if (i % 500 == 0): # Check here so first check is model w/o change
        # self_similarity = torch.cosine_similarity(c, last_encoder, dim=-1).mean().cpu().item()
        # Above is wrong, should be similarity between encoder and last encoder
        self_similarity = torch.cosine_similarity(autoencoder.encoder, last_encoder, dim=-1).mean().cpu().item()
        last_encoder = autoencoder.encoder.clone().detach()
        num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            # Count number of dead_features are zero
            num_dead_features = (dead_features == 0).sum().item()
        print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.2f} | Reconstruction Loss: {reconstruction_loss:.2f} | L1 Loss: {cfg.l1_alpha*l1_loss:.2f} | l1_alpha: {cfg.l1_alpha:.2e} | Tokens: {num_tokens_so_far} | Self Similarity: {self_similarity:.2f}")
        wandb.log({
            'Sparsity': sparsity,
            'Dead Features': num_dead_features,
            'Total Loss': total_loss.item(),
            'Reconstruction Loss': reconstruction_loss.item(),
            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
            'l1_alpha': cfg.l1_alpha,
            'Tokens': num_tokens_so_far,
            'Self Similarity': self_similarity
        })
        
        dead_features = torch.zeros(autoencoder.encoder.shape[0])
        
        if(num_tokens_so_far > max_num_tokens):
            print(f"Reached max number of tokens: {max_num_tokens}")
            break
    
    resample_period = 500
    if (i % resample_period == 0):
        # RESAMPLING
        with torch.no_grad():
            # Count number of dead_features are zero
            num_dead_features = (total_activations == 0).sum().item()
            
        if num_dead_features > 0:
            # hyperparams:
            max_resample_tokens = 1000 # the number of token activations that we consider for inserting into the dictionary
            # compute loss of model on random subset of inputs
            resample_loader = setup_token_data(cfg, tokenizer, model, seed=i)
            num_resample_data = 0

            resample_activations = torch.empty(0, activation_size)
            resample_losses = torch.empty(0)

            for resample_batch in resample_loader:
                resample_tokens = resample_batch["input_ids"].to(cfg.device)
                with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
                    with Trace(model, tensor_names[0]) as ret:
                        _ = model(resample_tokens)
                        representation = ret.output
                        if(isinstance(representation, tuple)):
                            representation = representation[0]
                layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
                resample_activations = torch.cat((resample_activations, layer_activations.detach().cpu()), dim=0)

                c = autoencoder.encode(layer_activations)
                x_hat = autoencoder.decode(c)
                
                reconstruction_loss = (x_hat - layer_activations).pow(2).mean(dim=-1)
                l1_loss = torch.norm(c, 1, dim=-1)
                temp_loss = reconstruction_loss + cfg.l1_alpha*l1_loss
                
                resample_losses = torch.cat((resample_losses, temp_loss.detach().cpu()), dim=0)
                
                num_resample_data +=layer_activations.shape[0]
                if num_resample_data > max_resample_tokens:
                    break

                
            # sample num_dead_features vectors of input activations
            probabilities = resample_losses**2
            sampled_indices = torch.multinomial(probabilities, num_dead_features)
            new_vectors = resample_activations[sampled_indices]

            # calculate average encoder norm of alive neurons
            alive_neurons = list((total_activations!=0))
            modified_columns = total_activations==0
            avg_norm = autoencoder.encoder.data[alive_neurons].norm(dim=-1).mean()

            # replace dictionary and encoder weights with vectors
            new_vectors = new_vectors / new_vectors.norm(dim=1, keepdim=True)
            
            params_to_modify = [autoencoder.encoder, autoencoder.encoder_bias]

            current_weights = autoencoder.encoder.data
            current_weights[modified_columns] = (new_vectors.to(cfg.device) * avg_norm * 0.2)
            autoencoder.encoder.data = current_weights

            current_weights = autoencoder.encoder_bias.data
            current_weights[modified_columns] = 0
            autoencoder.encoder_bias.data = current_weights
            
            if hasattr(autoencoder, 'decoder'):
                current_weights = autoencoder.decoder.data
                current_weights[modified_columns] = new_vectors
                autoencoder.decoder.data = current_weights
                params_to_modify += [autoencoder.decoder]

            for param_group in optimizer.param_groups:
                for param in param_group['params']:
                    if any(param is d_ for d_ in params_to_modify):
                        # Extract the corresponding rows from m and v
                        m = optimizer.state[param]['exp_avg']
                        v = optimizer.state[param]['exp_avg_sq']
                        
                        # Update the m and v values for the modified columns
                        m[modified_columns] = 0  # Reset moving average for modified columns
                        v[modified_columns] = 0  # Reset squared moving average for modified columns
        
        total_activations = torch.zeros(autoencoder.encoder.shape[0])

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # # Running sparsity check
    # if(num_tokens_so_far > 500000):
    #     if(i % 200 == 0):
    #         with torch.no_grad():
    #             sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
    #         if sparsity > target_upper_sparsity:
    #             cfg.l1_alpha *= (1 + adjustment_factor)
    #         elif sparsity < target_lower_sparsity:
    #             cfg.l1_alpha *= (1 - adjustment_factor)
    #         # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")

  0%|          | 0/110107 [00:00<?, ?it/s]

Sparsity: 239.2 | Dead Features: 262 | Total Loss: 0.12 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 0 | Self Similarity: 1.00


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  0%|          | 506/110107 [00:12<38:24, 47.55it/s] 

Sparsity: 161.6 | Dead Features: 0 | Total Loss: 0.13 | Reconstruction Loss: 0.06 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 512000 | Self Similarity: 0.76


  1%|          | 996/110107 [00:22<37:50, 48.05it/s]

Sparsity: 163.9 | Dead Features: 19 | Total Loss: 0.12 | Reconstruction Loss: 0.06 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 1024000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  1%|▏         | 1496/110107 [00:34<37:44, 47.96it/s]  

Sparsity: 157.8 | Dead Features: 30 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 1536000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  2%|▏         | 1996/110107 [00:46<37:27, 48.11it/s]  

Sparsity: 157.3 | Dead Features: 18 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 2048000 | Self Similarity: 0.96


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  2%|▏         | 2496/110107 [00:58<37:12, 48.19it/s]  

Sparsity: 165.3 | Dead Features: 12 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 2560000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  3%|▎         | 2996/110107 [01:11<37:07, 48.10it/s]  

Sparsity: 150.6 | Dead Features: 18 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 3072000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  3%|▎         | 3496/110107 [01:23<36:53, 48.17it/s]  

Sparsity: 164.5 | Dead Features: 23 | Total Loss: 0.12 | Reconstruction Loss: 0.06 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 3584000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  4%|▎         | 3996/110107 [01:35<36:35, 48.33it/s]  

Sparsity: 145.2 | Dead Features: 16 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 4096000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  4%|▍         | 4496/110107 [01:47<36:26, 48.31it/s]  

Sparsity: 153.5 | Dead Features: 19 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 4608000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  5%|▍         | 4996/110107 [01:59<36:17, 48.27it/s]  

Sparsity: 145.7 | Dead Features: 17 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 5120000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  5%|▍         | 5496/110107 [02:11<36:07, 48.26it/s]  

Sparsity: 149.6 | Dead Features: 15 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 5632000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  5%|▌         | 5996/110107 [02:24<36:08, 48.01it/s]  

Sparsity: 147.9 | Dead Features: 13 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 6144000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  6%|▌         | 6496/110107 [02:36<35:55, 48.08it/s]  

Sparsity: 146.1 | Dead Features: 14 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 6656000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  6%|▋         | 6996/110107 [02:48<35:49, 47.98it/s]  

Sparsity: 144.0 | Dead Features: 12 | Total Loss: 0.11 | Reconstruction Loss: 0.06 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 7168000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  7%|▋         | 7496/110107 [03:00<35:32, 48.11it/s]  

Sparsity: 143.0 | Dead Features: 12 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 7680000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  7%|▋         | 7996/110107 [03:12<35:22, 48.11it/s]  

Sparsity: 138.8 | Dead Features: 12 | Total Loss: 0.10 | Reconstruction Loss: 0.05 | L1 Loss: 0.05 | l1_alpha: 1.00e-03 | Tokens: 8192000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  8%|▊         | 8496/110107 [03:24<35:01, 48.35it/s]  

Sparsity: 137.6 | Dead Features: 20 | Total Loss: 0.10 | Reconstruction Loss: 0.05 | L1 Loss: 0.05 | l1_alpha: 1.00e-03 | Tokens: 8704000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  8%|▊         | 8996/110107 [03:36<34:53, 48.29it/s]  

Sparsity: 135.4 | Dead Features: 2 | Total Loss: 0.10 | Reconstruction Loss: 0.05 | L1 Loss: 0.05 | l1_alpha: 1.00e-03 | Tokens: 9216000 | Self Similarity: 0.96


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  9%|▊         | 9496/110107 [03:48<34:41, 48.34it/s]  

Sparsity: 145.8 | Dead Features: 6 | Total Loss: 0.11 | Reconstruction Loss: 0.06 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 9728000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
  9%|▉         | 9996/110107 [04:00<34:31, 48.33it/s]  

Sparsity: 132.6 | Dead Features: 9 | Total Loss: 0.10 | Reconstruction Loss: 0.05 | L1 Loss: 0.05 | l1_alpha: 1.00e-03 | Tokens: 10240000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
 10%|▉         | 10496/110107 [04:13<34:21, 48.32it/s]  

Sparsity: 141.5 | Dead Features: 4 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 10752000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
 10%|▉         | 10996/110107 [04:25<34:13, 48.27it/s]  

Sparsity: 137.6 | Dead Features: 6 | Total Loss: 0.10 | Reconstruction Loss: 0.05 | L1 Loss: 0.05 | l1_alpha: 1.00e-03 | Tokens: 11264000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
 10%|█         | 11496/110107 [04:37<34:06, 48.19it/s]  

Sparsity: 139.3 | Dead Features: 5 | Total Loss: 0.11 | Reconstruction Loss: 0.05 | L1 Loss: 0.06 | l1_alpha: 1.00e-03 | Tokens: 11776000 | Self Similarity: 0.97


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8401ec7d4dbd84d2_*_of_00008.arrow
 11%|█         | 11980/110107 [04:49<39:30, 41.40it/s]  


KeyboardInterrupt: 

In [27]:
model_save_name = cfg.model_name.split("/")[-1]
save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}"  # trim year

# Make directory traiend_models if it doesn't exist
import os
if not os.path.exists("trained_models"):
    os.makedirs("trained_models")
# Save model
torch.save(autoencoder, f"trained_models/{save_name}.pt")

# if not os.path.exists("activations"):
#     os.makedirs("activations")
# # Save model
# torch.save(saved_activations[:-1], f"activations/{save_name}.pt")

In [28]:
wandb.finish()

0,1
Dead Features,▁▅█▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
L1 Loss,█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▁▃▂▂▃▂▂▂▂▂▂▂▁▁▂▁▂
Reconstruction Loss,█▂▁▂▂▁▂▄▂▃▂▂▂▂▂▃▃▂▁▂▁▂▂▂▂▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂
Self Similarity,▁▆▇▇███▆████████▇▇██████████████████████
Sparsity,█▅▄▃▂▂▂▂▁▂▂▂▂▂▂▂▁▁▂▂▂▂▁▁▂▁▁▂▂▂▁▁▁▂▂▁▁▁▁▂
Tokens,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Total Loss,█▃▂▂▁▁▁▃▁▂▂▁▁▂▂▂▂▁▁▂▁▂▁▁▂▂▂▂▂▁▁▁▂▂▁▁▁▁▁▂
l1_alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Dead Features,0.0
L1 Loss,0.05351
Reconstruction Loss,0.04791
Self Similarity,0.97119
Sparsity,128.99707
Tokens,100352000.0
Total Loss,0.10142
l1_alpha,0.001


In [None]:
# cfg.model_name="EleutherAI/pythia-70m-deduped"
cfg.model_name="usvsnsp/pythia-6.9b-sft"
cfg.target_name="usvsnsp/pythia-6.9b-ppo"
cfg.layers=[4]
cfg.setting="residual"
cfg.tensor_name="gpt_neox.layers.{layer}"
cfg.l1_alpha=1e-3
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=4
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cuda:0"
cfg.ratio = 4
# cfg.device="cpu"

model = model.cpu()
target_model = AutoModelForCausalLM.from_pretrained(cfg.target_name)
target_model = target_model.to(cfg.device)

In [None]:
# Initialize New autoencoder
from autoencoders.learned_dict import TiedSAE, UntiedSAE
from torch import nn

# params["decoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
# nn.init.xavier_uniform_(params["decoder"])

transfer_autoencoder = UntiedSAE(
    # n_feats = n_dict_components, 
    # activation_size=activation_size,
    encoder=autoencoder.encoder,
    encoder_bias=autoencoder.encoder_bias,
    decoder=autoencoder.encoder.clone(),
)
transfer_autoencoder.to_device(cfg.device)

# Set gradient to true for decoder only- only training decoder on transfer
transfer_autoencoder.encoder.requires_grad = False
transfer_autoencoder.encoder_bias.requires_grad = False
transfer_autoencoder.decoder.requires_grad = True
optimizer = torch.optim.Adam([transfer_autoencoder.decoder], lr=1e-3)


In [None]:
# Wandb setup
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
wandb_run_name = f"{cfg.target_name}_transfer_{start_time[4:]}_{cfg.sparsity}"  # trim year
print(f"wandb_run_name: {wandb_run_name}")

In [None]:
# Training transfer autoencoder
dead_features = torch.zeros(transfer_autoencoder.encoder.shape[0])
max_num_tokens = 100000000
# Freeze model parameters 
model = model.cpu()
target_model = target_model.to(cfg.device)
target_model.eval()
target_model.requires_grad_(False)
last_encoder = transfer_autoencoder.encoder.clone().detach()
for i, (batch, base_activations) in enumerate(tqdm(zip(token_loader, saved_activations), total=min(len(token_loader),len(saved_activations)))):
    tokens = batch["input_ids"].to(cfg.device)
    with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
        with Trace(target_model, tensor_names[0]) as ret:
            _ = target_model(tokens)
            representation = ret.output
            if(isinstance(representation, tuple)):
                representation = representation[0]
    # layer_activations will be the output of the autoencoder
    layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")

    c = transfer_autoencoder.encode(base_activations.to(cfg.device))
    x_hat = transfer_autoencoder.decode(c)
    
    reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
    l1_loss = torch.norm(c, 1, dim=-1).mean()
    total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss

    dead_features += c.sum(dim=0).cpu()
    if (i % 500 == 0): # Check here so first check is model w/o change
        # self_similarity = torch.cosine_similarity(c, last_encoder, dim=-1).mean().cpu().item()
        # Above is wrong, should be similarity between encoder and last encoder
        self_similarity = torch.cosine_similarity(transfer_autoencoder.encoder, last_encoder, dim=-1).mean().cpu().item()
        last_encoder = transfer_autoencoder.encoder.clone().detach()
        num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            # Count number of dead_features are zero
            num_dead_features = (dead_features == 0).sum().item()
        print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.2f} | Reconstruction Loss: {reconstruction_loss:.2f} | L1 Loss: {cfg.l1_alpha*l1_loss:.2f} | l1_alpha: {cfg.l1_alpha:.2e} | Tokens: {num_tokens_so_far} | Self Similarity: {self_similarity:.2f}")
        wandb.log({
            'Sparsity': sparsity,
            'Dead Features': num_dead_features,
            'Total Loss': total_loss.item(),
            'Reconstruction Loss': reconstruction_loss.item(),
            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
            'l1_alpha': cfg.l1_alpha,
            'Tokens': num_tokens_so_far,
            'Self Similarity': self_similarity
        })
        dead_features = torch.zeros(transfer_autoencoder.encoder.shape[0])
        
        if(num_tokens_so_far > max_num_tokens):
            print(f"Reached max number of tokens: {max_num_tokens}")
            break

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # Running sparsity check
    if(num_tokens_so_far > 500000):
        if(i % 200 == 0):
            with torch.no_grad():
                sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            if sparsity > target_upper_sparsity:
                cfg.l1_alpha *= (1 + adjustment_factor)
            elif sparsity < target_lower_sparsity:
                cfg.l1_alpha *= (1 - adjustment_factor)
            # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")

In [None]:
model_save_name = cfg.target_name.split("/")[-1]
save_name = f"{model_save_name}_transfer_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}"  # trim year

# Make directory traiend_models if it doesn't exist
import os
if not os.path.exists("trained_models"):
    os.makedirs("trained_models")
# Save model
torch.save(transfer_autoencoder, f"trained_models/{save_name}.pt")

In [None]:
wandb.finish()

In [None]:
cfg

{'model_name': 'EleutherAI/pythia-70m-deduped',
 'layers': [4],
 'setting': 'residual',
 'tensor_name': 'gpt_neox.layers.{layer}',
 'l1_alpha': 0.0020591228579666505,
 'sparsity': 51,
 'num_epochs': 10,
 'model_batch_size': 4,
 'lr': 0.001,
 'kl': False,
 'reconstruction': False,
 'dataset_name': 'NeelNanda/pile-10k',
 'device': 'cuda:0',
 'ratio': 4,
 'max_length': 256}