In [1]:
import torch 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt

cfg = dotdict()
# models: "EleutherAI/pythia-70m-deduped", "usvsnsp/pythia-6.9b-ppo", "lomahony/eleuther-pythia6.9b-hh-sft"
# cfg.model_name="lomahony/eleuther-pythia6.9b-hh-sft"
cfg.model_name="EleutherAI/pythia-6.9b-deduped"
cfg.target_name="usvsnsp/pythia-6.9b-ppo"
cfg.layers=[10]
cfg.setting="residual"
cfg.tensor_name="gpt_neox.layers.{layer}"
original_l1_alpha = 8e-4
cfg.l1_alpha=original_l1_alpha
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=8
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Elriggs/openwebtext-100k"
cfg.device="cuda:0"
cfg.ratio = 4
cfg.seed = 0
# cfg.device="cpu"

In [45]:
tensor_names = [cfg.tensor_name.format(layer=layer) for layer in cfg.layers]

In [2]:
# Load in the model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(cfg.model_name)
model = model.to(cfg.device)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)

In [None]:
# Download the dataset
# TODO iteratively grab dataset?
cfg.max_length = 256
token_loader = setup_token_data(cfg, tokenizer, model, seed=cfg.seed)
num_tokens = cfg.max_length*cfg.model_batch_size*len(token_loader)
print(f"Number of tokens: {num_tokens}")

Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-b71791e158b5e518_*_of_00008.arrow


Number of tokens: 112750592


In [None]:
# Run 1 datapoint on model to get the activation size
from baukit import Trace

text = "1"
tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
# Your activation name will be different. In the next cells, we will show you how to find it.
with torch.no_grad():
    with Trace(model, tensor_names[0]) as ret:
        _ = model(tokens)
        representation = ret.output
        # check if instance tuple
        if(isinstance(representation, tuple)):
            representation = representation[0]
        activation_size = representation.shape[-1]
print(f"Activation size: {activation_size}")

Activation size: 4096


