In [1]:
#----- imports --------

import tqdm
import torch
from torch import nn
from torch.nn import functional as F
import wandb
import os
import tokenizers
from matplotlib import pyplot as plt
import numpy as np
import json
import random
import tqdm


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    "sae_learning_rate": 5e-5,
    "model_embedding_layer": 6,
    "eval_interval": 500,
    "max_iters": 60000, 
    "H": 32, # hidden dimension size
    "B": 64,
    "T": 256,
    "C": 256,
    "feedforward_factor": 3,
    "n_heads": 8,
    "n_layers": 12,
    "tokenizer_vocab_size": 2**13,
    "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v


#wandb.init(
#    project = "tinystories",
#    config = config,
#)

In [2]:
class SparseAutoEncoder(nn.Module):
    def __init__(self, activations_dim, sparse_dim):
        super().__init__()
        self.activations_dim = activations_dim
        encoder_weight = torch.randn(activations_dim, sparse_dim)
        decoder_weight = torch.randn(sparse_dim, activations_dim)
        self.encoder_bias = nn.Parameter(torch.zeros(sparse_dim))
        self.decoder_bias = nn.Parameter(torch.zeros(activations_dim))
        self.sparse_dim = sparse_dim
        self.sparsity_penalty = 30

        # set the encoder_weight to have the activations dim to be normalized to have l2 norm randomly between 0.05 and 1
        direction_lengths = torch.rand(sparse_dim) * 0.95 + 0.05
        # normalize the encoder_weight along columns (dim -2) to have l2 norm of 1
        encoder_weight = F.normalize(encoder_weight, p=2, dim=0)
        # multiply the column norms by the direction_lengths
        encoder_weight = encoder_weight * direction_lengths
        # make the decoder weight be the transpose of the encoder_weight
        decoder_weight = torch.transpose(encoder_weight, 0, 1)

        self.encoder_weight = nn.Parameter(encoder_weight)
        self.decoder_weight = nn.Parameter(decoder_weight)



    def forward(self, x):
        # preprocessing normalization
        # now on average any embedding has euclidian length 1

        encoded = F.relu(x @ self.encoder_weight + self.encoder_bias) # all act. are positive
        decoded = encoded @ self.decoder_weight + self.decoder_bias

        reconstruction_l2_loss = F.mse_loss(x, decoded)

        # every row in the tall decoder matrix
        # is the "sum" of the total influence of a feature on the output
        # the l2 norm of that row is the "influence" of that feature on that output
        # calculate that, store as row
        decoder_l2 = torch.linalg.norm(self.decoder_weight, dim=-1)
        # the feature activation is the sparse activation * it's influence on output
        feature_activations = (encoded) * decoder_l2
        # sum of feature activations
        # divide by the batch size * sequence length
        # should work if there is no batch dimension
        if x.ndim == 3:
            batch_dim, sequence_dim, _ = x.shape
        elif x.ndim == 2:
            batch_dim = 1
            sequence_dim, _ = x.shape
        elif x.ndim == 1:
            batch_dim = 1
            sequence_dim = 1
        else:
            raise ValueError(f"x has {x.ndim} dimensions, but it should have 1, 2, or 3")
        
        sparsity_loss = torch.sum(feature_activations) * self.sparsity_penalty / (batch_dim * sequence_dim * self.sparse_dim)

        total_loss = reconstruction_l2_loss + sparsity_loss

        return {"encoded": encoded, "decoded": decoded, 'feature_activations': feature_activations, "reconstruction_loss": reconstruction_l2_loss, "sparsity_loss": sparsity_loss, "total_loss": total_loss}




# def test_sparse_autoencoder_sequence_independence():
#     sae = SparseAutoEncoder(100, 300)
#     input_embedding = torch.randn(2, 100)

#     input_embedding_modified = input_embedding.clone()
#     # modify the first in the sequence to be zeros
#     input_embedding_modified[0, :] = torch.zeros(100)
#     original_output = sae.forward(input_embedding)['decoded']
#     modified_output = sae.forward(input_embedding_modified)['decoded']

#     # make sure the last embedding in the sequence is the same, and the first is different
#     assert torch.all(original_output[-1, :] == modified_output[-1, :])
#     assert not torch.all(original_output[0, :] == modified_output[0, :])

# test_sparse_autoencoder_sequence_independence()

# def count_num_params_in_sae(sparse_dim_size):
#     sae = SparseAutoEncoder(C, sparse_dim_size)
#     num_params = sum(p.numel() for p in sae.parameters())
#     return num_params

# print("number of parameters:", count_num_params_in_sae(2**14))
# print(2**14)

# def reconstruction_training_run(embedding_size, sparse_dim_size, iters, eval_frequency):
#     def generate_tensor():
#         return torch.randn(embedding_size)
    
#     sae = SparseAutoEncoder(embedding_size, sparse_dim_size)
    
#     optimizer = torch.optim.Adam(sae.parameters(), lr=0.0005)
    
#     for i in range(iters):
#         optimizer.zero_grad()
#         input_tensor = generate_tensor()
#         output = sae.forward(input_tensor)
#         loss = output['total_loss']
#         loss.backward()
        
#         # Clip the gradient norm to 1
#         torch.nn.utils.clip_grad_norm_(sae.parameters(), max_norm=1)
        
#         optimizer.step()
        
#         if i % eval_frequency == 0:
#             avg_loss = 0
#             num_samples = 100  # Number of samples to average over
#             for _ in range(num_samples):
#                 input_tensor = generate_tensor()
#                 output = sae.forward(input_tensor)
#                 avg_loss += output['total_loss'].item()
#             avg_loss /= num_samples
#             print(f"Iteration {i}, Average Total Loss: {avg_loss}")
#     return sae

# reconstruction_training_run(C, 2**14, 10_000, 500)





In [3]:
sae = SparseAutoEncoder(C, 2**14)
optimizer = torch.optim.Adam(sae.parameters(), lr=sae_learning_rate)


In [4]:
def load_tensor(filepath):
    # load the .pt tensor
    tensor = torch.load(filepath)
    tensor = torch.cat(tensor, dim=0)
    tensor = tensor.to(device)
    return tensor
    

In [18]:
@torch.no_grad()
def estimate_sae_loss(eval_iters, tensor):
    sae_loss = 0
    sae_sparsity_loss = 0
    sae_reconstruction_loss = 0
    count = 0
    for i in range(0, eval_iters, B):
        count += 1
        start = i
        end = i+B
        assert tensor.shape[0] >= end, f"too many eval_iters"
        sample = tensor[start:end]
        sae_output = sae(sample)
        sae_loss += sae_output['total_loss'].item()
        sae_sparsity_loss += sae_output['sparsity_loss'].item()
        sae_reconstruction_loss += sae_output['reconstruction_loss'].item()
    avg_loss = sae_loss/count
    avg_sparsity_loss = sae_sparsity_loss/count
    avg_reconstruction_loss = sae_reconstruction_loss/count
    return {"reconstruction_loss": avg_reconstruction_loss, "sparsity_loss": avg_sparsity_loss, "total_loss": avg_loss}
    


estimate_sae_loss(100, load_tensor("residuals/residuals_train_1.pt"))

{'reconstruction_loss': 0.16089461743831635,
 'sparsity_loss': 0.32641203701496124,
 'total_loss': 0.4873066544532776}

In [6]:
train_filepaths = []
val_filepaths = []
for file in os.listdir(f'residuals'):
    if file.startswith(f"residuals_train"):
        train_filepaths.append(f"residuals/{file}")
    elif file.startswith(f"residuals_val"):
        val_filepaths.append(f"residuals/{file}")



