In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split

torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 10

In [8]:
class LayerConfig:
    def __init__(self, name, input_dim):
        self.name = name
        self.input_dim = input_dim

# Create instances for each layer
fc1_config = LayerConfig('fc1', 256)
fc2_config = LayerConfig('fc2', 128)
fc3_config = LayerConfig('fc3', 10)
encoder_config = LayerConfig('encoder', 2304)
decoder_config = LayerConfig('decoder', 128)

In [9]:
# given a model and a dataset run it and save the activations as well as returning them. 
def get_activations(model, dataset, layer_name):
    model.clear_cache()
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(dataset):
            model(images)
    
    return model.get_cached_activations(layer_name)

# get the activations for the encoder layer
def get_sae_activations(sae, dataset, layer_name):
    sae.clear_cache()
    sae.eval()
    with torch.no_grad():
        for images in tqdm(dataset):
            sae(images)
    
    return sae.get_cached_activations(layer_name)

# MNIST Training

In [12]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Updated transforms for colored EMNIST
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load datasets
train_dataset = datasets.EMNIST(root='./data', split='balanced', train=True, download=True, transform=transform_train)
test_dataset = datasets.EMNIST(root='./data', split='balanced', train=False, download=True, transform=transform_val)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


# SOTA EMNIST MODEL : https://github.com/dipuk0506/SpinalNet/blob/master/MNIST_VGG/EMNIST_letters_VGG_and%20_SpinalVGG.py

In [13]:
import torch.optim as optim
from structs.models import ColoredMNISTModel
# Initialize model, loss function and optimizer
mnist_model = ColoredMNISTModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mnist_model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    mnist_model.train()
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = mnist_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")

# Optional: Test the model

mnist_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = mnist_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total:.2f}%')

IndexError: Target 23 is out of bounds.

In [11]:
torch.save(mnist_model.state_dict(), 'models/colored_mnist_model.pth')

In [12]:
mnist_model = ColoredMNISTModel()
mnist_model.load_state_dict(torch.load('models/colored_mnist_model.pth'))

<All keys matched successfully>

In [13]:
# concatenate the train and val dataset together 
full_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
full_loader = DataLoader(dataset=full_dataset, batch_size=batch_size, shuffle=False)

# get the activations for the full dataset
fc1_activations = get_activations(mnist_model, full_loader, fc1_config.name)
fc1_activations.shape

100%|██████████| 1094/1094 [00:06<00:00, 179.69it/s]


torch.Size([70000, 256])

In [14]:
torch.save(fc1_activations, 'embeddings/mnist_colored_fc1_activations.pth')

# SAE Training

In [29]:
from structs.models import EnhancedSAE, SimpleSAE
epochs = 1

def train_sae(train_loader, input_dim=fc1_config.input_dim, hidden_dim=encoder_config.input_dim):
    model = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for activations in train_loader:
            optimizer.zero_grad()
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")
    
    return model

def test_sae(model, test_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for activations in test_loader:
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            total_loss += loss.item()
    
    print(f"Test Loss: {total_loss/len(test_loader):.4f}")
    
    return total_loss/len(test_loader) # return the average loss
    

In [18]:
from sklearn.model_selection import train_test_split

# Load the activations

# Split the dataset
train_activations, test_activations = train_test_split(fc1_activations, test_size=0.2, random_state=42)

# Create data loaders
train_loader = DataLoader(dataset=train_activations, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_activations, batch_size=batch_size, shuffle=False)

# Train the SAE
# sae_model = train_sae(train_loader)

# # Test the SAE
# test_sae(sae_model, test_loader)

# # Save the model
# torch.save(sae_model.state_dict(), 'models/sae_model.pth')

# # cache the activations for the encoder layer
# full_loader = DataLoader(dataset=fc1_activations, batch_size=batch_size, shuffle=False)
# encoder_activations = get_sae_activations(sae_model, full_loader, encoder_config.name)

# torch.save(encoder_activations, 'embeddings/mnist_colored_encoder_activations.pth')

In [37]:
data_name = 'MNIST'
def infinite_sae(data, depth, custom_depth=None): #the first bit of data is the mnist activations     
    input_dim = data.shape[1]

    # Split the dataset
    train_activations, test_activations = train_test_split(data, test_size=0.2, random_state=42)

    # Create data loaders
    train_loader = DataLoader(dataset=train_activations, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_activations, batch_size=batch_size, shuffle=False)

    # Train the SAE
    model = train_sae(train_loader, input_dim=input_dim, hidden_dim=encoder_config.input_dim)

    # Test the SAE
    loss = test_sae(model, test_loader)
    with open('log.txt', 'a') as f:
        f.write(f"Data: {data_name} Depth: {depth}, Loss: {loss}\n")

    if depth < (custom_depth if custom_depth is not None else 10):  # Use custom depth if provided
        full_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=False)
        activations = get_sae_activations(model, full_loader, encoder_config.name)
        torch.save(activations, f'embeddings/mnist-colored_sae_{data_name}_depth_{depth}.pth')
        torch.save(model.state_dict(), f'models/mnist-colored_sae_{data_name}_depth_{depth}.pth')
        return infinite_sae(activations, depth + 1, custom_depth)  # Reverse the depth increment
    
    return model

In [38]:
sae = infinite_sae(fc1_activations, 1, 10)

Epoch [1/1], Loss: 170.4939
Test Loss: 34.0440


100%|██████████| 1094/1094 [00:00<00:00, 1745.29it/s]


Epoch [1/1], Loss: 0.6468
Test Loss: 0.0178


100%|██████████| 1094/1094 [00:02<00:00, 461.78it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0002


100%|██████████| 1094/1094 [00:02<00:00, 459.07it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 469.06it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 465.00it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 435.17it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 482.18it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 464.78it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000


100%|██████████| 1094/1094 [00:02<00:00, 449.37it/s]


Epoch [1/1], Loss: 0.6290
Test Loss: 0.0000