In [None]:
# Initialize New autoencoder
from autoencoders.learned_dict import TiedSAE, UntiedSAE, AnthropicSAE
from torch import nn
params = dict()
n_dict_components = activation_size*cfg.ratio
params["encoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["encoder"])

params["decoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["decoder"])

params["encoder_bias"] = torch.empty((n_dict_components,), device=cfg.device)
nn.init.zeros_(params["encoder_bias"])

params["shift_bias"] = torch.empty((activation_size,), device=cfg.device)
nn.init.zeros_(params["shift_bias"])

autoencoder = AnthropicSAE(  # TiedSAE, UntiedSAE, AnthropicSAE
    # n_feats = n_dict_components, 
    # activation_size=activation_size,
    encoder=params["encoder"],
    encoder_bias=params["encoder_bias"],
    decoder=params["decoder"],
    shift_bias=params["shift_bias"],
)
autoencoder.to_device(cfg.device)
autoencoder.set_grad()
# autoencoder.encoder.requires_grad = True
# autoencoder.encoder_bias.requires_grad = True
# autoencoder.decoder.requires_grad = True
# autoencoder.shift_bias.requires_grad = True
optimizer = torch.optim.Adam(
    [
        autoencoder.encoder, 
        autoencoder.encoder_bias,
        autoencoder.decoder,
        autoencoder.shift_bias,
    ], lr=cfg.lr)

In [None]:
# Set target sparsity to 10% of activation_size if not set
if cfg.sparsity is None:
    cfg.sparsity = int(activation_size*0.05)
    print(f"Target sparsity: {cfg.sparsity}")

target_lower_sparsity = cfg.sparsity * 0.9
target_upper_sparsity = cfg.sparsity * 1.1
adjustment_factor = 0.1  # You can set this to whatever you like

Target sparsity: 204


In [None]:
original_bias = autoencoder.encoder_bias.clone().detach()
# Wandb setup
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
wandb_run_name = f"{cfg.model_name}_{start_time[4:]}_{cfg.sparsity}"  # trim year
print(f"wandb_run_name: {wandb_run_name}")



wandb_run_name: lomahony/eleuther-pythia6.9b-hh-sft_1012-231100_204


In [None]:
wandb.init(project="sparse coding", config=dict(cfg), name=wandb_run_name)

0,1
Dead Features,▂▁▇█████▇█▇▁
L1 Loss,▁▃▂▁▁▁▂▂▂▂▂█
Reconstruction Loss,▁▂▁▁▁▁▁▁▁▁▁█
Self Similarity,█▄▆███▇████▁
Sparsity,▆▅▃▂▂▁▂▂▁▁▁█
Tokens,▁▂▂▃▄▄▅▅▆▇▇█
Total Loss,▁▂▁▁▁▁▁▁▁▁▁█
l1_alpha,▁▁▁▁▁▁▄▄▄▄▄█

0,1
Dead Features,834.0
L1 Loss,40.15882
Reconstruction Loss,258842.82812
Self Similarity,0.63484
Sparsity,10146.43848
Tokens,450560.0
Total Loss,258882.98438
l1_alpha,0.00012


In [None]:
time_since_activation = torch.zeros(autoencoder.encoder.shape[0])
total_activations = torch.zeros(autoencoder.encoder.shape[0])
max_num_tokens = 30_000_000
save_every = 5_000_000
num_saved_so_far = 0
# Freeze model parameters 
model.eval()
model.requires_grad_(False)
model.to(cfg.device)
last_encoder = autoencoder.encoder.clone().detach()
for i, batch in enumerate(tqdm(token_loader)):
    tokens = batch["input_ids"].to(cfg.device)
    with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
        with Trace(model, tensor_names[0]) as ret:
            _ = model(tokens)
            representation = ret.output
            if(isinstance(representation, tuple)):
                representation = representation[0]
    layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
    # activation_saver.save_batch(layer_activations.clone().cpu().detach())

    c = autoencoder.encode(layer_activations)
    x_hat = autoencoder.decode(c)
    
    reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
    l1_loss = torch.norm(c, 1, dim=-1).mean()
    total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss

    time_since_activation += 1
    time_since_activation = time_since_activation * (c.sum(dim=0).cpu()==0)
    total_activations += c.sum(dim=0).cpu()
    if ((i+1) % 10 == 0): # Check here so first check is model w/o change
        # self_similarity = torch.cosine_similarity(c, last_encoder, dim=-1).mean().cpu().item()
        # Above is wrong, should be similarity between encoder and last encoder
        self_similarity = torch.cosine_similarity(autoencoder.encoder, last_encoder, dim=-1).mean().cpu().item()
        last_encoder = autoencoder.encoder.clone().detach()
        num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            # Count number of dead_features are zero
            num_dead_features = (time_since_activation >= min(i, 200)).sum().item()
        print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.2f} | Reconstruction Loss: {reconstruction_loss:.2f} | L1 Loss: {cfg.l1_alpha*l1_loss:.2f} | l1_alpha: {cfg.l1_alpha:.2e} | Tokens: {num_tokens_so_far} | Self Similarity: {self_similarity:.2f}")
        wandb.log({
            'Sparsity': sparsity,
            'Dead Features': num_dead_features,
            'Total Loss': total_loss.item(),
            'Reconstruction Loss': reconstruction_loss.item(),
            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
            'l1_alpha': cfg.l1_alpha,
            'Tokens': num_tokens_so_far,
            'Self Similarity': self_similarity
        })
        
        dead_features = torch.zeros(autoencoder.encoder.shape[0])
        
        if(num_tokens_so_far > max_num_tokens):
            print(f"Reached max number of tokens: {max_num_tokens}")
            break
        
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # resample_period = 10000
    # if (i % resample_period == 0):
    #     # RESAMPLING
    #     with torch.no_grad():
    #         # Count number of dead_features are zero
    #         num_dead_features = (total_activations == 0).sum().item()
    #         print(f"Dead Features: {num_dead_features}")
            
    #     if num_dead_features > 0:
    #         print("Resampling!")
    #         # hyperparams:
    #         max_resample_tokens = 1000 # the number of token activations that we consider for inserting into the dictionary
    #         # compute loss of model on random subset of inputs
    #         resample_loader = setup_token_data(cfg, tokenizer, model, seed=i)
    #         num_resample_data = 0

    #         resample_activations = torch.empty(0, activation_size)
    #         resample_losses = torch.empty(0)

    #         for resample_batch in resample_loader:
    #             resample_tokens = resample_batch["input_ids"].to(cfg.device)
    #             with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
    #                 with Trace(model, tensor_names[0]) as ret:
    #                     _ = model(resample_tokens)
    #                     representation = ret.output
    #                     if(isinstance(representation, tuple)):
    #                         representation = representation[0]
    #             layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
    #             resample_activations = torch.cat((resample_activations, layer_activations.detach().cpu()), dim=0)

    #             c = autoencoder.encode(layer_activations)
    #             x_hat = autoencoder.decode(c)
                
    #             reconstruction_loss = (x_hat - layer_activations).pow(2).mean(dim=-1)
    #             l1_loss = torch.norm(c, 1, dim=-1)
    #             temp_loss = reconstruction_loss + cfg.l1_alpha*l1_loss
                
    #             resample_losses = torch.cat((resample_losses, temp_loss.detach().cpu()), dim=0)
                
    #             num_resample_data +=layer_activations.shape[0]
    #             if num_resample_data > max_resample_tokens:
    #                 break

                
    #         # sample num_dead_features vectors of input activations
    #         probabilities = resample_losses**2
    #         probabilities /= probabilities.sum()
    #         sampled_indices = torch.multinomial(probabilities, num_dead_features, replacement=True)
    #         new_vectors = resample_activations[sampled_indices]

    #         # calculate average encoder norm of alive neurons
    #         alive_neurons = list((total_activations!=0))
    #         modified_columns = total_activations==0
    #         avg_norm = autoencoder.encoder.data[alive_neurons].norm(dim=-1).mean()

    #         # replace dictionary and encoder weights with vectors
    #         new_vectors = new_vectors / new_vectors.norm(dim=1, keepdim=True)
            
    #         params_to_modify = [autoencoder.encoder, autoencoder.encoder_bias]

    #         current_weights = autoencoder.encoder.data
    #         current_weights[modified_columns] = (new_vectors.to(cfg.device) * avg_norm * 0.02)
    #         autoencoder.encoder.data = current_weights

    #         current_weights = autoencoder.encoder_bias.data
    #         current_weights[modified_columns] = 0
    #         autoencoder.encoder_bias.data = current_weights
            
    #         if hasattr(autoencoder, 'decoder'):
    #             current_weights = autoencoder.decoder.data
    #             current_weights[modified_columns] = new_vectors.to(cfg.device)
    #             autoencoder.decoder.data = current_weights
    #             params_to_modify += [autoencoder.decoder]

    #         for param_group in optimizer.param_groups:
    #             for param in param_group['params']:
    #                 if any(param is d_ for d_ in params_to_modify):
    #                     # Extract the corresponding rows from m and v
    #                     m = optimizer.state[param]['exp_avg']
    #                     v = optimizer.state[param]['exp_avg_sq']
                        
    #                     # Update the m and v values for the modified columns
    #                     m[modified_columns] = 0  # Reset moving average for modified columns
    #                     v[modified_columns] = 0  # Reset squared moving average for modified columns
        
    #     total_activations = torch.zeros(autoencoder.encoder.shape[0])

    if ((i+2) % 10_000==0): # save periodically but before big changes
        model_save_name = cfg.model_name.split("/")[-1]
        save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}_ckpt{num_saved_so_far}"  # trim year

        # Make directory traiend_models if it doesn't exist
        import os
        if not os.path.exists("trained_models"):
            os.makedirs("trained_models")
        # Save model
        torch.save(autoencoder, f"trained_models/{save_name}.pt")
        
        num_saved_so_far += 1

    # Running sparsity check
    num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
    if(num_tokens_so_far > 200000):
        if(i % 100 == 0):
            with torch.no_grad():
                sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            if sparsity > target_upper_sparsity:
                cfg.l1_alpha *= (1 + adjustment_factor)
            elif sparsity < target_lower_sparsity:
                cfg.l1_alpha *= (1 - adjustment_factor)
            # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")

  0%|          | 0/55054 [00:00<?, ?it/s]

Sparsity: 8141.0 | Dead Features: 1050 | Total Loss: 127.88 | Reconstruction Loss: 100.38 | L1 Loss: 27.50 | l1_alpha: 1.00e-03 | Tokens: 0 | Self Similarity: 1.00
Resampling!


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-b71791e158b5e518_*_of_00008.arrow
  0%|          | 21/55054 [00:38<26:33:23,  1.74s/it]

Sparsity: 6981.5 | Dead Features: 537 | Total Loss: 40316.38 | Reconstruction Loss: 40188.09 | L1 Loss: 128.29 | l1_alpha: 1.00e-03 | Tokens: 40960 | Self Similarity: 0.79


  0%|          | 41/55054 [01:13<26:53:29,  1.76s/it]

Sparsity: 4907.9 | Dead Features: 6059 | Total Loss: 2266.19 | Reconstruction Loss: 2212.78 | L1 Loss: 53.41 | l1_alpha: 1.00e-03 | Tokens: 81920 | Self Similarity: 0.91


  0%|          | 61/55054 [01:49<27:02:40,  1.77s/it]

Sparsity: 4136.7 | Dead Features: 6458 | Total Loss: 701.96 | Reconstruction Loss: 663.84 | L1 Loss: 38.12 | l1_alpha: 1.00e-03 | Tokens: 122880 | Self Similarity: 0.99


  0%|          | 81/55054 [02:24<27:13:16,  1.78s/it]

Sparsity: 4039.4 | Dead Features: 6696 | Total Loss: 399.27 | Reconstruction Loss: 363.15 | L1 Loss: 36.11 | l1_alpha: 1.00e-03 | Tokens: 163840 | Self Similarity: 1.00


  0%|          | 101/55054 [03:00<27:43:29,  1.82s/it]