In [9]:
optimizer = torch.optim.Adam(sae.parameters(), lr=sae_learning_rate)
num_epochs = 1
logging_interval = 50000

for epoch in range(num_epochs):
    for filepath in train_filepaths:
        val_residuals_tensor = load_tensor(random.choice(val_filepaths))
        print(f"val loss on next datafile")
        print(estimate_sae_loss(1000, val_residuals_tensor))
        del val_residuals_tensor
        residuals_tensor = load_tensor(filepath)
        print(f"train loss on next datafile")
        print(estimate_sae_loss(1000, residuals_tensor))
        print(f"training on {filepath}")

        for i in tqdm.tqdm(range(0, residuals_tensor.shape[0]-B, B)):
            start = i
            end = i+B
            assert residuals_tensor.shape[0] >= end, f"too many train samples"
            sample = residuals_tensor[start:end]
            optimizer.zero_grad()
            sae_output = sae(sample)
            sae_reconstruction_loss = sae_output['reconstruction_loss']
            sae_sparsity_loss = sae_output['sparsity_loss']
            total_loss = sae_reconstruction_loss + sae_sparsity_loss
            total_loss.backward()
            optimizer.step()
            if i % logging_interval == 0:
                print(f"reconstruction loss: {sae_reconstruction_loss}, sparsity loss: {sae_sparsity_loss}, total loss: {total_loss}")
            
            


val loss on next datafile
{'reconstruction_loss': 3.5548341926187274e-05, 'sparsity_loss': 0.00010388851584866643, 'total_loss': 0.0001394368577748537}
train loss on next datafile
{'reconstruction_loss': 3.75913274474442e-05, 'sparsity_loss': 9.729796694591642e-05, 'total_loss': 0.00013488929532468318}
training on residuals/residuals_train_5.pt


  0%|          | 51/51187 [00:00<01:40, 508.82it/s]

reconstruction loss: 0.1483488380908966, sparsity loss: 0.3810155689716339, total loss: 0.5293644070625305


  6%|▋         | 3282/51187 [00:06<00:35, 1336.69it/s]

reconstruction loss: 0.14869186282157898, sparsity loss: 0.3963930606842041, total loss: 0.5450849533081055


 12%|█▏        | 6282/51187 [00:11<02:31, 295.46it/s] 

reconstruction loss: 0.11764596402645111, sparsity loss: 0.3542681634426117, total loss: 0.4719141125679016


 18%|█▊        | 9416/51187 [00:22<02:20, 298.04it/s]

reconstruction loss: 0.1268046796321869, sparsity loss: 0.35661062598228455, total loss: 0.48341530561447144


 24%|██▍       | 12535/51187 [00:33<02:12, 291.76it/s]

reconstruction loss: 0.1402653157711029, sparsity loss: 0.39713966846466064, total loss: 0.5374050140380859


 31%|███       | 15668/51187 [00:43<02:02, 290.59it/s]

reconstruction loss: 0.15164753794670105, sparsity loss: 0.3817576467990875, total loss: 0.5334051847457886


 37%|███▋      | 18894/51187 [00:54<00:35, 904.85it/s]

reconstruction loss: 0.15155085921287537, sparsity loss: 0.3948303461074829, total loss: 0.5463812351226807


 43%|████▎     | 21919/51187 [00:56<00:21, 1368.44it/s]

reconstruction loss: 0.1438024640083313, sparsity loss: 0.3988403081893921, total loss: 0.5426427721977234


 49%|████▉     | 25207/51187 [00:59<00:19, 1349.16it/s]

reconstruction loss: 0.138739675283432, sparsity loss: 0.38116219639778137, total loss: 0.5199018716812134


 55%|█████▌    | 28387/51187 [01:01<00:16, 1379.67it/s]

reconstruction loss: 0.1387862265110016, sparsity loss: 0.370105504989624, total loss: 0.5088917016983032


 61%|██████▏   | 31436/51187 [01:03<00:14, 1393.92it/s]

reconstruction loss: 0.16090913116931915, sparsity loss: 0.4195156395435333, total loss: 0.5804247856140137


 68%|██████▊   | 34565/51187 [01:06<00:12, 1286.84it/s]

reconstruction loss: 0.14397752285003662, sparsity loss: 0.3625220060348511, total loss: 0.5064995288848877


 74%|███████▎  | 37710/51187 [01:09<00:09, 1361.98it/s]

reconstruction loss: 0.16432525217533112, sparsity loss: 0.3723389804363251, total loss: 0.5366642475128174


 80%|███████▉  | 40796/51187 [01:12<00:07, 1364.04it/s]

reconstruction loss: 0.13558253645896912, sparsity loss: 0.35248318314552307, total loss: 0.4880657196044922


 86%|████████▌ | 43950/51187 [01:15<00:06, 1165.20it/s]

reconstruction loss: 0.13328345119953156, sparsity loss: 0.370769739151001, total loss: 0.5040531754493713


 92%|█████████▏| 47128/51187 [01:17<00:02, 1377.11it/s]

reconstruction loss: 0.16803371906280518, sparsity loss: 0.38334259390830994, total loss: 0.5513763427734375


 98%|█████████▊| 50158/51187 [01:20<00:00, 1366.44it/s]

reconstruction loss: 0.14581726491451263, sparsity loss: 0.3812530040740967, total loss: 0.5270702838897705


100%|██████████| 51187/51187 [01:20<00:00, 632.98it/s] 


val loss on next datafile
{'reconstruction_loss': 3.59435398131609e-05, 'sparsity_loss': 9.587590536102653e-05, 'total_loss': 0.0001318194456398487}
train loss on next datafile
{'reconstruction_loss': 3.81002991925925e-05, 'sparsity_loss': 8.824215782806277e-05, 'total_loss': 0.0001263424572534859}
training on residuals/residuals_train_11.pt


  0%|          | 135/51187 [00:00<00:37, 1347.70it/s]

reconstruction loss: 0.14387914538383484, sparsity loss: 0.3431031405925751, total loss: 0.4869822859764099


  6%|▌         | 3176/51187 [00:05<02:43, 292.90it/s] 

reconstruction loss: 0.1388176679611206, sparsity loss: 0.3594816327095032, total loss: 0.4982993006706238


 12%|█▏        | 6286/51187 [00:16<02:38, 282.79it/s]

reconstruction loss: 0.14521798491477966, sparsity loss: 0.3552238345146179, total loss: 0.5004417896270752


 19%|█▊        | 9557/51187 [00:26<00:47, 876.89it/s]

reconstruction loss: 0.1348564177751541, sparsity loss: 0.3888508975505829, total loss: 0.5237073302268982


 25%|██▍       | 12726/51187 [00:29<00:28, 1353.87it/s]

reconstruction loss: 0.1430855095386505, sparsity loss: 0.3606286346912384, total loss: 0.5037141442298889


 31%|███       | 15801/51187 [00:32<00:25, 1386.19it/s]

reconstruction loss: 0.14745020866394043, sparsity loss: 0.3647986948490143, total loss: 0.5122488737106323


 37%|███▋      | 18915/51187 [00:34<00:23, 1373.07it/s]

reconstruction loss: 0.14994826912879944, sparsity loss: 0.3760119676589966, total loss: 0.5259602069854736


 43%|████▎     | 22087/51187 [00:36<00:21, 1372.39it/s]

reconstruction loss: 0.14737321436405182, sparsity loss: 0.3662495017051697, total loss: 0.5136227011680603


 49%|████▉     | 25270/51187 [00:39<00:18, 1377.73it/s]

reconstruction loss: 0.14832377433776855, sparsity loss: 0.34452712535858154, total loss: 0.4928508996963501


 55%|█████▌    | 28308/51187 [00:41<00:16, 1378.23it/s]

