In [2]:
%load_ext autoreload
%autoreload 2

In [11]:
import sys
sys.path.append('../../')

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from structs.models import CIFAR100Model


import os 
path = '/Volumes/Ayush_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 50
# device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 'cpu'
print(f"Using device: {device}")

Using device: cpu


# Train SAEs

In [4]:
from tqdm import tqdm
from structs.models import EnhancedSAE, SimpleSAE
import torch.optim as optim

def train_sae(train_loader, input_dim, hidden_dim, device):
    model = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim).to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for activations in tqdm(train_loader, desc="Processing Activations"):
            activations = activations.to(device)
            optimizer.zero_grad()
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if epoch % 5 == 0:  # Save every 5 epochs
            checkpoint_path = os.path.join(prefix, 'embeddings', f'sae_depth_model_ckpt_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, checkpoint_path)


        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")
    
    return model

def test_sae(model, test_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for activations in tqdm(test_loader):
            activations = activations.to(device)
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            total_loss += loss.item()
    
    print(f"Test Loss: {total_loss/len(test_loader):.4f}")
    
    return total_loss/len(test_loader)

In [17]:
data_name = 'laion'
def infinite_saes(model, depth, custom_depth=10, dataset=None): 
    # the model provides the data 
    if depth > custom_depth: 
        print('Finished Training up to depth ', depth)
        return model

    # get the model weights 
    if dataset is None:
        print('Getting the model weights')
        dataset = model.decoder.weight.T.detach().clone()
    else:
        print('Getting the dataset weights')
        dataset = dataset

    #train and test set 

    train_weights, test_weights = train_test_split(dataset, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_weights, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_weights, batch_size=batch_size, shuffle=False)
    
    # get the input and hidden dimensions
    input_dim = dataset.shape[1]
    hidden_dim = max(int(dataset.shape[0] / 2), 2304)  # Adjust as needed

    print('Creating SAE with input dim:', input_dim, 'and hidden dim:', hidden_dim)
    # create the model
    sae_model = train_sae(train_loader, input_dim, hidden_dim, device)
    loss = test_sae(sae_model, test_loader, device)

    del model 
    
    with open('log_weights.txt', 'a') as f:
        f.write(f"Data: {data_name} Depth: {depth}, Loss: {loss}\n")



    model_path = os.path.join(prefix, 'embeddings', data_name, f'sae_depth_{depth}_decoder.pth')
    torch.save(sae_model.state_dict(), model_path)
    print(f"SAE model saved to {model_path}")

    # Recursively call the function for the next depth
    return infinite_saes(sae_model, depth + 1, custom_depth=custom_depth)

- layer3 activations.shape = torch.Size([60000, 16384])
- layer4 activations.shape = 

In [15]:
sae_weights = torch.load(os.path.join(path, 'laion', 'sae_depth_1_weights.pt'), map_location=device)
sae_weights.shape

torch.Size([65536, 1024])

In [18]:
depth_1 = torch.load(os.path.join(path, 'embeddings', 'laion', 'sae_depth_2_decoder.pth'), map_location=device)
depth_1['decoder.weight'].shape

torch.Size([1024, 16384])

In [19]:
infinite_saes(None, 3, custom_depth=10, dataset=depth_1['decoder.weight'].T.detach().clone())

Getting the dataset weights
Creating SAE with input dim: 1024 and hidden dim: 8192


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.13it/s]


Epoch [1/50], Loss: 18.7885


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 32.02it/s]


Epoch [2/50], Loss: 0.0526


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 33.52it/s]


Epoch [3/50], Loss: 0.0134


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 33.27it/s]


Epoch [4/50], Loss: 0.0070


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 33.33it/s]


Epoch [5/50], Loss: 0.0044


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 33.66it/s]


Epoch [6/50], Loss: 0.0012


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 32.17it/s]


Epoch [7/50], Loss: 0.0005


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.75it/s]


Epoch [8/50], Loss: 0.0002


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 27.74it/s]


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 26.50it/s]


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 30.57it/s]


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 32.48it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.47it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.69it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 25.73it/s]


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:08<00:00, 24.71it/s]


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.11it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.90it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.64it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:08<00:00, 24.21it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 30.78it/s]


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 30.34it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.29it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 26.56it/s]


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.58it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.10it/s]


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 25.69it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 32.57it/s]


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 30.15it/s]


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.33it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.36it/s]


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.38it/s]


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.57it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 30.62it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.56it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.06it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.32it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 28.86it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 33.69it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 31.33it/s]


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 29.28it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 32.63it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:05<00:00, 34.64it/s]


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.47it/s]


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:07<00:00, 27.59it/s]


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.69it/s]


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 32.91it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 29.34it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 30.74it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 205/205 [00:06<00:00, 33.28it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 52/52 [00:00<00:00, 194.76it/s]


Test Loss: 0.0021
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_3_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 4096


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 66.09it/s]