Sparsity: 3885.1 | Dead Features: 6826 | Total Loss: 242.77 | Reconstruction Loss: 208.32 | L1 Loss: 34.44 | l1_alpha: 1.00e-03 | Tokens: 204800 | Self Similarity: 1.00


  0%|          | 121/55054 [03:35<27:19:15,  1.79s/it]

Sparsity: 3763.8 | Dead Features: 6947 | Total Loss: 206.72 | Reconstruction Loss: 168.37 | L1 Loss: 38.35 | l1_alpha: 1.10e-03 | Tokens: 245760 | Self Similarity: 1.00


  0%|          | 141/55054 [04:11<27:20:10,  1.79s/it]

Sparsity: 3814.2 | Dead Features: 6953 | Total Loss: 173.20 | Reconstruction Loss: 136.76 | L1 Loss: 36.44 | l1_alpha: 1.10e-03 | Tokens: 286720 | Self Similarity: 1.00


  0%|          | 161/55054 [04:47<27:20:37,  1.79s/it]

Sparsity: 3657.7 | Dead Features: 6651 | Total Loss: 144.73 | Reconstruction Loss: 109.40 | L1 Loss: 35.33 | l1_alpha: 1.10e-03 | Tokens: 327680 | Self Similarity: 1.00


  0%|          | 181/55054 [05:23<27:22:13,  1.80s/it]

Sparsity: 3294.5 | Dead Features: 6794 | Total Loss: 107.94 | Reconstruction Loss: 75.26 | L1 Loss: 32.69 | l1_alpha: 1.10e-03 | Tokens: 368640 | Self Similarity: 1.00


  0%|          | 201/55054 [05:59<27:43:54,  1.82s/it]

Sparsity: 3311.4 | Dead Features: 6470 | Total Loss: 107.22 | Reconstruction Loss: 74.61 | L1 Loss: 32.62 | l1_alpha: 1.10e-03 | Tokens: 409600 | Self Similarity: 1.00


  0%|          | 221/55054 [06:35<27:21:41,  1.80s/it]

Sparsity: 3818.1 | Dead Features: 6968 | Total Loss: 112.52 | Reconstruction Loss: 69.79 | L1 Loss: 42.74 | l1_alpha: 1.21e-03 | Tokens: 450560 | Self Similarity: 1.00


  0%|          | 241/55054 [07:11<27:20:45,  1.80s/it]

Sparsity: 3393.1 | Dead Features: 6909 | Total Loss: 99.42 | Reconstruction Loss: 63.28 | L1 Loss: 36.15 | l1_alpha: 1.21e-03 | Tokens: 491520 | Self Similarity: 1.00


  0%|          | 261/55054 [07:46<27:20:27,  1.80s/it]

Sparsity: 3519.1 | Dead Features: 6834 | Total Loss: 91.68 | Reconstruction Loss: 52.23 | L1 Loss: 39.46 | l1_alpha: 1.21e-03 | Tokens: 532480 | Self Similarity: 1.00


  1%|          | 281/55054 [08:22<27:18:51,  1.80s/it]

Sparsity: 3241.7 | Dead Features: 6838 | Total Loss: 87.69 | Reconstruction Loss: 53.67 | L1 Loss: 34.01 | l1_alpha: 1.21e-03 | Tokens: 573440 | Self Similarity: 1.00


  1%|          | 301/55054 [08:58<27:42:25,  1.82s/it]

Sparsity: 3343.0 | Dead Features: 6778 | Total Loss: 76.05 | Reconstruction Loss: 40.24 | L1 Loss: 35.81 | l1_alpha: 1.21e-03 | Tokens: 614400 | Self Similarity: 1.00


  1%|          | 321/55054 [09:34<27:19:06,  1.80s/it]

Sparsity: 3173.2 | Dead Features: 6554 | Total Loss: 79.95 | Reconstruction Loss: 41.45 | L1 Loss: 38.50 | l1_alpha: 1.33e-03 | Tokens: 655360 | Self Similarity: 1.00


  1%|          | 341/55054 [10:10<27:18:50,  1.80s/it]

Sparsity: 3347.9 | Dead Features: 6800 | Total Loss: 75.65 | Reconstruction Loss: 38.20 | L1 Loss: 37.46 | l1_alpha: 1.33e-03 | Tokens: 696320 | Self Similarity: 1.00


  1%|          | 361/55054 [10:46<27:18:59,  1.80s/it]

Sparsity: 3130.9 | Dead Features: 6812 | Total Loss: 73.99 | Reconstruction Loss: 37.96 | L1 Loss: 36.03 | l1_alpha: 1.33e-03 | Tokens: 737280 | Self Similarity: 1.00


  1%|          | 381/55054 [11:22<27:18:53,  1.80s/it]

Sparsity: 3152.9 | Dead Features: 6843 | Total Loss: 75.01 | Reconstruction Loss: 37.04 | L1 Loss: 37.97 | l1_alpha: 1.33e-03 | Tokens: 778240 | Self Similarity: 1.00


  1%|          | 401/55054 [11:58<27:40:43,  1.82s/it]

Sparsity: 2957.2 | Dead Features: 6734 | Total Loss: 66.34 | Reconstruction Loss: 30.07 | L1 Loss: 36.27 | l1_alpha: 1.33e-03 | Tokens: 819200 | Self Similarity: 1.00


  1%|          | 421/55054 [12:34<27:15:21,  1.80s/it]

Sparsity: 3211.1 | Dead Features: 6752 | Total Loss: 73.91 | Reconstruction Loss: 32.35 | L1 Loss: 41.55 | l1_alpha: 1.46e-03 | Tokens: 860160 | Self Similarity: 1.00


  1%|          | 441/55054 [13:09<27:15:35,  1.80s/it]

Sparsity: 3054.1 | Dead Features: 6763 | Total Loss: 126.96 | Reconstruction Loss: 86.20 | L1 Loss: 40.76 | l1_alpha: 1.46e-03 | Tokens: 901120 | Self Similarity: 1.00


  1%|          | 461/55054 [13:45<27:13:56,  1.80s/it]

Sparsity: 3604.3 | Dead Features: 6789 | Total Loss: 85.78 | Reconstruction Loss: 38.85 | L1 Loss: 46.93 | l1_alpha: 1.46e-03 | Tokens: 942080 | Self Similarity: 1.00


  1%|          | 481/55054 [14:21<27:14:50,  1.80s/it]

Sparsity: 2823.3 | Dead Features: 6811 | Total Loss: 73.82 | Reconstruction Loss: 36.68 | L1 Loss: 37.14 | l1_alpha: 1.46e-03 | Tokens: 983040 | Self Similarity: 1.00


  1%|          | 501/55054 [14:57<27:38:24,  1.82s/it]

Sparsity: 2941.0 | Dead Features: 6573 | Total Loss: 76.75 | Reconstruction Loss: 38.04 | L1 Loss: 38.71 | l1_alpha: 1.46e-03 | Tokens: 1024000 | Self Similarity: 1.00


  1%|          | 521/55054 [15:33<27:12:40,  1.80s/it]