reconstruction loss: 0.13512852787971497, sparsity loss: 0.3594658374786377, total loss: 0.49459436535835266


 61%|██████▏   | 31432/51187 [00:43<00:14, 1409.12it/s]

reconstruction loss: 0.14751499891281128, sparsity loss: 0.37893250584602356, total loss: 0.5264475345611572


 68%|██████▊   | 34620/51187 [00:45<00:12, 1370.34it/s]

reconstruction loss: 0.16775670647621155, sparsity loss: 0.36535874009132385, total loss: 0.5331154465675354


 73%|███████▎  | 37541/51187 [00:50<00:31, 429.82it/s] 

reconstruction loss: 0.1458468735218048, sparsity loss: 0.3511880040168762, total loss: 0.49703487753868103


 80%|███████▉  | 40874/51187 [00:54<00:07, 1372.87it/s]

reconstruction loss: 0.1629129946231842, sparsity loss: 0.3662186563014984, total loss: 0.5291316509246826


 86%|████████▌ | 43913/51187 [00:56<00:05, 1374.25it/s]

reconstruction loss: 0.1598433405160904, sparsity loss: 0.34953150153160095, total loss: 0.5093748569488525


 92%|█████████▏| 47098/51187 [00:59<00:02, 1370.26it/s]

reconstruction loss: 0.1475600004196167, sparsity loss: 0.34575679898262024, total loss: 0.49331679940223694


 98%|█████████▊| 50141/51187 [01:01<00:00, 1369.05it/s]

reconstruction loss: 0.15373921394348145, sparsity loss: 0.3415603041648865, total loss: 0.4952995181083679


100%|██████████| 51187/51187 [01:02<00:00, 824.67it/s] 


val loss on next datafile
{'reconstruction_loss': 3.649495798163116e-05, 'sparsity_loss': 8.998134545981884e-05, 'total_loss': 0.0001264763022772968}
train loss on next datafile
{'reconstruction_loss': 4.1004138998687264e-05, 'sparsity_loss': 8.189393719658256e-05, 'total_loss': 0.00012289807479828597}
training on residuals/residuals_train_6.pt


  0%|          | 33/51187 [00:00<02:36, 326.38it/s]

reconstruction loss: 0.17818717658519745, sparsity loss: 0.33161661028862, total loss: 0.5098037719726562


  6%|▋         | 3292/51187 [00:02<00:35, 1343.66it/s]

reconstruction loss: 0.15853366255760193, sparsity loss: 0.36408475041389465, total loss: 0.5226184129714966


 13%|█▎        | 6464/51187 [00:05<00:32, 1374.35it/s]

reconstruction loss: 0.14737968146800995, sparsity loss: 0.37307772040367126, total loss: 0.52045738697052


 19%|█▉        | 9645/51187 [00:07<00:30, 1374.59it/s]

reconstruction loss: 0.15848320722579956, sparsity loss: 0.3870643079280853, total loss: 0.5455474853515625


 25%|██▍       | 12686/51187 [00:09<00:28, 1359.55it/s]

reconstruction loss: 0.14647486805915833, sparsity loss: 0.3491286337375641, total loss: 0.4956035017967224


 31%|███       | 15864/51187 [00:11<00:25, 1370.85it/s]

reconstruction loss: 0.13699713349342346, sparsity loss: 0.36083534359931946, total loss: 0.4978324770927429


 37%|███▋      | 18907/51187 [00:14<00:23, 1370.35it/s]

reconstruction loss: 0.13261285424232483, sparsity loss: 0.34070077538490295, total loss: 0.4733136296272278


 43%|████▎     | 22085/51187 [00:16<00:21, 1369.07it/s]

reconstruction loss: 0.14038366079330444, sparsity loss: 0.3364056944847107, total loss: 0.47678935527801514


 49%|████▉     | 25263/51187 [00:18<00:18, 1372.22it/s]

reconstruction loss: 0.1521209180355072, sparsity loss: 0.33793044090270996, total loss: 0.49005135893821716


 55%|█████▌    | 28300/51187 [00:21<00:16, 1379.30it/s]

reconstruction loss: 0.1327940821647644, sparsity loss: 0.34352192282676697, total loss: 0.47631600499153137


 62%|██████▏   | 31493/51187 [00:23<00:14, 1384.99it/s]

reconstruction loss: 0.15039102733135223, sparsity loss: 0.3533141613006592, total loss: 0.5037052035331726


 67%|██████▋   | 34541/51187 [00:25<00:12, 1364.43it/s]

reconstruction loss: 0.142342671751976, sparsity loss: 0.37024596333503723, total loss: 0.512588620185852


 74%|███████▎  | 37733/51187 [00:27<00:09, 1376.83it/s]

reconstruction loss: 0.14462900161743164, sparsity loss: 0.34055134654045105, total loss: 0.4851803481578827


 80%|███████▉  | 40780/51187 [00:30<00:07, 1371.49it/s]

reconstruction loss: 0.1412576287984848, sparsity loss: 0.361362099647522, total loss: 0.502619743347168


 86%|████████▌ | 43958/51187 [00:32<00:05, 1376.43it/s]

reconstruction loss: 0.1347624957561493, sparsity loss: 0.3497219383716583, total loss: 0.4844844341278076


 92%|█████████▏| 47147/51187 [00:34<00:02, 1379.65it/s]

reconstruction loss: 0.15231099724769592, sparsity loss: 0.36663371324539185, total loss: 0.5189447402954102


 98%|█████████▊| 50199/51187 [00:36<00:00, 1379.89it/s]

reconstruction loss: 0.17181840538978577, sparsity loss: 0.3624331057071686, total loss: 0.5342515110969543


100%|██████████| 51187/51187 [00:37<00:00, 1358.29it/s]


val loss on next datafile
{'reconstruction_loss': 3.498175134882331e-05, 'sparsity_loss': 9.09958751872182e-05, 'total_loss': 0.00012597762700170277}
train loss on next datafile
{'reconstruction_loss': 3.395669593010098e-05, 'sparsity_loss': 8.327080076560378e-05, 'total_loss': 0.00011722749657928944}
training on residuals/residuals_train_7.pt


  0%|          | 119/51187 [00:00<00:43, 1182.19it/s]

reconstruction loss: 0.1279817521572113, sparsity loss: 0.3415292203426361, total loss: 0.4695109724998474


  6%|▋         | 3298/51187 [00:02<00:34, 1379.50it/s]

reconstruction loss: 0.15336166322231293, sparsity loss: 0.3722476661205292, total loss: 0.5256093144416809


 13%|█▎        | 6475/51187 [00:04<00:37, 1198.77it/s]

reconstruction loss: 0.12372072041034698, sparsity loss: 0.3400476574897766, total loss: 0.4637683629989624


 19%|█▉        | 9630/51187 [00:07<00:30, 1360.01it/s]

reconstruction loss: 0.14565613865852356, sparsity loss: 0.3206835389137268, total loss: 0.46633967757225037


 25%|██▍       | 12674/51187 [00:09<00:28, 1369.88it/s]

reconstruction loss: 0.1500013768672943, sparsity loss: 0.344875305891037, total loss: 0.4948766827583313


 31%|███       | 15770/51187 [00:12<00:28, 1261.90it/s]

reconstruction loss: 0.1361503005027771, sparsity loss: 0.33416083455085754, total loss: 0.47031113505363464


 37%|███▋      | 18958/51187 [00:15<00:23, 1377.14it/s]

reconstruction loss: 0.12478259205818176, sparsity loss: 0.32257863879203796, total loss: 0.4473612308502197


 43%|████▎     | 22141/51187 [00:17<00:21, 1382.45it/s]

