## MNIST

In [2]:
import torch
import torch.nn as nn

class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Input layer (28x28 pixels)
        self.fc2 = nn.Linear(128, 64)       # Hidden layer
        self.fc3 = nn.Linear(64, 10)        # Output layer (10 classes)

        # Initialize a cache to store lists of activations
        self.activation_cache = {
            'fc1': [],
            'fc2': [],
            'fc3': []
        }

    def forward(self, x, cache_activations=False):
        x = x.view(-1, 28 * 28)  # Flatten the input image
        
        # Pass through fc1 and cache activations if needed
        fc1_out = torch.relu(self.fc1(x))
        if cache_activations:
            self.activation_cache['fc1'].append(fc1_out.detach().clone())  # Append fc1 activations
        
        # Pass through fc2 and cache activations if needed
        fc2_out = torch.relu(self.fc2(fc1_out))
        if cache_activations:
            self.activation_cache['fc2'].append(fc2_out.detach().clone())  # Append fc2 activations
        
        # Pass through fc3 and cache activations if needed
        fc3_out = self.fc3(fc2_out)
        if cache_activations:
            self.activation_cache['fc3'].append(fc3_out.detach().clone())  # Append fc3 activations

        return fc3_out

    # Method to retrieve cached activations for a specified layer
    def get_cached_activations(self, layer_name):
        return torch.cat(self.activation_cache[layer_name]) if layer_name in self.activation_cache else None

    # Method to clear the cache
    def clear_cache(self):
        self.activation_cache = {
            'fc1': [],
            'fc2': [],
            'fc3': []
        }

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 5

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and std deviation of MNIST dataset
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function and optimizer
model = MNISTModel()
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for classification
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
all_activations = []

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Zero out previous gradients
        outputs = model(images)  # Get output and activations from fc2
        loss = criterion(outputs, labels)
        loss.backward()  # Backpropagation
        optimizer.step()  # Gradient descent
        total_loss += loss.item()

        # Store activations for the sparse autoencoder

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")


Epoch [1/5], Loss: 0.2696
Epoch [2/5], Loss: 0.1127
Epoch [3/5], Loss: 0.0793
Epoch [4/5], Loss: 0.0599
Epoch [5/5], Loss: 0.0454


### Test Loop

In [5]:
#* Test Accuracy Loop 
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Test accuracy: {100 * correct / total}%")

Test accuracy: 97.02%


### Cache Code

In [6]:
#* Test Accuracy Loop 
model.clear_cache()
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images, cache_activations=True)
        _, predicted = torch.max(outputs.data, 1)

In [7]:
model.get_cached_activations('fc1').shape, model.get_cached_activations('fc2').shape, model.get_cached_activations('fc3').shape

(torch.Size([10000, 128]), torch.Size([10000, 64]), torch.Size([10000, 10]))

In [8]:
test_loader.dataset.targets.shape

torch.Size([10000])

## SAE

*Activations Shapes:* 
- fc1: torch.Size([10000, 128])
- fc2: torch.Size([10000, 64])
- fc3: torch.Size([10000, 10]))


### Simple SAE

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class SimpleSAE(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=32, l1_coeff=0.1):
        super(SimpleSAE, self).__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
        self.l1_coeff = l1_coeff  # L1 regularization coefficient for sparsity

    def forward(self, x):
        # Encoder: Reduce the dimensionality
        encoded = torch.relu(self.encoder(x))

        # Decoder: Reconstruct the original input
        decoded = self.decoder(encoded)

        return encoded, decoded

    def compute_loss(self, x, decoded, encoded):
        # Reconstruction Loss (MSE)
        recon_loss = nn.MSELoss()(decoded, x)

        # L1 Sparsity Loss (L1 regularization on encoded activations)
        l1_loss = self.l1_coeff * torch.sum(torch.abs(encoded))

        # Combine losses: L = MSE + λ * L1
        loss = recon_loss + l1_loss

        return loss

# Initialize the SAE
input_dim = 64
hidden_dim = 32
sae = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim, l1_coeff=0.1)

# Optimizer
optimizer_sae = optim.Adam(sae.parameters(), lr=learning_rate)

train_activations = model.get_cached_activations('fc2')

# DataLoader for activations
train_activations_loader = DataLoader(train_activations, batch_size=batch_size, shuffle=True)

# Training Loop for SAE
epochs = 5  # You can adjust this based on your preference
for epoch in range(epochs):
    sae.train()
    total_loss = 0
    for batch in train_activations_loader:
        optimizer_sae.zero_grad()  # Zero out previous gradients

        # Forward pass through the SAE
        encoded, decoded = sae(batch)

        # Compute loss (MSE + sparsity penalty)
        loss = sae.compute_loss(batch, decoded, encoded)

        loss.backward()  # Backprop for SAE
        optimizer_sae.step()  # Optimizer step

        total_loss += loss.item()

    print(f"SAE Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_activations_loader):.4f}")


SAE Epoch [1/5], Loss: 29.5675
SAE Epoch [2/5], Loss: 13.4935
SAE Epoch [3/5], Loss: 12.8894
SAE Epoch [4/5], Loss: 12.3574
SAE Epoch [5/5], Loss: 11.9019


In [19]:
# DataLoader for test activations
test_activations_loader = DataLoader(test_activations, batch_size=batch_size, shuffle=False)

# Validation Loop for SAE
sae.eval()
total_val_loss = 0
with torch.no_grad():
    for batch in test_activations_loader:
        # Forward pass through the SAE
        encoded, decoded = sae(batch)

        # Compute loss (MSE + sparsity penalty)
        loss = sae.compute_loss(batch, decoded, encoded)

        total_val_loss += loss.item()

print(f"SAE Validation Loss: {total_val_loss/len(test_activations_loader):.4f}")

SAE Validation Loss: 9.0200