Sparsity: 2874.4 | Dead Features: 6879 | Total Loss: 86.41 | Reconstruction Loss: 46.03 | L1 Loss: 40.38 | l1_alpha: 1.61e-03 | Tokens: 1064960 | Self Similarity: 1.00


  1%|          | 541/55054 [16:09<27:12:13,  1.80s/it]

Sparsity: 3302.0 | Dead Features: 6815 | Total Loss: 88.59 | Reconstruction Loss: 42.33 | L1 Loss: 46.26 | l1_alpha: 1.61e-03 | Tokens: 1105920 | Self Similarity: 1.00


  1%|          | 561/55054 [16:45<27:12:13,  1.80s/it]

Sparsity: 3196.2 | Dead Features: 6863 | Total Loss: 86.50 | Reconstruction Loss: 42.62 | L1 Loss: 43.87 | l1_alpha: 1.61e-03 | Tokens: 1146880 | Self Similarity: 1.00


  1%|          | 581/55054 [17:21<27:11:58,  1.80s/it]

Sparsity: 2275.2 | Dead Features: 6790 | Total Loss: 63.26 | Reconstruction Loss: 29.62 | L1 Loss: 33.64 | l1_alpha: 1.61e-03 | Tokens: 1187840 | Self Similarity: 1.00


  1%|          | 601/55054 [17:57<27:34:51,  1.82s/it]

Sparsity: 2759.2 | Dead Features: 6850 | Total Loss: 71.02 | Reconstruction Loss: 33.88 | L1 Loss: 37.14 | l1_alpha: 1.61e-03 | Tokens: 1228800 | Self Similarity: 1.00


  1%|          | 621/55054 [18:32<27:10:49,  1.80s/it]

Sparsity: 2445.1 | Dead Features: 6892 | Total Loss: 61.29 | Reconstruction Loss: 23.78 | L1 Loss: 37.51 | l1_alpha: 1.77e-03 | Tokens: 1269760 | Self Similarity: 1.00


  1%|          | 641/55054 [19:08<27:10:27,  1.80s/it]

Sparsity: 2412.6 | Dead Features: 6759 | Total Loss: 63.94 | Reconstruction Loss: 24.12 | L1 Loss: 39.83 | l1_alpha: 1.77e-03 | Tokens: 1310720 | Self Similarity: 1.00


  1%|          | 661/55054 [19:44<27:08:13,  1.80s/it]

Sparsity: 2347.0 | Dead Features: 6455 | Total Loss: 63.42 | Reconstruction Loss: 24.82 | L1 Loss: 38.60 | l1_alpha: 1.77e-03 | Tokens: 1351680 | Self Similarity: 1.00


  1%|          | 681/55054 [20:20<27:08:07,  1.80s/it]

Sparsity: 2418.3 | Dead Features: 6390 | Total Loss: 63.34 | Reconstruction Loss: 24.85 | L1 Loss: 38.49 | l1_alpha: 1.77e-03 | Tokens: 1392640 | Self Similarity: 1.00


  1%|▏         | 701/55054 [20:56<27:30:28,  1.82s/it]

Sparsity: 2200.9 | Dead Features: 6991 | Total Loss: 80.72 | Reconstruction Loss: 44.18 | L1 Loss: 36.53 | l1_alpha: 1.77e-03 | Tokens: 1433600 | Self Similarity: 1.00


  1%|▏         | 721/55054 [21:32<27:07:55,  1.80s/it]

Sparsity: 2014.4 | Dead Features: 7051 | Total Loss: 75.19 | Reconstruction Loss: 37.82 | L1 Loss: 37.38 | l1_alpha: 1.95e-03 | Tokens: 1474560 | Self Similarity: 1.00


  1%|▏         | 741/55054 [22:08<27:06:15,  1.80s/it]

Sparsity: 2407.4 | Dead Features: 6940 | Total Loss: 74.70 | Reconstruction Loss: 32.78 | L1 Loss: 41.92 | l1_alpha: 1.95e-03 | Tokens: 1515520 | Self Similarity: 1.00


  1%|▏         | 761/55054 [22:44<27:06:28,  1.80s/it]

Sparsity: 2123.8 | Dead Features: 6538 | Total Loss: 61.44 | Reconstruction Loss: 21.85 | L1 Loss: 39.59 | l1_alpha: 1.95e-03 | Tokens: 1556480 | Self Similarity: 1.00


  1%|▏         | 781/55054 [23:20<27:04:37,  1.80s/it]

Sparsity: 2195.0 | Dead Features: 7110 | Total Loss: 72.26 | Reconstruction Loss: 34.62 | L1 Loss: 37.64 | l1_alpha: 1.95e-03 | Tokens: 1597440 | Self Similarity: 1.00


  1%|▏         | 801/55054 [23:56<27:29:07,  1.82s/it]

Sparsity: 2415.5 | Dead Features: 6854 | Total Loss: 61.08 | Reconstruction Loss: 21.85 | L1 Loss: 39.23 | l1_alpha: 1.95e-03 | Tokens: 1638400 | Self Similarity: 1.00


  1%|▏         | 821/55054 [24:31<27:03:16,  1.80s/it]

Sparsity: 2061.2 | Dead Features: 6842 | Total Loss: 60.63 | Reconstruction Loss: 21.91 | L1 Loss: 38.72 | l1_alpha: 2.14e-03 | Tokens: 1679360 | Self Similarity: 1.00


  2%|▏         | 841/55054 [25:07<27:03:44,  1.80s/it]

Sparsity: 2174.2 | Dead Features: 7131 | Total Loss: 68.75 | Reconstruction Loss: 27.93 | L1 Loss: 40.82 | l1_alpha: 2.14e-03 | Tokens: 1720320 | Self Similarity: 1.00


  2%|▏         | 861/55054 [25:43<27:03:07,  1.80s/it]

Sparsity: 1742.5 | Dead Features: 7081 | Total Loss: 53.16 | Reconstruction Loss: 18.24 | L1 Loss: 34.92 | l1_alpha: 2.14e-03 | Tokens: 1761280 | Self Similarity: 1.00


  2%|▏         | 881/55054 [26:19<27:02:23,  1.80s/it]

Sparsity: 2030.7 | Dead Features: 6913 | Total Loss: 64.56 | Reconstruction Loss: 24.82 | L1 Loss: 39.74 | l1_alpha: 2.14e-03 | Tokens: 1802240 | Self Similarity: 1.00


  2%|▏         | 901/55054 [26:55<27:24:30,  1.82s/it]

Sparsity: 1465.5 | Dead Features: 6979 | Total Loss: 47.02 | Reconstruction Loss: 16.06 | L1 Loss: 30.96 | l1_alpha: 2.14e-03 | Tokens: 1843200 | Self Similarity: 1.00


  2%|▏         | 921/55054 [27:31<27:01:23,  1.80s/it]

Sparsity: 1481.2 | Dead Features: 7062 | Total Loss: 52.79 | Reconstruction Loss: 18.45 | L1 Loss: 34.34 | l1_alpha: 2.36e-03 | Tokens: 1884160 | Self Similarity: 1.00


  2%|▏         | 941/55054 [28:07<27:00:09,  1.80s/it]