reconstruction loss: 0.14267472922801971, sparsity loss: 0.3518860936164856, total loss: 0.4945608377456665


 49%|████▉     | 24987/51187 [00:19<00:19, 1368.76it/s]

reconstruction loss: 0.12018192559480667, sparsity loss: 0.32071566581726074, total loss: 0.4408975839614868


 55%|█████▌    | 28284/51187 [00:22<00:16, 1420.68it/s]

reconstruction loss: 0.13604354858398438, sparsity loss: 0.3276062309741974, total loss: 0.46364977955818176


 61%|██████    | 31300/51187 [00:29<01:08, 288.85it/s] 

reconstruction loss: 0.13039447367191315, sparsity loss: 0.334159255027771, total loss: 0.46455371379852295


 67%|██████▋   | 34407/51187 [00:39<00:57, 293.17it/s]

reconstruction loss: 0.12055759131908417, sparsity loss: 0.33675867319107056, total loss: 0.4573162794113159


 74%|███████▎  | 37728/51187 [00:45<00:09, 1375.16it/s]

reconstruction loss: 0.14971186220645905, sparsity loss: 0.36435094475746155, total loss: 0.5140628218650818


 80%|███████▉  | 40773/51187 [00:47<00:07, 1379.12it/s]

reconstruction loss: 0.1414232850074768, sparsity loss: 0.3421993851661682, total loss: 0.483622670173645


 86%|████████▌ | 43954/51187 [00:50<00:05, 1355.58it/s]

reconstruction loss: 0.14146575331687927, sparsity loss: 0.3397005498409271, total loss: 0.4811663031578064


 92%|█████████▏| 47122/51187 [00:52<00:02, 1373.46it/s]

reconstruction loss: 0.12310123443603516, sparsity loss: 0.3216387629508972, total loss: 0.4447399973869324


 98%|█████████▊| 50223/51187 [00:54<00:00, 1422.30it/s]

reconstruction loss: 0.13954311609268188, sparsity loss: 0.33438268303871155, total loss: 0.47392579913139343


100%|██████████| 51187/51187 [00:55<00:00, 923.54it/s] 


val loss on next datafile
{'reconstruction_loss': 3.4632514580152926e-05, 'sparsity_loss': 8.695177920162678e-05, 'total_loss': 0.00012158429436385632}
train loss on next datafile
{'reconstruction_loss': 3.3269589184783396e-05, 'sparsity_loss': 8.499325020238757e-05, 'total_loss': 0.00011826283950358629}
training on residuals/residuals_train_4.pt


  0%|          | 36/51187 [00:00<02:25, 351.12it/s]

reconstruction loss: 0.13829226791858673, sparsity loss: 0.360306054353714, total loss: 0.4985983371734619


  6%|▋         | 3271/51187 [00:05<00:34, 1373.33it/s]

reconstruction loss: 0.12740430235862732, sparsity loss: 0.3445141613483429, total loss: 0.4719184637069702


 13%|█▎        | 6405/51187 [00:08<00:33, 1331.70it/s]

reconstruction loss: 0.141026109457016, sparsity loss: 0.35459694266319275, total loss: 0.49562305212020874


 18%|█▊        | 9417/51187 [00:18<02:23, 290.73it/s] 

reconstruction loss: 0.1508101373910904, sparsity loss: 0.3584941625595093, total loss: 0.5093042850494385


 24%|██▍       | 12532/51187 [00:28<02:11, 294.46it/s]

reconstruction loss: 0.12971973419189453, sparsity loss: 0.33490443229675293, total loss: 0.46462416648864746


 31%|███       | 15661/51187 [00:39<02:01, 293.27it/s]

reconstruction loss: 0.15751338005065918, sparsity loss: 0.35167360305786133, total loss: 0.5091869831085205


 37%|███▋      | 18797/51187 [00:50<01:50, 293.06it/s]

reconstruction loss: 0.12926270067691803, sparsity loss: 0.34265902638435364, total loss: 0.47192174196243286


 43%|████▎     | 21910/51187 [00:59<01:40, 291.65it/s] 

reconstruction loss: 0.1420777440071106, sparsity loss: 0.3384134769439697, total loss: 0.4804912209510803


 49%|████▉     | 25245/51187 [01:08<00:20, 1253.15it/s]

reconstruction loss: 0.12326431274414062, sparsity loss: 0.33641403913497925, total loss: 0.4596783518791199


 55%|█████▌    | 28257/51187 [01:10<00:23, 979.10it/s] 

reconstruction loss: 0.13435223698616028, sparsity loss: 0.3456655442714691, total loss: 0.4800177812576294


 61%|██████▏   | 31423/51187 [01:13<00:14, 1372.66it/s]

reconstruction loss: 0.15095120668411255, sparsity loss: 0.3321076035499573, total loss: 0.4830588102340698


 68%|██████▊   | 34582/51187 [01:15<00:12, 1362.18it/s]

reconstruction loss: 0.14064162969589233, sparsity loss: 0.3361400067806244, total loss: 0.4767816364765167


 74%|███████▎  | 37744/51187 [01:17<00:10, 1344.24it/s]

reconstruction loss: 0.12161773443222046, sparsity loss: 0.32999956607818604, total loss: 0.4516173005104065


 80%|███████▉  | 40785/51187 [01:20<00:07, 1377.19it/s]

reconstruction loss: 0.1345936357975006, sparsity loss: 0.33800143003463745, total loss: 0.47259506583213806


 86%|████████▌ | 43937/51187 [01:22<00:05, 1216.07it/s]

reconstruction loss: 0.13094866275787354, sparsity loss: 0.3254133462905884, total loss: 0.4563620090484619


 92%|█████████▏| 46916/51187 [01:32<00:14, 292.61it/s] 

reconstruction loss: 0.13722673058509827, sparsity loss: 0.3278186321258545, total loss: 0.46504536271095276


 98%|█████████▊| 50136/51187 [01:41<00:01, 880.11it/s] 

reconstruction loss: 0.14363601803779602, sparsity loss: 0.33517172932624817, total loss: 0.4788077473640442


100%|██████████| 51187/51187 [01:42<00:00, 500.97it/s] 


val loss on next datafile
{'reconstruction_loss': 3.3008070080541074e-05, 'sparsity_loss': 8.602474816143512e-05, 'total_loss': 0.00011903281789273024}
train loss on next datafile
{'reconstruction_loss': 3.3393425401300195e-05, 'sparsity_loss': 8.30373433418572e-05, 'total_loss': 0.0001164307682774961}
training on residuals/residuals_train_10.pt


  0%|          | 33/51187 [00:00<02:36, 327.29it/s]

reconstruction loss: 0.13821884989738464, sparsity loss: 0.3286048173904419, total loss: 0.46682366728782654


  6%|▋         | 3320/51187 [00:03<00:34, 1375.28it/s]

reconstruction loss: 0.14310956001281738, sparsity loss: 0.32850590348243713, total loss: 0.4716154634952545


 13%|█▎        | 6497/51187 [00:05<00:32, 1370.33it/s]

reconstruction loss: 0.14334961771965027, sparsity loss: 0.3341493308544159, total loss: 0.47749894857406616


 18%|█▊        | 9414/51187 [00:14<02:24, 289.19it/s] 

reconstruction loss: 0.14107871055603027, sparsity loss: 0.3159550428390503, total loss: 0.45703375339508057


 25%|██▍       | 12554/51187 [00:25<02:15, 284.51it/s]

reconstruction loss: 0.13539479672908783, sparsity loss: 0.3246653974056244, total loss: 0.460060179233551


 31%|███       | 15685/51187 [00:35<02:01, 292.49it/s]