Epoch [1/50], Loss: 21.7827


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 67.72it/s]


Epoch [2/50], Loss: 0.0102


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 69.02it/s]


Epoch [3/50], Loss: 0.0010


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 74.21it/s]


Epoch [4/50], Loss: 0.0001


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 69.06it/s]


Epoch [5/50], Loss: 0.0001


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 71.14it/s]


Epoch [6/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 71.43it/s]


Epoch [7/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 71.52it/s]


Epoch [8/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 70.41it/s]


Epoch [9/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 72.45it/s]


Epoch [10/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 65.04it/s]


Epoch [11/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 67.59it/s]


Epoch [12/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 69.25it/s]


Epoch [13/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 75.10it/s]


Epoch [14/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 72.31it/s]


Epoch [15/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 70.43it/s]


Epoch [16/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 61.39it/s]


Epoch [17/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 71.72it/s]


Epoch [18/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 69.18it/s]


Epoch [19/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 70.65it/s]


Epoch [20/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 59.66it/s]


Epoch [21/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 62.42it/s]


Epoch [22/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 62.58it/s]


Epoch [23/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 62.74it/s]


Epoch [24/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 65.35it/s]


Epoch [25/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 67.91it/s]


Epoch [26/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 63.97it/s]


Epoch [27/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 58.01it/s]


Epoch [28/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 58.62it/s]


Epoch [29/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 66.53it/s]


Epoch [30/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 65.08it/s]


Epoch [31/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 63.22it/s]


Epoch [32/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 59.91it/s]


Epoch [33/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 57.99it/s]


Epoch [34/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 65.57it/s]


Epoch [35/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:02<00:00, 44.80it/s]


Epoch [36/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 54.13it/s]


Epoch [37/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 65.08it/s]


Epoch [38/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 60.58it/s]


Epoch [39/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 60.93it/s]


Epoch [40/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 66.52it/s]


Epoch [41/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 60.20it/s]


Epoch [42/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 72.43it/s]


Epoch [43/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 71.46it/s]


Epoch [44/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 68.19it/s]


Epoch [45/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 69.96it/s]


Epoch [46/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 68.35it/s]


Epoch [47/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 70.21it/s]


Epoch [48/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 70.47it/s]


Epoch [49/50], Loss: 0.0000


Processing Activations: 100%|██████████| 103/103 [00:01<00:00, 69.60it/s]


Epoch [50/50], Loss: 0.0000


100%|██████████| 26/26 [00:00<00:00, 359.08it/s]


Test Loss: 0.0133
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_4_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 2304


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 129.39it/s]


Epoch [1/50], Loss: 26.8916


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 102.94it/s]


Epoch [2/50], Loss: 0.0194


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 111.68it/s]


Epoch [3/50], Loss: 0.0032


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 119.84it/s]


Epoch [4/50], Loss: 0.0005


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 121.06it/s]


Epoch [5/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 110.16it/s]


Epoch [6/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 113.07it/s]


Epoch [7/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 129.46it/s]


Epoch [8/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 90.50it/s] 


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 92.41it/s] 


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 120.98it/s]


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 125.72it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 121.45it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 125.36it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 122.44it/s]


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 113.04it/s]


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 121.25it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 129.62it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 131.77it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 111.65it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 130.14it/s]


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 110.42it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 108.63it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 115.93it/s]


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 126.43it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 128.26it/s]


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 126.87it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 132.30it/s]


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 100.32it/s]


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 115.14it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 133.16it/s]


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 121.21it/s]


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 123.32it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 131.64it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 132.63it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 120.71it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 120.67it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 122.00it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 130.34it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 101.83it/s]


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 115.15it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 116.13it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 109.70it/s]


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 100.67it/s]


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 111.31it/s]


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 99.52it/s] 


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 117.02it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 118.24it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 106.45it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 52/52 [00:00<00:00, 110.73it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 13/13 [00:00<00:00, 339.07it/s]


Test Loss: 0.0321
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_5_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 2304


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.69it/s]


Epoch [1/50], Loss: 35.5852


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 114.01it/s]