Sparsity: 1345.0 | Dead Features: 6986 | Total Loss: 48.53 | Reconstruction Loss: 13.09 | L1 Loss: 35.44 | l1_alpha: 2.36e-03 | Tokens: 1925120 | Self Similarity: 1.00


  2%|▏         | 961/55054 [28:43<26:58:47,  1.80s/it]

Sparsity: 1391.4 | Dead Features: 7052 | Total Loss: 48.94 | Reconstruction Loss: 14.03 | L1 Loss: 34.91 | l1_alpha: 2.36e-03 | Tokens: 1966080 | Self Similarity: 1.00


  2%|▏         | 981/55054 [29:18<26:58:20,  1.80s/it]

Sparsity: 1676.1 | Dead Features: 6818 | Total Loss: 50.48 | Reconstruction Loss: 15.44 | L1 Loss: 35.04 | l1_alpha: 2.36e-03 | Tokens: 2007040 | Self Similarity: 1.00


  2%|▏         | 1000/55054 [29:52<26:55:24,  1.79s/it]

Sparsity: 1557.6 | Dead Features: 7068 | Total Loss: 52.93 | Reconstruction Loss: 17.47 | L1 Loss: 35.46 | l1_alpha: 2.36e-03 | Tokens: 2048000 | Self Similarity: 1.00
Resampling!


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-b71791e158b5e518_*_of_00008.arrow
  2%|▏         | 1021/55054 [30:33<26:56:45,  1.80s/it]

Sparsity: 962.9 | Dead Features: 6526 | Total Loss: 7648.69 | Reconstruction Loss: 7595.11 | L1 Loss: 53.59 | l1_alpha: 2.59e-03 | Tokens: 2088960 | Self Similarity: 0.89


  2%|▏         | 1041/55054 [31:09<26:57:31,  1.80s/it]

Sparsity: 497.6 | Dead Features: 6886 | Total Loss: 693.27 | Reconstruction Loss: 664.87 | L1 Loss: 28.40 | l1_alpha: 2.59e-03 | Tokens: 2129920 | Self Similarity: 0.99


  2%|▏         | 1061/55054 [31:45<26:56:37,  1.80s/it]

Sparsity: 436.6 | Dead Features: 6952 | Total Loss: 261.01 | Reconstruction Loss: 237.93 | L1 Loss: 23.08 | l1_alpha: 2.59e-03 | Tokens: 2170880 | Self Similarity: 1.00


  2%|▏         | 1081/55054 [32:21<26:55:34,  1.80s/it]

Sparsity: 428.7 | Dead Features: 7288 | Total Loss: 160.30 | Reconstruction Loss: 140.11 | L1 Loss: 20.18 | l1_alpha: 2.59e-03 | Tokens: 2211840 | Self Similarity: 1.00


  2%|▏         | 1101/55054 [32:57<27:18:50,  1.82s/it]

Sparsity: 428.6 | Dead Features: 6663 | Total Loss: 113.02 | Reconstruction Loss: 91.14 | L1 Loss: 21.88 | l1_alpha: 2.59e-03 | Tokens: 2252800 | Self Similarity: 1.00


  2%|▏         | 1121/55054 [33:32<26:54:27,  1.80s/it]

Sparsity: 401.0 | Dead Features: 7322 | Total Loss: 92.44 | Reconstruction Loss: 69.09 | L1 Loss: 23.35 | l1_alpha: 2.85e-03 | Tokens: 2293760 | Self Similarity: 1.00


  2%|▏         | 1141/55054 [34:08<26:54:43,  1.80s/it]

Sparsity: 310.8 | Dead Features: 7325 | Total Loss: 53.07 | Reconstruction Loss: 33.52 | L1 Loss: 19.55 | l1_alpha: 2.85e-03 | Tokens: 2334720 | Self Similarity: 1.00


  2%|▏         | 1161/55054 [34:44<26:54:06,  1.80s/it]

Sparsity: 365.6 | Dead Features: 7204 | Total Loss: 55.61 | Reconstruction Loss: 33.43 | L1 Loss: 22.17 | l1_alpha: 2.85e-03 | Tokens: 2375680 | Self Similarity: 1.00


  2%|▏         | 1181/55054 [35:20<26:53:27,  1.80s/it]

Sparsity: 314.9 | Dead Features: 7079 | Total Loss: 40.39 | Reconstruction Loss: 19.94 | L1 Loss: 20.45 | l1_alpha: 2.85e-03 | Tokens: 2416640 | Self Similarity: 1.00


  2%|▏         | 1201/55054 [35:56<27:13:41,  1.82s/it]

Sparsity: 319.5 | Dead Features: 7220 | Total Loss: 42.20 | Reconstruction Loss: 23.23 | L1 Loss: 18.96 | l1_alpha: 2.85e-03 | Tokens: 2457600 | Self Similarity: 1.00


  2%|▏         | 1221/55054 [36:32<26:51:09,  1.80s/it]

Sparsity: 323.8 | Dead Features: 7301 | Total Loss: 36.75 | Reconstruction Loss: 16.34 | L1 Loss: 20.41 | l1_alpha: 3.14e-03 | Tokens: 2498560 | Self Similarity: 1.00


  2%|▏         | 1241/55054 [37:08<26:50:16,  1.80s/it]

Sparsity: 316.1 | Dead Features: 7213 | Total Loss: 37.97 | Reconstruction Loss: 16.46 | L1 Loss: 21.52 | l1_alpha: 3.14e-03 | Tokens: 2539520 | Self Similarity: 1.00


  2%|▏         | 1261/55054 [37:44<26:48:14,  1.79s/it]

Sparsity: 324.4 | Dead Features: 7240 | Total Loss: 38.71 | Reconstruction Loss: 17.32 | L1 Loss: 21.39 | l1_alpha: 3.14e-03 | Tokens: 2580480 | Self Similarity: 1.00


  2%|▏         | 1281/55054 [38:19<26:48:34,  1.79s/it]

Sparsity: 251.5 | Dead Features: 7169 | Total Loss: 31.39 | Reconstruction Loss: 12.62 | L1 Loss: 18.77 | l1_alpha: 3.14e-03 | Tokens: 2621440 | Self Similarity: 1.00


  2%|▏         | 1301/55054 [38:55<27:10:51,  1.82s/it]

Sparsity: 273.5 | Dead Features: 7154 | Total Loss: 31.76 | Reconstruction Loss: 12.06 | L1 Loss: 19.70 | l1_alpha: 3.14e-03 | Tokens: 2662400 | Self Similarity: 1.00


  2%|▏         | 1321/55054 [39:31<26:45:01,  1.79s/it]

Sparsity: 342.4 | Dead Features: 6980 | Total Loss: 38.65 | Reconstruction Loss: 15.72 | L1 Loss: 22.93 | l1_alpha: 3.45e-03 | Tokens: 2703360 | Self Similarity: 1.00


  2%|▏         | 1341/55054 [40:07<26:43:28,  1.79s/it]

Sparsity: 203.4 | Dead Features: 6935 | Total Loss: 28.04 | Reconstruction Loss: 10.37 | L1 Loss: 17.67 | l1_alpha: 3.45e-03 | Tokens: 2744320 | Self Similarity: 1.00


  2%|▏         | 1361/55054 [40:43<26:41:37,  1.79s/it]