reconstruction loss: 0.12216970324516296, sparsity loss: 0.311266154050827, total loss: 0.43343585729599


 37%|███▋      | 19010/51187 [00:39<00:23, 1372.10it/s]

reconstruction loss: 0.1311698704957962, sparsity loss: 0.32242605090141296, total loss: 0.45359593629837036


 43%|████▎     | 22048/51187 [00:41<00:21, 1372.48it/s]

reconstruction loss: 0.11995998024940491, sparsity loss: 0.31174448132514954, total loss: 0.43170446157455444


 49%|████▉     | 25219/51187 [00:43<00:18, 1373.28it/s]

reconstruction loss: 0.14098869264125824, sparsity loss: 0.3323368728160858, total loss: 0.47332555055618286


 55%|█████▌    | 28366/51187 [00:46<00:23, 990.33it/s] 

reconstruction loss: 0.12119892984628677, sparsity loss: 0.29330939054489136, total loss: 0.41450831294059753


 62%|██████▏   | 31519/51187 [00:48<00:15, 1272.09it/s]

reconstruction loss: 0.14803537726402283, sparsity loss: 0.3347192406654358, total loss: 0.4827546179294586


 68%|██████▊   | 34560/51187 [00:50<00:12, 1370.17it/s]

reconstruction loss: 0.14716945588588715, sparsity loss: 0.3161550462245941, total loss: 0.46332448720932007


 74%|███████▎  | 37737/51187 [00:53<00:09, 1359.11it/s]

reconstruction loss: 0.13778488337993622, sparsity loss: 0.3260190784931183, total loss: 0.4638039469718933


 80%|███████▉  | 40776/51187 [00:55<00:07, 1373.83it/s]

reconstruction loss: 0.14692120254039764, sparsity loss: 0.3598550856113434, total loss: 0.5067762732505798


 86%|████████▌ | 43926/51187 [01:00<00:06, 1173.29it/s]

reconstruction loss: 0.12950505316257477, sparsity loss: 0.3236548900604248, total loss: 0.4531599283218384


 92%|█████████▏| 47095/51187 [01:03<00:02, 1376.56it/s]

reconstruction loss: 0.1332034468650818, sparsity loss: 0.30902379751205444, total loss: 0.44222724437713623


 98%|█████████▊| 50253/51187 [01:12<00:00, 1032.99it/s]

reconstruction loss: 0.12087619304656982, sparsity loss: 0.3349033296108246, total loss: 0.4557795226573944


100%|██████████| 51187/51187 [01:13<00:00, 696.85it/s] 


val loss on next datafile
{'reconstruction_loss': 3.246698062866926e-05, 'sparsity_loss': 8.518736017867922e-05, 'total_loss': 0.00011765434080734848}
train loss on next datafile
{'reconstruction_loss': 3.246698062866926e-05, 'sparsity_loss': 8.518736017867922e-05, 'total_loss': 0.00011765434080734848}
training on residuals/residuals_train_12.pt


  0%|          | 32/44865 [00:00<02:20, 318.67it/s]

reconstruction loss: 0.12912636995315552, sparsity loss: 0.3334084749221802, total loss: 0.4625348448753357


  7%|▋         | 3333/44865 [00:04<00:36, 1150.23it/s]

reconstruction loss: 0.13320450484752655, sparsity loss: 0.3087809383869171, total loss: 0.44198542833328247


 15%|█▍        | 6508/44865 [00:07<00:27, 1372.78it/s]

reconstruction loss: 0.13343475759029388, sparsity loss: 0.3293408453464508, total loss: 0.4627755880355835


 21%|██▏       | 9615/44865 [00:11<00:28, 1243.04it/s]

reconstruction loss: 0.1273324191570282, sparsity loss: 0.3095618188381195, total loss: 0.4368942379951477


 28%|██▊       | 12544/44865 [00:15<01:49, 294.39it/s] 

reconstruction loss: 0.14521904289722443, sparsity loss: 0.31004098057746887, total loss: 0.4552600383758545


 35%|███▍      | 15670/44865 [00:26<01:40, 291.34it/s]

reconstruction loss: 0.1363087296485901, sparsity loss: 0.3143322765827179, total loss: 0.450641006231308


 42%|████▏     | 18962/44865 [00:36<00:22, 1133.33it/s]

reconstruction loss: 0.14996011555194855, sparsity loss: 0.341407835483551, total loss: 0.4913679361343384


 49%|████▉     | 21928/44865 [00:43<01:19, 289.11it/s] 

reconstruction loss: 0.12872245907783508, sparsity loss: 0.3141150176525116, total loss: 0.4428374767303467


 56%|█████▌    | 25054/44865 [00:54<01:08, 290.49it/s]

reconstruction loss: 0.13522443175315857, sparsity loss: 0.31479644775390625, total loss: 0.4500208795070648


 63%|██████▎   | 28174/44865 [01:05<00:57, 291.16it/s]

reconstruction loss: 0.1523103266954422, sparsity loss: 0.3552396297454834, total loss: 0.5075499415397644


 70%|██████▉   | 31402/44865 [01:12<00:13, 990.63it/s] 

reconstruction loss: 0.13108006119728088, sparsity loss: 0.33067798614501953, total loss: 0.4617580473423004


 77%|███████▋  | 34573/44865 [01:14<00:07, 1372.60it/s]

reconstruction loss: 0.14206930994987488, sparsity loss: 0.3327482342720032, total loss: 0.47481754422187805


 84%|████████▍ | 37657/44865 [01:17<00:05, 1366.15it/s]

reconstruction loss: 0.142462819814682, sparsity loss: 0.3115233778953552, total loss: 0.45398619771003723


 91%|█████████ | 40838/44865 [01:19<00:02, 1375.51it/s]

reconstruction loss: 0.1392812877893448, sparsity loss: 0.31516537070274353, total loss: 0.4544466733932495


 98%|█████████▊| 44007/44865 [01:22<00:00, 1365.93it/s]

reconstruction loss: 0.14589771628379822, sparsity loss: 0.34139484167099, total loss: 0.4872925579547882


100%|██████████| 44865/44865 [01:22<00:00, 540.99it/s] 


val loss on next datafile
{'reconstruction_loss': 3.3031229046173395e-05, 'sparsity_loss': 8.419029554352164e-05, 'total_loss': 0.00011722152354195714}
train loss on next datafile
{'reconstruction_loss': 3.2651821966283025e-05, 'sparsity_loss': 7.768556522205472e-05, 'total_loss': 0.00011033738730475306}
training on residuals/residuals_train_3.pt


  0%|          | 32/51187 [00:00<02:40, 319.60it/s]

reconstruction loss: 0.12244781106710434, sparsity loss: 0.3053560256958008, total loss: 0.4278038442134857


  6%|▋         | 3268/51187 [00:02<00:35, 1368.79it/s]

reconstruction loss: 0.1353038251399994, sparsity loss: 0.3063626289367676, total loss: 0.44166645407676697


 12%|█▏        | 6394/51187 [00:07<00:52, 845.63it/s] 

reconstruction loss: 0.15860971808433533, sparsity loss: 0.3284890651702881, total loss: 0.4870987832546234


 18%|█▊        | 9423/51187 [00:12<02:23, 291.96it/s] 

reconstruction loss: 0.14981552958488464, sparsity loss: 0.33641302585601807, total loss: 0.4862285554409027


 25%|██▍       | 12712/51187 [00:16<00:35, 1070.14it/s]

reconstruction loss: 0.14386573433876038, sparsity loss: 0.3303297758102417, total loss: 0.4741955101490021


 31%|███       | 15884/51187 [00:18<00:25, 1367.12it/s]