Epoch [2/50], Loss: 1.4510


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 120.12it/s]


Epoch [3/50], Loss: 0.1792


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 124.46it/s]


Epoch [4/50], Loss: 0.0404


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 124.39it/s]


Epoch [5/50], Loss: 0.0097


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 131.13it/s]


Epoch [6/50], Loss: 0.0018


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 104.87it/s]


Epoch [7/50], Loss: 0.0003


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 100.41it/s]


Epoch [8/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 88.86it/s]


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 109.89it/s]


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 98.16it/s] 


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.61it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 117.53it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 114.31it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 92.09it/s] 


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 109.52it/s]


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 107.74it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 86.27it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.56it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.47it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 82.09it/s] 


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.62it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.89it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 108.06it/s]


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 118.14it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.38it/s]


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.70it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 96.69it/s] 


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.23it/s]


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.96it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.05it/s] 


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 98.04it/s] 


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.56it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.76it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 120.21it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 104.32it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.95it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 73.32it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 91.36it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 109.63it/s]


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 88.44it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 96.76it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.23it/s] 


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 106.44it/s]


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.35it/s] 


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 94.48it/s]


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 108.97it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 104.43it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.65it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 85.18it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 8/8 [00:00<00:00, 777.14it/s]


Test Loss: 0.3446
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_6_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 2304


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 80.52it/s]


Epoch [1/50], Loss: 43.2468


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 86.37it/s]


Epoch [2/50], Loss: 0.8703


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 87.06it/s]


Epoch [3/50], Loss: 0.0858


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 84.88it/s] 


Epoch [4/50], Loss: 0.0200


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 103.17it/s]


Epoch [5/50], Loss: 0.0045


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 100.72it/s]


Epoch [6/50], Loss: 0.0008


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 106.85it/s]


Epoch [7/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.36it/s]


Epoch [8/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 99.04it/s] 


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 58.08it/s]


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 118.76it/s]


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.47it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.83it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 108.91it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 107.28it/s]


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 124.43it/s]


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 75.94it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 75.79it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 80.23it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 78.68it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 73.07it/s]


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 107.40it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 78.80it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.28it/s] 


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 106.97it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.18it/s] 


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 96.23it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 90.71it/s]


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 85.14it/s] 


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 63.81it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 70.32it/s]


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 80.94it/s]


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 104.71it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 109.14it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 105.22it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 120.91it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 107.80it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.48it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.12it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 99.71it/s] 


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 100.15it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 103.72it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 106.93it/s]


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 104.66it/s]


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 105.28it/s]


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 121.14it/s]


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 120.37it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 114.27it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 130.84it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.91it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 8/8 [00:00<00:00, 581.31it/s]


Test Loss: 0.3399
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_7_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 2304


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.56it/s]


Epoch [1/50], Loss: 35.5371


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 117.86it/s]


Epoch [2/50], Loss: 1.3452


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 111.99it/s]


Epoch [3/50], Loss: 0.1959


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 134.92it/s]


Epoch [4/50], Loss: 0.0467


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 130.34it/s]


Epoch [5/50], Loss: 0.0118


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 106.76it/s]


Epoch [6/50], Loss: 0.0025


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 127.90it/s]


Epoch [7/50], Loss: 0.0005


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.24it/s]


Epoch [8/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.72it/s]


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 139.52it/s]


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 85.93it/s] 


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 148.55it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.21it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 151.70it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 147.15it/s]


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 150.89it/s]


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 131.85it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 141.35it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 144.92it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.68it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 130.92it/s]


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 109.13it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.07it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 138.51it/s]


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 121.14it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 123.70it/s]


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 117.32it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 116.55it/s]


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 109.49it/s]


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 124.80it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.62it/s]


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.77it/s]


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 114.88it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 125.60it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.46it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 116.59it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 126.41it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 127.03it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.72it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.57it/s]


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 114.33it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 125.46it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 126.27it/s]


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.61it/s]


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.20it/s]


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.17it/s]


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 127.90it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 130.90it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 126.82it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 134.51it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 8/8 [00:00<00:00, 691.74it/s]


Test Loss: 0.4120
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_8_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 2304


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.27it/s]


Epoch [1/50], Loss: 42.0404


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.69it/s]


Epoch [2/50], Loss: 0.8878


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.74it/s]


Epoch [3/50], Loss: 0.0939


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.85it/s]


Epoch [4/50], Loss: 0.0216


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.21it/s]


Epoch [5/50], Loss: 0.0048


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 117.63it/s]