Sparsity: 291.5 | Dead Features: 7450 | Total Loss: 33.06 | Reconstruction Loss: 11.85 | L1 Loss: 21.21 | l1_alpha: 3.45e-03 | Tokens: 2785280 | Self Similarity: 1.00


  3%|▎         | 1381/55054 [41:18<26:40:22,  1.79s/it]

Sparsity: 255.7 | Dead Features: 7314 | Total Loss: 30.22 | Reconstruction Loss: 11.48 | L1 Loss: 18.74 | l1_alpha: 3.45e-03 | Tokens: 2826240 | Self Similarity: 1.00


  3%|▎         | 1401/55054 [41:54<27:02:56,  1.81s/it]

Sparsity: 281.9 | Dead Features: 7042 | Total Loss: 33.43 | Reconstruction Loss: 13.54 | L1 Loss: 19.90 | l1_alpha: 3.45e-03 | Tokens: 2867200 | Self Similarity: 1.00


  3%|▎         | 1421/55054 [42:30<26:39:12,  1.79s/it]

Sparsity: 250.3 | Dead Features: 7427 | Total Loss: 38.99 | Reconstruction Loss: 19.34 | L1 Loss: 19.65 | l1_alpha: 3.80e-03 | Tokens: 2908160 | Self Similarity: 1.00


  3%|▎         | 1441/55054 [43:06<26:37:25,  1.79s/it]

Sparsity: 232.6 | Dead Features: 7411 | Total Loss: 31.07 | Reconstruction Loss: 9.58 | L1 Loss: 21.49 | l1_alpha: 3.80e-03 | Tokens: 2949120 | Self Similarity: 1.00


  3%|▎         | 1461/55054 [43:41<26:38:03,  1.79s/it]

Sparsity: 270.0 | Dead Features: 7281 | Total Loss: 32.61 | Reconstruction Loss: 10.77 | L1 Loss: 21.84 | l1_alpha: 3.80e-03 | Tokens: 2990080 | Self Similarity: 1.00


  3%|▎         | 1481/55054 [44:17<26:36:11,  1.79s/it]

Sparsity: 290.7 | Dead Features: 7225 | Total Loss: 34.47 | Reconstruction Loss: 12.43 | L1 Loss: 22.04 | l1_alpha: 3.80e-03 | Tokens: 3031040 | Self Similarity: 1.00


  3%|▎         | 1501/55054 [44:53<26:58:05,  1.81s/it]

Sparsity: 240.0 | Dead Features: 7379 | Total Loss: 29.42 | Reconstruction Loss: 9.08 | L1 Loss: 20.34 | l1_alpha: 3.80e-03 | Tokens: 3072000 | Self Similarity: 1.00


  3%|▎         | 1521/55054 [45:28<26:34:32,  1.79s/it]

Sparsity: 249.4 | Dead Features: 7139 | Total Loss: 36.71 | Reconstruction Loss: 14.64 | L1 Loss: 22.07 | l1_alpha: 4.18e-03 | Tokens: 3112960 | Self Similarity: 1.00


  3%|▎         | 1541/55054 [46:04<26:37:37,  1.79s/it]

Sparsity: 214.0 | Dead Features: 7401 | Total Loss: 28.53 | Reconstruction Loss: 8.49 | L1 Loss: 20.04 | l1_alpha: 4.18e-03 | Tokens: 3153920 | Self Similarity: 1.00


  3%|▎         | 1561/55054 [46:40<26:38:39,  1.79s/it]

Sparsity: 235.1 | Dead Features: 7395 | Total Loss: 29.17 | Reconstruction Loss: 7.89 | L1 Loss: 21.29 | l1_alpha: 4.18e-03 | Tokens: 3194880 | Self Similarity: 1.00


  3%|▎         | 1581/55054 [47:16<26:38:29,  1.79s/it]

Sparsity: 229.2 | Dead Features: 7503 | Total Loss: 29.74 | Reconstruction Loss: 8.58 | L1 Loss: 21.16 | l1_alpha: 4.18e-03 | Tokens: 3235840 | Self Similarity: 1.00


  3%|▎         | 1601/55054 [47:52<27:00:37,  1.82s/it]

Sparsity: 248.9 | Dead Features: 7260 | Total Loss: 31.16 | Reconstruction Loss: 11.52 | L1 Loss: 19.64 | l1_alpha: 4.18e-03 | Tokens: 3276800 | Self Similarity: 1.00


  3%|▎         | 1621/55054 [48:27<26:37:15,  1.79s/it]

Sparsity: 216.7 | Dead Features: 7304 | Total Loss: 31.90 | Reconstruction Loss: 9.90 | L1 Loss: 22.01 | l1_alpha: 4.59e-03 | Tokens: 3317760 | Self Similarity: 1.00


  3%|▎         | 1641/55054 [49:03<26:33:26,  1.79s/it]

Sparsity: 213.6 | Dead Features: 7616 | Total Loss: 32.86 | Reconstruction Loss: 9.71 | L1 Loss: 23.15 | l1_alpha: 4.59e-03 | Tokens: 3358720 | Self Similarity: 1.00


  3%|▎         | 1661/55054 [49:39<26:32:45,  1.79s/it]

Sparsity: 188.5 | Dead Features: 7100 | Total Loss: 30.02 | Reconstruction Loss: 10.08 | L1 Loss: 19.94 | l1_alpha: 4.59e-03 | Tokens: 3399680 | Self Similarity: 1.00


  3%|▎         | 1681/55054 [50:15<26:30:40,  1.79s/it]

Sparsity: 210.5 | Dead Features: 7273 | Total Loss: 31.05 | Reconstruction Loss: 8.79 | L1 Loss: 22.27 | l1_alpha: 4.59e-03 | Tokens: 3440640 | Self Similarity: 1.00


  3%|▎         | 1701/55054 [50:50<26:53:49,  1.81s/it]

Sparsity: 172.3 | Dead Features: 7225 | Total Loss: 26.30 | Reconstruction Loss: 7.14 | L1 Loss: 19.16 | l1_alpha: 4.59e-03 | Tokens: 3481600 | Self Similarity: 1.00


  3%|▎         | 1721/55054 [51:26<26:28:45,  1.79s/it]

Sparsity: 189.9 | Dead Features: 7670 | Total Loss: 25.18 | Reconstruction Loss: 7.20 | L1 Loss: 17.98 | l1_alpha: 4.14e-03 | Tokens: 3522560 | Self Similarity: 1.00


  3%|▎         | 1741/55054 [52:02<26:29:00,  1.79s/it]

Sparsity: 300.4 | Dead Features: 7259 | Total Loss: 41.74 | Reconstruction Loss: 21.07 | L1 Loss: 20.68 | l1_alpha: 4.14e-03 | Tokens: 3563520 | Self Similarity: 1.00


  3%|▎         | 1761/55054 [52:37<26:30:01,  1.79s/it]

