In [1]:
%load_ext autoreload
%autoreload 2

import torch
torch.manual_seed(42)

import os 
path = '/Volumes/Sid_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

In [3]:
import torch 
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from structs.models import SimpleSAE, EnhancedSAE
from structs.utils import fc1_config, encoder_config

data_name = 'EMNIST'

# Load initial activations
current_activations = torch.load(f'embeddings/mnist_fc1_{data_name}.pth')

# Process through each depth
for depth in range(1, 10):
    print(f"Processing depth {depth}")
    
    # Create dataset and loader from current activations
    dataset = TensorDataset(current_activations)
    data_loader = DataLoader(dataset, batch_size=64, shuffle=False)
    
    # Initialize appropriate model for this depth
    if depth == 1:
        model = SimpleSAE(input_dim=fc1_config.input_dim, hidden_dim=2304, l1_coeff=0.01)
    else:
        model = SimpleSAE(input_dim=encoder_config.input_dim, hidden_dim=encoder_config.input_dim, l1_coeff=0.01)
    
    # Load saved weights and prepare model
    model.load_state_dict(torch.load(f'models/mnist-colored_sae_MNIST_depth_{depth}.pth'))
    model.clear_cache()
    model.eval()
    
    # Process batches
    encoded_activations = []
    with torch.no_grad():
        for (batch,) in tqdm(data_loader):  # Note the comma to unpack the single tensor
            encoded, _ = model(batch)
            encoded_activations.append(encoded)
    
    # Concatenate all batches
    current_activations = torch.cat(encoded_activations, dim=0)
    
    # Save the results
    torch.save(current_activations, f'{prefix}embeddings/mnist_encoder_{data_name}_depth_{depth}.pth')
    
    print(f"Depth {depth} shape: {current_activations.shape}")

  current_activations = torch.load(f'embeddings/mnist_fc1_{data_name}.pth')


Processing depth 1


  model.load_state_dict(torch.load(f'models/mnist-colored_sae_MNIST_depth_{depth}.pth'))
100%|██████████| 1763/1763 [00:01<00:00, 1018.12it/s]


Depth 1 shape: torch.Size([112800, 2304])
Processing depth 2


100%|██████████| 1763/1763 [00:05<00:00, 330.42it/s]


Depth 2 shape: torch.Size([112800, 2304])
Processing depth 3


100%|██████████| 1763/1763 [00:05<00:00, 296.81it/s]


Depth 3 shape: torch.Size([112800, 2304])
Processing depth 4


100%|██████████| 1763/1763 [00:05<00:00, 320.78it/s]


Depth 4 shape: torch.Size([112800, 2304])
Processing depth 5


100%|██████████| 1763/1763 [00:04<00:00, 354.29it/s]


Depth 5 shape: torch.Size([112800, 2304])
Processing depth 6


100%|██████████| 1763/1763 [00:04<00:00, 394.51it/s]


Depth 6 shape: torch.Size([112800, 2304])
Processing depth 7


100%|██████████| 1763/1763 [00:04<00:00, 369.02it/s]


Depth 7 shape: torch.Size([112800, 2304])
Processing depth 8


100%|██████████| 1763/1763 [00:11<00:00, 150.80it/s]


Depth 8 shape: torch.Size([112800, 2304])
Processing depth 9


100%|██████████| 1763/1763 [00:04<00:00, 366.39it/s]


Depth 9 shape: torch.Size([112800, 2304])