Epoch [6/50], Loss: 0.0008


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.50it/s]


Epoch [7/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.37it/s]


Epoch [8/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.25it/s]


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.69it/s]


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 112.09it/s]


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.31it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 123.66it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 132.05it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 134.79it/s]


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 132.75it/s]


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.98it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 131.39it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 136.05it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 116.68it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.12it/s]


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.13it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.47it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 122.95it/s]


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.78it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.71it/s]


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.96it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.02it/s]


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.60it/s]


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 101.45it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 103.41it/s]


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 125.73it/s]


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.48it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 118.80it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.36it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 117.70it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.39it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 130.11it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 127.11it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 128.13it/s]


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 115.05it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 122.31it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 138.26it/s]


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.47it/s] 


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 132.40it/s]


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 118.32it/s]


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 124.10it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 122.20it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 129.09it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.23it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 8/8 [00:00<00:00, 739.02it/s]


Test Loss: 0.3616
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_9_decoder.pth
Getting the model weights
Creating SAE with input dim: 1024 and hidden dim: 2304


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 135.68it/s]


Epoch [1/50], Loss: 35.8391


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 120.45it/s]


Epoch [2/50], Loss: 1.3075


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 136.76it/s]


Epoch [3/50], Loss: 0.1908


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 92.87it/s] 


Epoch [4/50], Loss: 0.0456


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.66it/s]


Epoch [5/50], Loss: 0.0115


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 132.79it/s]


Epoch [6/50], Loss: 0.0024


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 135.24it/s]


Epoch [7/50], Loss: 0.0005


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 98.79it/s] 


Epoch [8/50], Loss: 0.0002


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 38.71it/s]


Epoch [9/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 56.08it/s]


Epoch [10/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 94.69it/s]


Epoch [11/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 102.10it/s]


Epoch [12/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 77.12it/s]


Epoch [13/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.07it/s]


Epoch [14/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 106.12it/s]


Epoch [15/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 89.75it/s] 


Epoch [16/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 81.57it/s]


Epoch [17/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 107.87it/s]


Epoch [18/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 104.86it/s]


Epoch [19/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.84it/s]


Epoch [20/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 99.80it/s]


Epoch [21/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 47.80it/s]


Epoch [22/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 84.57it/s]


Epoch [23/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 65.76it/s]


Epoch [24/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 97.95it/s]


Epoch [25/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.09it/s]


Epoch [26/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 125.41it/s]


Epoch [27/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 103.24it/s]


Epoch [28/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.25it/s]


Epoch [29/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 81.66it/s]


Epoch [30/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 99.50it/s] 


Epoch [31/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 102.51it/s]


Epoch [32/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.44it/s]


Epoch [33/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 93.05it/s]


Epoch [34/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 113.30it/s]


Epoch [35/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 82.81it/s]


Epoch [36/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 101.82it/s]


Epoch [37/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 119.18it/s]


Epoch [38/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 131.28it/s]


Epoch [39/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 120.68it/s]


Epoch [40/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 122.25it/s]


Epoch [41/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 133.38it/s]


Epoch [42/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 86.23it/s] 


Epoch [43/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 100.78it/s]


Epoch [44/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 100.91it/s]


Epoch [45/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 88.73it/s]


Epoch [46/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 102.20it/s]


Epoch [47/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 110.03it/s]


Epoch [48/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 95.90it/s]


Epoch [49/50], Loss: 0.0001


Processing Activations: 100%|██████████| 29/29 [00:00<00:00, 114.46it/s]


Epoch [50/50], Loss: 0.0001


100%|██████████| 8/8 [00:00<00:00, 581.47it/s]


Test Loss: 0.4180
SAE model saved to /Volumes/Ayush_Drive/mnist/embeddings/laion/sae_depth_10_decoder.pth
Finished Training up to depth  11


SimpleSAE(
  (encoder): Linear(in_features=1024, out_features=2304, bias=True)
  (decoder): Linear(in_features=2304, out_features=1024, bias=True)
)

In [None]:
# # load the initial sae model 
# input_dim = 16384
# hidden_dim = 8192  # Adjust as needed
# sae_model = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim)
# sae_model.load_state_dict(torch.load(f'{prefix}/embeddings/cifar100/sae_layer3_depth_1_{hidden_dim}.pth', map_location=torch.device('mps')))

In [None]:
sae_weights = torch.load(os.path.join(path, 'laion', 'sae_depth_1_weights.pt'), map_location=device)
sae_weights.shape