Sparsity: 173.8 | Dead Features: 7191 | Total Loss: 25.51 | Reconstruction Loss: 7.35 | L1 Loss: 18.17 | l1_alpha: 4.14e-03 | Tokens: 3604480 | Self Similarity: 1.00


  3%|▎         | 1781/55054 [53:13<26:26:21,  1.79s/it]

Sparsity: 209.7 | Dead Features: 7634 | Total Loss: 27.66 | Reconstruction Loss: 7.68 | L1 Loss: 19.98 | l1_alpha: 4.14e-03 | Tokens: 3645440 | Self Similarity: 1.00


  3%|▎         | 1801/55054 [53:49<26:49:48,  1.81s/it]

Sparsity: 184.3 | Dead Features: 7175 | Total Loss: 25.04 | Reconstruction Loss: 7.21 | L1 Loss: 17.83 | l1_alpha: 4.14e-03 | Tokens: 3686400 | Self Similarity: 1.00


  3%|▎         | 1821/55054 [54:25<26:28:11,  1.79s/it]

Sparsity: 218.9 | Dead Features: 7521 | Total Loss: 26.63 | Reconstruction Loss: 7.84 | L1 Loss: 18.80 | l1_alpha: 4.14e-03 | Tokens: 3727360 | Self Similarity: 1.00


  3%|▎         | 1841/55054 [55:00<26:25:50,  1.79s/it]

Sparsity: 174.8 | Dead Features: 7371 | Total Loss: 23.93 | Reconstruction Loss: 7.22 | L1 Loss: 16.71 | l1_alpha: 4.14e-03 | Tokens: 3768320 | Self Similarity: 1.00


  3%|▎         | 1861/55054 [55:36<26:23:48,  1.79s/it]

Sparsity: 236.5 | Dead Features: 7253 | Total Loss: 29.44 | Reconstruction Loss: 10.52 | L1 Loss: 18.91 | l1_alpha: 4.14e-03 | Tokens: 3809280 | Self Similarity: 1.00


  3%|▎         | 1881/55054 [56:12<26:23:43,  1.79s/it]

Sparsity: 155.4 | Dead Features: 7478 | Total Loss: 22.75 | Reconstruction Loss: 6.84 | L1 Loss: 15.91 | l1_alpha: 4.14e-03 | Tokens: 3850240 | Self Similarity: 1.00


  3%|▎         | 1901/55054 [56:48<26:45:42,  1.81s/it]

Sparsity: 169.6 | Dead Features: 7636 | Total Loss: 24.37 | Reconstruction Loss: 7.36 | L1 Loss: 17.01 | l1_alpha: 4.14e-03 | Tokens: 3891200 | Self Similarity: 1.00


  3%|▎         | 1921/55054 [57:23<26:20:23,  1.78s/it]

Sparsity: 183.0 | Dead Features: 7612 | Total Loss: 23.64 | Reconstruction Loss: 7.15 | L1 Loss: 16.48 | l1_alpha: 3.72e-03 | Tokens: 3932160 | Self Similarity: 1.00


  4%|▎         | 1941/55054 [57:59<26:20:08,  1.79s/it]

Sparsity: 151.7 | Dead Features: 7530 | Total Loss: 22.38 | Reconstruction Loss: 7.69 | L1 Loss: 14.69 | l1_alpha: 3.72e-03 | Tokens: 3973120 | Self Similarity: 1.00


  4%|▎         | 1961/55054 [58:34<26:19:41,  1.79s/it]

Sparsity: 165.2 | Dead Features: 7575 | Total Loss: 21.45 | Reconstruction Loss: 7.69 | L1 Loss: 13.76 | l1_alpha: 3.72e-03 | Tokens: 4014080 | Self Similarity: 1.00


  4%|▎         | 1981/55054 [59:10<26:19:39,  1.79s/it]

Sparsity: 213.4 | Dead Features: 7156 | Total Loss: 23.83 | Reconstruction Loss: 8.22 | L1 Loss: 15.62 | l1_alpha: 3.72e-03 | Tokens: 4055040 | Self Similarity: 1.00


  4%|▎         | 2000/55054 [59:44<26:16:23,  1.78s/it]

Sparsity: 153.8 | Dead Features: 7650 | Total Loss: 20.80 | Reconstruction Loss: 7.65 | L1 Loss: 13.15 | l1_alpha: 3.72e-03 | Tokens: 4096000 | Self Similarity: 1.00
Resampling!