reconstruction loss: 0.13635480403900146, sparsity loss: 0.34028032422065735, total loss: 0.4766351282596588


 37%|███▋      | 19020/51187 [00:21<00:23, 1368.12it/s]

reconstruction loss: 0.14498265087604523, sparsity loss: 0.3052769601345062, total loss: 0.45025962591171265


 43%|████▎     | 21909/51187 [00:23<00:44, 654.93it/s] 

reconstruction loss: 0.12260135263204575, sparsity loss: 0.3241141140460968, total loss: 0.44671547412872314


 49%|████▉     | 25212/51187 [00:27<00:19, 1362.43it/s]

reconstruction loss: 0.13876332342624664, sparsity loss: 0.3250272274017334, total loss: 0.46379053592681885


 55%|█████▌    | 28156/51187 [00:31<01:16, 299.93it/s] 

reconstruction loss: 0.1378978192806244, sparsity loss: 0.3277952969074249, total loss: 0.4656931161880493


 61%|██████    | 31301/51187 [00:39<01:07, 292.46it/s] 

reconstruction loss: 0.14673778414726257, sparsity loss: 0.33688801527023315, total loss: 0.4836257994174957


 67%|██████▋   | 34400/51187 [00:48<00:28, 586.27it/s] 

reconstruction loss: 0.14918039739131927, sparsity loss: 0.32226014137268066, total loss: 0.47144055366516113


 74%|███████▎  | 37744/51187 [00:53<00:09, 1399.39it/s]

reconstruction loss: 0.1369977593421936, sparsity loss: 0.3069651424884796, total loss: 0.4439629018306732


 79%|███████▉  | 40662/51187 [01:01<00:35, 297.13it/s] 

reconstruction loss: 0.13825643062591553, sparsity loss: 0.31393444538116455, total loss: 0.4521908760070801


 86%|████████▌ | 43810/51187 [01:12<00:25, 292.70it/s]

reconstruction loss: 0.11847741901874542, sparsity loss: 0.30235859751701355, total loss: 0.42083603143692017


 92%|█████████▏| 47136/51187 [01:17<00:02, 1358.34it/s]

reconstruction loss: 0.13737265765666962, sparsity loss: 0.3250424861907959, total loss: 0.4624151587486267


 98%|█████████▊| 50160/51187 [01:19<00:00, 1361.26it/s]

reconstruction loss: 0.12432487308979034, sparsity loss: 0.29498764872550964, total loss: 0.4193125367164612


100%|██████████| 51187/51187 [01:22<00:00, 618.71it/s] 


val loss on next datafile
{'reconstruction_loss': 3.475122980307788e-05, 'sparsity_loss': 8.267601765692234e-05, 'total_loss': 0.00011742724711075425}
train loss on next datafile
{'reconstruction_loss': 3.445353731513023e-05, 'sparsity_loss': 8.202111115679145e-05, 'total_loss': 0.00011647464707493782}
training on residuals/residuals_train_2.pt


  0%|          | 62/51187 [00:00<01:23, 615.74it/s]

reconstruction loss: 0.1379564255475998, sparsity loss: 0.3303033411502838, total loss: 0.4682597517967224


  7%|▋         | 3357/51187 [00:02<00:35, 1358.41it/s]

reconstruction loss: 0.13871322572231293, sparsity loss: 0.3117443323135376, total loss: 0.4504575729370117


 12%|█▏        | 6390/51187 [00:04<00:32, 1375.27it/s]

reconstruction loss: 0.11868011206388474, sparsity loss: 0.3224032521247864, total loss: 0.4410833716392517


 19%|█▊        | 9571/51187 [00:07<00:30, 1371.90it/s]

reconstruction loss: 0.13444605469703674, sparsity loss: 0.3145444095134735, total loss: 0.44899046421051025


 25%|██▍       | 12689/51187 [00:09<00:27, 1406.05it/s]

reconstruction loss: 0.12339460104703903, sparsity loss: 0.2901325225830078, total loss: 0.41352713108062744


 31%|███       | 15869/51187 [00:11<00:25, 1363.96it/s]

reconstruction loss: 0.14419607818126678, sparsity loss: 0.3289312720298767, total loss: 0.4731273651123047


 37%|███▋      | 18867/51187 [00:14<00:40, 804.37it/s] 

reconstruction loss: 0.15836769342422485, sparsity loss: 0.31977689266204834, total loss: 0.4781445860862732


 43%|████▎     | 22031/51187 [00:16<00:21, 1367.84it/s]

reconstruction loss: 0.13976706564426422, sparsity loss: 0.3141571283340454, total loss: 0.45392417907714844


 49%|████▉     | 25105/51187 [00:18<00:21, 1193.17it/s]

reconstruction loss: 0.12381681799888611, sparsity loss: 0.32066959142684937, total loss: 0.4444864094257355


 55%|█████▌    | 28358/51187 [00:21<00:16, 1345.89it/s]

reconstruction loss: 0.12922999262809753, sparsity loss: 0.3114686608314514, total loss: 0.44069865345954895


 62%|██████▏   | 31525/51187 [00:23<00:14, 1376.86it/s]

reconstruction loss: 0.13253596425056458, sparsity loss: 0.3073825538158417, total loss: 0.43991851806640625


 68%|██████▊   | 34562/51187 [00:25<00:12, 1360.04it/s]

reconstruction loss: 0.13707831501960754, sparsity loss: 0.3196668326854706, total loss: 0.4567451477050781


 74%|███████▎  | 37735/51187 [00:28<00:09, 1366.46it/s]

reconstruction loss: 0.13928449153900146, sparsity loss: 0.2994925081729889, total loss: 0.43877699971199036


 80%|███████▉  | 40768/51187 [00:30<00:07, 1367.28it/s]

reconstruction loss: 0.13628607988357544, sparsity loss: 0.3175940215587616, total loss: 0.45388010144233704


 86%|████████▌ | 43981/51187 [00:32<00:05, 1370.62it/s]

reconstruction loss: 0.13342788815498352, sparsity loss: 0.3121159076690674, total loss: 0.4455437958240509


 92%|█████████▏| 47149/51187 [00:35<00:02, 1364.42it/s]

reconstruction loss: 0.12980934977531433, sparsity loss: 0.3260285258293152, total loss: 0.4558378756046295


 98%|█████████▊| 50054/51187 [00:37<00:00, 1373.34it/s]

reconstruction loss: 0.13535046577453613, sparsity loss: 0.29970428347587585, total loss: 0.435054749250412


100%|██████████| 51187/51187 [00:39<00:00, 1310.28it/s]


val loss on next datafile
{'reconstruction_loss': 3.430781071074307e-05, 'sparsity_loss': 8.133528102189302e-05, 'total_loss': 0.00011564309149980544}
train loss on next datafile
{'reconstruction_loss': 3.4142603049986066e-05, 'sparsity_loss': 8.156498987227678e-05, 'total_loss': 0.0001157075921073556}
training on residuals/residuals_train_1.pt


  0%|          | 116/51187 [00:00<00:44, 1157.38it/s]

reconstruction loss: 0.15103298425674438, sparsity loss: 0.33418211340904236, total loss: 0.48521509766578674


  6%|▋         | 3294/51187 [00:02<00:35, 1360.85it/s]

reconstruction loss: 0.13114579021930695, sparsity loss: 0.295004278421402, total loss: 0.4261500835418701


 13%|█▎        | 6473/51187 [00:04<00:32, 1372.45it/s]

reconstruction loss: 0.1337544023990631, sparsity loss: 0.30830854177474976, total loss: 0.44206294417381287


 19%|█▊        | 9547/51187 [00:06<00:29, 1395.51it/s]

