In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import os 
path = '/Volumes/Sid_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split

torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 10

In [3]:
class LayerConfig:
    def __init__(self, name, input_dim):
        self.name = name
        self.input_dim = input_dim

# Create instances for each layer
fc1_config = LayerConfig('fc1', 256)
fc2_config = LayerConfig('fc2', 128)
fc3_config = LayerConfig('fc3', 10)
encoder_config = LayerConfig('encoder', 2304)
decoder_config = LayerConfig('decoder', 128)

In [4]:
# given a model and a dataset run it and save the activations as well as returning them. 
def get_activations(model, dataset, layer_name):
    model.clear_cache()
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(dataset):
            model(images)
    
    return model.get_cached_activations(layer_name)

# get the activations for the encoder layer
def get_sae_activations(sae, dataset, layer_name):
    sae.clear_cache()
    sae.eval()
    with torch.no_grad():
        for images in tqdm(dataset):
            sae(images)
    
    return sae.get_cached_activations(layer_name)

# MNIST Training

In [14]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class ToRGB:
    def __call__(self, img):
        return img.repeat(3, 1, 1)  # Repeat the grayscale channel 3 times

# Updated transforms for colored MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    ToRGB(),  # Convert to RGB
    transforms.Normalize(mean=[0.1307, 0.1307, 0.1307],  # Same normalization for each channel
                       std=[0.3081, 0.3081, 0.3081])
])

# Load datasets
train_dataset = datasets.MNIST(root=f'{prefix}/data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=f'{prefix}/data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# SOTA EMNIST MODEL : https://github.com/dipuk0506/SpinalNet/blob/master/MNIST_VGG/EMNIST_letters_VGG_and%20_SpinalVGG.py

In [15]:
import torch.optim as optim
from structs.models import ColoredMNISTModel
# Initialize model, loss function and optimizer
mnist_model = ColoredMNISTModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mnist_model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    mnist_model.train()
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = mnist_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")

# Optional: Test the model

mnist_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = mnist_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total:.2f}%')

Epoch [1/10], Loss: 0.2111


KeyboardInterrupt: 

In [11]:
torch.save(mnist_model.state_dict(), 'models/colored_mnist_model.pth')

In [16]:
mnist_model = ColoredMNISTModel()
mnist_model.load_state_dict(torch.load('models/mnist_colored.pth'))

  mnist_model.load_state_dict(torch.load('models/mnist_colored.pth'))


<All keys matched successfully>

In [17]:
mnist_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = mnist_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total:.2f}%')

Accuracy on test set: 97.52%


In [18]:
# concatenate the train and val dataset together 
full_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
full_loader = DataLoader(dataset=full_dataset, batch_size=batch_size, shuffle=False)

# get the activations for the full dataset
fc2_activations = get_activations(mnist_model, full_loader, fc2_config.name)
fc2_activations.shape

100%|██████████| 1094/1094 [00:11<00:00, 96.75it/s]


torch.Size([70000, 128])

In [19]:
torch.save(fc2_activations, 'embeddings/mnist_colored_fc2_activations.pth')

# SAE Training

In [20]:
from structs.models import EnhancedSAE, SimpleSAE
epochs = 1

def train_sae(train_loader, input_dim=fc1_config.input_dim, hidden_dim=encoder_config.input_dim):
    model = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for activations in train_loader:
            optimizer.zero_grad()
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")
    
    return model

def test_sae(model, test_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for activations in test_loader:
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            total_loss += loss.item()
    
    print(f"Test Loss: {total_loss/len(test_loader):.4f}")
    
    return total_loss/len(test_loader) # return the average loss
    

In [21]:
from sklearn.model_selection import train_test_split

# Load the activations

# Split the dataset
train_activations, test_activations = train_test_split(fc2_activations, test_size=0.2, random_state=42)

# Create data loaders
train_loader = DataLoader(dataset=train_activations, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_activations, batch_size=batch_size, shuffle=False)

# Train the SAE
# sae_model = train_sae(train_loader)

# # Test the SAE
# test_sae(sae_model, test_loader)

# # Save the model
# torch.save(sae_model.state_dict(), 'models/sae_model.pth')

# # cache the activations for the encoder layer
# full_loader = DataLoader(dataset=fc1_activations, batch_size=batch_size, shuffle=False)
# encoder_activations = get_sae_activations(sae_model, full_loader, encoder_config.name)

# torch.save(encoder_activations, 'embeddings/mnist_colored_encoder_activations.pth')

In [26]:
data_name = 'MNIST_fc2'
def infinite_sae(data, depth, custom_depth=None): #the first bit of data is the mnist activations     
    input_dim = data.shape[1]

    # Split the dataset
    train_activations, test_activations = train_test_split(data, test_size=0.2, random_state=42)

    # Create data loaders
    train_loader = DataLoader(dataset=train_activations, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_activations, batch_size=batch_size, shuffle=False)

    # Train the SAE
    model = train_sae(train_loader, input_dim=input_dim, hidden_dim=encoder_config.input_dim)

    # Test the SAE
    loss = test_sae(model, test_loader)
    with open('log.txt', 'a') as f:
        f.write(f"Data: {data_name} Depth: {depth}, Loss: {loss}\n")

    if depth < (custom_depth if custom_depth is not None else 10):  # Use custom depth if provided
        full_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=False)
        activations = get_sae_activations(model, full_loader, encoder_config.name)
        torch.save(activations, f'embeddings/mnist-colored_sae_{data_name}_depth_{depth}.pth')
        torch.save(model.state_dict(), f'models/mnist-colored_sae_{data_name}_depth_{depth}.pth')
        return infinite_sae(activations, depth + 1, custom_depth)  # Reverse the depth increment
    
    return model

In [28]:
sae = infinite_sae(fc2_activations, 1, 10)

Epoch [1/1], Loss: 156.0956
Test Loss: 20.4235


100%|██████████| 1094/1094 [00:02<00:00, 425.39it/s]


Epoch [1/1], Loss: 0.6699
Test Loss: 0.0190


100%|██████████| 1094/1094 [00:11<00:00, 96.06it/s] 


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0004


100%|██████████| 1094/1094 [00:14<00:00, 76.33it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:07<00:00, 140.80it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:03<00:00, 317.50it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:03<00:00, 353.89it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:06<00:00, 179.86it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:03<00:00, 320.05it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 399.15it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000