Found cached dataset parquet (/root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-b71791e158b5e518_*_of_00008.arrow
  4%|▎         | 2021/55054 [1:00:24<26:17:27,  1.78s/it]

Sparsity: 1222.4 | Dead Features: 1880 | Total Loss: 68389.82 | Reconstruction Loss: 68081.30 | L1 Loss: 308.52 | l1_alpha: 3.35e-03 | Tokens: 4136960 | Self Similarity: 0.60


  4%|▎         | 2041/55054 [1:01:00<26:19:48,  1.79s/it]

Sparsity: 683.3 | Dead Features: 2073 | Total Loss: 11818.96 | Reconstruction Loss: 11585.87 | L1 Loss: 233.09 | l1_alpha: 3.35e-03 | Tokens: 4177920 | Self Similarity: 0.99


  4%|▎         | 2061/55054 [1:01:36<26:18:52,  1.79s/it]

Sparsity: 656.9 | Dead Features: 2336 | Total Loss: 3978.73 | Reconstruction Loss: 3748.73 | L1 Loss: 229.99 | l1_alpha: 3.35e-03 | Tokens: 4218880 | Self Similarity: 1.00


  4%|▍         | 2081/55054 [1:02:11<26:18:10,  1.79s/it]

Sparsity: 692.7 | Dead Features: 2353 | Total Loss: 1495.34 | Reconstruction Loss: 1259.14 | L1 Loss: 236.20 | l1_alpha: 3.35e-03 | Tokens: 4259840 | Self Similarity: 1.00


  4%|▍         | 2101/55054 [1:02:47<26:39:47,  1.81s/it]

Sparsity: 649.1 | Dead Features: 2412 | Total Loss: 861.87 | Reconstruction Loss: 623.89 | L1 Loss: 237.98 | l1_alpha: 3.35e-03 | Tokens: 4300800 | Self Similarity: 1.00


  4%|▍         | 2121/55054 [1:03:23<26:15:33,  1.79s/it]

Sparsity: 590.9 | Dead Features: 2247 | Total Loss: 542.31 | Reconstruction Loss: 305.02 | L1 Loss: 237.29 | l1_alpha: 3.68e-03 | Tokens: 4341760 | Self Similarity: 1.00


  4%|▍         | 2141/55054 [1:03:58<26:17:45,  1.79s/it]

Sparsity: 618.9 | Dead Features: 2119 | Total Loss: 581.16 | Reconstruction Loss: 342.47 | L1 Loss: 238.69 | l1_alpha: 3.68e-03 | Tokens: 4382720 | Self Similarity: 1.00


  4%|▍         | 2161/55054 [1:04:34<26:13:57,  1.79s/it]

Sparsity: 629.7 | Dead Features: 2187 | Total Loss: 539.79 | Reconstruction Loss: 297.09 | L1 Loss: 242.70 | l1_alpha: 3.68e-03 | Tokens: 4423680 | Self Similarity: 1.00


  4%|▍         | 2181/55054 [1:05:10<26:14:05,  1.79s/it]

Sparsity: 637.8 | Dead Features: 1829 | Total Loss: 461.66 | Reconstruction Loss: 237.66 | L1 Loss: 224.00 | l1_alpha: 3.68e-03 | Tokens: 4464640 | Self Similarity: 1.00


  4%|▍         | 2201/55054 [1:05:46<26:36:28,  1.81s/it]

Sparsity: 611.7 | Dead Features: 1982 | Total Loss: 428.45 | Reconstruction Loss: 172.66 | L1 Loss: 255.78 | l1_alpha: 3.68e-03 | Tokens: 4505600 | Self Similarity: 1.00


  4%|▍         | 2221/55054 [1:06:21<26:11:43,  1.78s/it]

Sparsity: 614.5 | Dead Features: 2333 | Total Loss: 428.60 | Reconstruction Loss: 144.72 | L1 Loss: 283.87 | l1_alpha: 4.05e-03 | Tokens: 4546560 | Self Similarity: 1.00


  4%|▍         | 2241/55054 [1:06:57<26:13:05,  1.79s/it]

Sparsity: 613.9 | Dead Features: 2153 | Total Loss: 389.20 | Reconstruction Loss: 115.80 | L1 Loss: 273.40 | l1_alpha: 4.05e-03 | Tokens: 4587520 | Self Similarity: 1.00


  4%|▍         | 2261/55054 [1:07:33<26:15:40,  1.79s/it]

Sparsity: 625.9 | Dead Features: 2462 | Total Loss: 470.61 | Reconstruction Loss: 178.82 | L1 Loss: 291.79 | l1_alpha: 4.05e-03 | Tokens: 4628480 | Self Similarity: 1.00


  4%|▍         | 2281/55054 [1:08:08<26:14:44,  1.79s/it]

Sparsity: 635.9 | Dead Features: 2383 | Total Loss: 423.92 | Reconstruction Loss: 141.29 | L1 Loss: 282.63 | l1_alpha: 4.05e-03 | Tokens: 4669440 | Self Similarity: 1.00


  4%|▍         | 2301/55054 [1:08:44<26:36:38,  1.82s/it]

Sparsity: 587.7 | Dead Features: 2246 | Total Loss: 365.01 | Reconstruction Loss: 95.43 | L1 Loss: 269.59 | l1_alpha: 4.05e-03 | Tokens: 4710400 | Self Similarity: 1.00


  4%|▍         | 2321/55054 [1:09:20<26:14:19,  1.79s/it]

Sparsity: 593.2 | Dead Features: 2147 | Total Loss: 382.01 | Reconstruction Loss: 97.32 | L1 Loss: 284.69 | l1_alpha: 4.46e-03 | Tokens: 4751360 | Self Similarity: 1.00


  4%|▍         | 2341/55054 [1:09:56<26:12:36,  1.79s/it]

Sparsity: 598.4 | Dead Features: 2238 | Total Loss: 367.81 | Reconstruction Loss: 77.45 | L1 Loss: 290.36 | l1_alpha: 4.46e-03 | Tokens: 4792320 | Self Similarity: 1.00


  4%|▍         | 2361/55054 [1:10:31<26:10:06,  1.79s/it]

Sparsity: 610.0 | Dead Features: 2246 | Total Loss: 387.74 | Reconstruction Loss: 93.20 | L1 Loss: 294.54 | l1_alpha: 4.46e-03 | Tokens: 4833280 | Self Similarity: 1.00


  4%|▍         | 2381/55054 [1:11:07<26:09:54,  1.79s/it]

Sparsity: 608.8 | Dead Features: 2317 | Total Loss: 412.15 | Reconstruction Loss: 113.16 | L1 Loss: 298.99 | l1_alpha: 4.46e-03 | Tokens: 4874240 | Self Similarity: 1.00


  4%|▍         | 2401/55054 [1:11:43<26:32:09,  1.81s/it]

Sparsity: 595.9 | Dead Features: 2530 | Total Loss: 371.84 | Reconstruction Loss: 74.75 | L1 Loss: 297.09 | l1_alpha: 4.46e-03 | Tokens: 4915200 | Self Similarity: 1.00


  4%|▍         | 2421/55054 [1:12:18<26:09:36,  1.79s/it]

Sparsity: 602.2 | Dead Features: 2091 | Total Loss: 396.61 | Reconstruction Loss: 61.72 | L1 Loss: 334.88 | l1_alpha: 4.90e-03 | Tokens: 4956160 | Self Similarity: 1.00


  4%|▍         | 2441/55054 [1:12:54<26:08:54,  1.79s/it]

Sparsity: 610.7 | Dead Features: 2184 | Total Loss: 402.14 | Reconstruction Loss: 73.28 | L1 Loss: 328.86 | l1_alpha: 4.90e-03 | Tokens: 4997120 | Self Similarity: 1.00


  4%|▍         | 2460/55054 [1:13:28<26:07:46,  1.79s/it]

Sparsity: 675.9 | Dead Features: 2171 | Total Loss: 512.79 | Reconstruction Loss: 191.09 | L1 Loss: 321.69 | l1_alpha: 4.90e-03 | Tokens: 5038080 | Self Similarity: 1.00


  5%|▍         | 2480/55054 [1:14:26<41:45:18,  2.86s/it]

Sparsity: 590.3 | Dead Features: 2176 | Total Loss: 385.50 | Reconstruction Loss: 51.91 | L1 Loss: 333.60 | l1_alpha: 4.90e-03 | Tokens: 5079040 | Self Similarity: 1.00


  5%|▍         | 2500/55054 [1:15:28<47:06:25,  3.23s/it]

Sparsity: 570.3 | Dead Features: 2337 | Total Loss: 367.45 | Reconstruction Loss: 48.65 | L1 Loss: 318.80 | l1_alpha: 4.90e-03 | Tokens: 5120000 | Self Similarity: 1.00


  5%|▍         | 2503/55054 [1:15:38<26:28:06,  1.81s/it]


KeyboardInterrupt: 

In [24]:
model_save_name = cfg.model_name.split("/")[-1]
save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}"  # trim year

# Make directory traiend_models if it doesn't exist
import os
if not os.path.exists("trained_models"):
    os.makedirs("trained_models")
# Save model
torch.save(autoencoder, f"trained_models/{save_name}.pt")

In [18]:
wandb.finish()

0,1
Dead Features,▁▁▁▁▁▂▃▂█▅▃▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
L1 Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Reconstruction Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Self Similarity,█▁▃▃▃▃▂▄▂▄▅▄▄▅▅▅▅▄▂▅▅▅▅▄▄▅▄▅▅▅▄▅▅▅▅▄▅▅▄▅
Sparsity,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Total Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1_alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Dead Features,0.0
L1 Loss,0.04268
Reconstruction Loss,0.04672
Self Similarity,0.97426
Sparsity,139.97461
Tokens,50688000.0
Total Loss,0.08939
l1_alpha,0.001