reconstruction loss: 0.14747709035873413, sparsity loss: 0.3237672746181488, total loss: 0.47124436497688293


 25%|██▍       | 12733/51187 [00:09<00:27, 1380.00it/s]

reconstruction loss: 0.14098510146141052, sparsity loss: 0.3142784535884857, total loss: 0.45526355504989624


 31%|███       | 15775/51187 [00:12<00:26, 1326.83it/s]

reconstruction loss: 0.14087757468223572, sparsity loss: 0.31573501229286194, total loss: 0.45661258697509766


 37%|███▋      | 18969/51187 [00:14<00:23, 1376.58it/s]

reconstruction loss: 0.1417439728975296, sparsity loss: 0.32161250710487366, total loss: 0.46335649490356445


 43%|████▎     | 22133/51187 [00:17<00:21, 1360.03it/s]

reconstruction loss: 0.1393691599369049, sparsity loss: 0.29517653584480286, total loss: 0.43454569578170776


 49%|████▉     | 25169/51187 [00:19<00:18, 1378.08it/s]

reconstruction loss: 0.1440589427947998, sparsity loss: 0.31492185592651367, total loss: 0.4589807987213135


 55%|█████▌    | 28356/51187 [00:21<00:16, 1381.85it/s]

reconstruction loss: 0.13877293467521667, sparsity loss: 0.33481112122535706, total loss: 0.47358405590057373


 61%|██████    | 31301/51187 [00:25<01:03, 312.69it/s] 

reconstruction loss: 0.1254165768623352, sparsity loss: 0.3056807219982147, total loss: 0.4310972988605499


 68%|██████▊   | 34626/51187 [00:28<00:12, 1379.54it/s]

reconstruction loss: 0.1434980034828186, sparsity loss: 0.3131033778190613, total loss: 0.4566013813018799


 73%|███████▎  | 37509/51187 [00:31<00:21, 640.10it/s] 

reconstruction loss: 0.13773036003112793, sparsity loss: 0.3080960214138031, total loss: 0.44582638144493103


 79%|███████▉  | 40666/51187 [00:42<00:36, 285.17it/s]

reconstruction loss: 0.13039928674697876, sparsity loss: 0.307926744222641, total loss: 0.43832603096961975


 86%|████████▌ | 43962/51187 [00:47<00:05, 1373.45it/s]

reconstruction loss: 0.13522563874721527, sparsity loss: 0.31080785393714905, total loss: 0.4460334777832031


 92%|█████████▏| 47148/51187 [00:49<00:02, 1381.96it/s]

reconstruction loss: 0.11482377350330353, sparsity loss: 0.2973267436027527, total loss: 0.412150502204895


 98%|█████████▊| 50162/51187 [00:52<00:00, 1101.90it/s]

reconstruction loss: 0.14618201553821564, sparsity loss: 0.3348972797393799, total loss: 0.4810792803764343


100%|██████████| 51187/51187 [00:53<00:00, 963.42it/s] 


val loss on next datafile
{'reconstruction_loss': 3.6146028200164435e-05, 'sparsity_loss': 7.983175432309508e-05, 'total_loss': 0.0001159777818247676}
train loss on next datafile
{'reconstruction_loss': 3.61798289231956e-05, 'sparsity_loss': 7.866473682224751e-05, 'total_loss': 0.0001148445657454431}
training on residuals/residuals_train_8.pt


  0%|          | 132/51187 [00:00<00:38, 1318.17it/s]

reconstruction loss: 0.16853877902030945, sparsity loss: 0.33714160323143005, total loss: 0.5056803822517395


  6%|▋         | 3295/51187 [00:02<00:35, 1340.32it/s]

reconstruction loss: 0.14234957098960876, sparsity loss: 0.3132641613483429, total loss: 0.45561373233795166


 13%|█▎        | 6449/51187 [00:05<00:38, 1170.95it/s]

reconstruction loss: 0.1380549520254135, sparsity loss: 0.32066959142684937, total loss: 0.4587245583534241


 19%|█▉        | 9666/51187 [00:07<00:28, 1477.29it/s]

reconstruction loss: 0.1286112368106842, sparsity loss: 0.2910202443599701, total loss: 0.4196314811706543


 25%|██▍       | 12735/51187 [00:09<00:27, 1375.11it/s]

reconstruction loss: 0.15132972598075867, sparsity loss: 0.3224027454853058, total loss: 0.47373247146606445


 31%|███       | 15774/51187 [00:11<00:25, 1362.58it/s]

reconstruction loss: 0.14015008509159088, sparsity loss: 0.32276514172554016, total loss: 0.46291524171829224


 37%|███▋      | 18956/51187 [00:14<00:23, 1372.67it/s]

reconstruction loss: 0.1508789360523224, sparsity loss: 0.33824577927589417, total loss: 0.48912471532821655


 43%|████▎     | 22124/51187 [00:16<00:21, 1367.05it/s]

reconstruction loss: 0.12804138660430908, sparsity loss: 0.31359612941741943, total loss: 0.4416375160217285


 49%|████▉     | 25207/51187 [00:18<00:18, 1373.28it/s]

reconstruction loss: 0.14115357398986816, sparsity loss: 0.3147660493850708, total loss: 0.45591962337493896


 55%|█████▌    | 28376/51187 [00:21<00:16, 1356.88it/s]

reconstruction loss: 0.14442141354084015, sparsity loss: 0.3061245381832123, total loss: 0.4505459666252136


 61%|██████▏   | 31409/51187 [00:23<00:14, 1371.00it/s]

reconstruction loss: 0.15437325835227966, sparsity loss: 0.3066624402999878, total loss: 0.46103569865226746


 68%|██████▊   | 34582/51187 [00:25<00:12, 1359.99it/s]

reconstruction loss: 0.14028510451316833, sparsity loss: 0.3142143785953522, total loss: 0.4544994831085205


 74%|███████▎  | 37669/51187 [00:27<00:09, 1446.59it/s]

reconstruction loss: 0.13257372379302979, sparsity loss: 0.30270659923553467, total loss: 0.43528032302856445


 80%|███████▉  | 40873/51187 [00:30<00:07, 1377.31it/s]

reconstruction loss: 0.14289000630378723, sparsity loss: 0.2833559215068817, total loss: 0.42624592781066895


 86%|████████▌ | 43905/51187 [00:32<00:05, 1373.89it/s]

reconstruction loss: 0.1435595005750656, sparsity loss: 0.32171711325645447, total loss: 0.4652765989303589


 92%|█████████▏| 47113/51187 [00:35<00:02, 1364.85it/s]

reconstruction loss: 0.14681151509284973, sparsity loss: 0.3422892987728119, total loss: 0.4891008138656616


 98%|█████████▊| 50141/51187 [00:38<00:00, 1369.17it/s]

reconstruction loss: 0.13037773966789246, sparsity loss: 0.31415313482284546, total loss: 0.4445308744907379


100%|██████████| 51187/51187 [00:38<00:00, 1318.85it/s]


val loss on next datafile
{'reconstruction_loss': 3.5885677440091965e-05, 'sparsity_loss': 7.910593971610069e-05, 'total_loss': 0.00011499161645770073}
train loss on next datafile
{'reconstruction_loss': 3.7881667027249934e-05, 'sparsity_loss': 7.582402788102626e-05, 'total_loss': 0.00011370569560676813}
training on residuals/residuals_train_0.pt


  0%|          | 31/51187 [00:00<02:50, 300.16it/s]

reconstruction loss: 0.14549310505390167, sparsity loss: 0.3106359839439392, total loss: 0.4561290740966797


  6%|▋         | 3278/51187 [00:02<00:33, 1413.42it/s]

reconstruction loss: 0.15515325963497162, sparsity loss: 0.31161531805992126, total loss: 0.4667685627937317


 13%|█▎        | 6469/51187 [00:05<00:32, 1382.58it/s]

reconstruction loss: 0.1428813934326172, sparsity loss: 0.30393823981285095, total loss: 0.44681963324546814


 18%|█▊        | 9406/51187 [00:07<00:32, 1294.00it/s]

reconstruction loss: 0.13864558935165405, sparsity loss: 0.31387749314308167, total loss: 0.4525230824947357


 25%|██▍       | 12668/51187 [00:15<00:36, 1064.34it/s]

reconstruction loss: 0.1259119212627411, sparsity loss: 0.31885194778442383, total loss: 0.4447638690471649


 31%|███       | 15791/51187 [00:17<00:26, 1313.52it/s]

reconstruction loss: 0.12516085803508759, sparsity loss: 0.3014317452907562, total loss: 0.4265925884246826


 37%|███▋      | 18973/51187 [00:20<00:24, 1317.71it/s]

reconstruction loss: 0.13952922821044922, sparsity loss: 0.2983294725418091, total loss: 0.4378587007522583


 43%|████▎     | 22146/51187 [00:22<00:21, 1376.78it/s]

reconstruction loss: 0.15112179517745972, sparsity loss: 0.3141702711582184, total loss: 0.4652920663356781


 49%|████▉     | 25191/51187 [00:25<00:18, 1376.73it/s]

reconstruction loss: 0.15373697876930237, sparsity loss: 0.32707127928733826, total loss: 0.4808082580566406


 55%|█████▌    | 28370/51187 [00:27<00:16, 1361.78it/s]

reconstruction loss: 0.15470445156097412, sparsity loss: 0.3290603756904602, total loss: 0.4837648272514343


 61%|██████    | 31291/51187 [00:33<01:08, 291.16it/s] 

reconstruction loss: 0.13831284642219543, sparsity loss: 0.3016214072704315, total loss: 0.43993425369262695


 67%|██████▋   | 34408/51187 [00:44<00:57, 291.48it/s]

reconstruction loss: 0.13166821002960205, sparsity loss: 0.3200798034667969, total loss: 0.4517480134963989


 74%|███████▎  | 37728/51187 [00:49<00:18, 722.16it/s] 

reconstruction loss: 0.1407952904701233, sparsity loss: 0.315878689289093, total loss: 0.4566739797592163


 80%|███████▉  | 40898/51187 [00:52<00:07, 1371.99it/s]

reconstruction loss: 0.12280377000570297, sparsity loss: 0.28467097878456116, total loss: 0.4074747562408447


 86%|████████▌ | 43972/51187 [00:54<00:05, 1369.74it/s]

reconstruction loss: 0.13646620512008667, sparsity loss: 0.3084108531475067, total loss: 0.4448770582675934


 92%|█████████▏| 47148/51187 [00:56<00:02, 1375.54it/s]

reconstruction loss: 0.13878975808620453, sparsity loss: 0.32174068689346313, total loss: 0.46053045988082886


 98%|█████████▊| 50195/51187 [00:58<00:00, 1379.16it/s]

reconstruction loss: 0.1343926638364792, sparsity loss: 0.30524253845214844, total loss: 0.4396352171897888


100%|██████████| 51187/51187 [00:59<00:00, 859.50it/s] 


val loss on next datafile
{'reconstruction_loss': 3.5493850824423136e-05, 'sparsity_loss': 7.946036662906408e-05, 'total_loss': 0.00011495421780273319}
train loss on next datafile
{'reconstruction_loss': 3.5142360255122185e-05, 'sparsity_loss': 7.917671324685215e-05, 'total_loss': 0.00011431907396763563}
training on residuals/residuals_train_9.pt


  0%|          | 133/51187 [00:00<00:38, 1321.13it/s]

reconstruction loss: 0.14096271991729736, sparsity loss: 0.2956383228302002, total loss: 0.43660104274749756


  6%|▋         | 3294/51187 [00:02<00:35, 1353.27it/s]

reconstruction loss: 0.13769683241844177, sparsity loss: 0.30020350217819214, total loss: 0.4379003345966339


 13%|█▎        | 6431/51187 [00:05<00:33, 1355.63it/s]

reconstruction loss: 0.15452216565608978, sparsity loss: 0.314414918422699, total loss: 0.46893709897994995


 19%|█▉        | 9602/51187 [00:08<00:30, 1349.72it/s]

reconstruction loss: 0.14784660935401917, sparsity loss: 0.310860276222229, total loss: 0.45870688557624817


 25%|██▍       | 12558/51187 [00:11<01:13, 523.51it/s] 

reconstruction loss: 0.13326236605644226, sparsity loss: 0.28845179080963135, total loss: 0.4217141568660736


 31%|███       | 15896/51187 [00:14<00:27, 1280.17it/s]

reconstruction loss: 0.14545215666294098, sparsity loss: 0.31411343812942505, total loss: 0.45956557989120483


 37%|███▋      | 18928/51187 [00:16<00:22, 1444.39it/s]

reconstruction loss: 0.14991217851638794, sparsity loss: 0.31763675808906555, total loss: 0.4675489366054535


 43%|████▎     | 22130/51187 [00:18<00:21, 1372.08it/s]

reconstruction loss: 0.1410212516784668, sparsity loss: 0.3140268921852112, total loss: 0.455048143863678


 49%|████▉     | 25053/51187 [00:28<01:15, 345.85it/s] 

reconstruction loss: 0.14893311262130737, sparsity loss: 0.30684977769851685, total loss: 0.4557828903198242


 55%|█████▌    | 28390/51187 [00:32<00:16, 1370.86it/s]

reconstruction loss: 0.14933320879936218, sparsity loss: 0.31526127457618713, total loss: 0.4645944833755493


 61%|██████▏   | 31431/51187 [00:34<00:14, 1364.39it/s]

reconstruction loss: 0.1618022620677948, sparsity loss: 0.30282026529312134, total loss: 0.46462252736091614


 68%|██████▊   | 34610/51187 [00:36<00:12, 1375.51it/s]

reconstruction loss: 0.13614891469478607, sparsity loss: 0.28661471605300903, total loss: 0.4227636456489563


 74%|███████▎  | 37649/51187 [00:38<00:09, 1369.05it/s]

reconstruction loss: 0.13501882553100586, sparsity loss: 0.3105904459953308, total loss: 0.44560927152633667


 80%|███████▉  | 40769/51187 [00:41<00:07, 1378.84it/s]

reconstruction loss: 0.14330117404460907, sparsity loss: 0.32285189628601074, total loss: 0.466153085231781


 86%|████████▌ | 44009/51187 [00:43<00:05, 1422.34it/s]

reconstruction loss: 0.14814700186252594, sparsity loss: 0.3118644654750824, total loss: 0.46001148223876953


 92%|█████████▏| 47061/51187 [00:45<00:03, 1363.48it/s]

reconstruction loss: 0.16866004467010498, sparsity loss: 0.3251166343688965, total loss: 0.49377667903900146


 98%|█████████▊| 50245/51187 [00:47<00:00, 1373.90it/s]

reconstruction loss: 0.14468492567539215, sparsity loss: 0.287036657333374, total loss: 0.431721568107605


100%|██████████| 51187/51187 [00:48<00:00, 1052.30it/s]
  0%|          | 0/60000 [00:00<?, ?it/s]


NameError: name 'get_batch' is not defined

In [10]:
# Save the model weights
torch.save(sae.state_dict(), 'sae_model_weights.pth')
