In [1]:
import torch
torch.manual_seed(42)

<torch._C.Generator at 0x13059aed0>

## MNIST

In [2]:
import torch
import torch.nn as nn

class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Input layer (28x28 pixels)
        self.fc2 = nn.Linear(128, 64)       # Hidden layer
        self.fc3 = nn.Linear(64, 10)        # Output layer (10 classes)

        # Initialize a cache to store lists of activations
        self.activation_cache = {
            'fc1': [],
            'fc2': [],
            'fc3': []
        }

    def forward(self, x, cache_activations=False):
        x = x.view(-1, 28 * 28)  # Flatten the input image
        
        # Pass through fc1 and cache activations if needed
        fc1_out = torch.relu(self.fc1(x))
        if cache_activations:
            self.activation_cache['fc1'].append(fc1_out.detach().clone())  # Append fc1 activations
        
        # Pass through fc2 and cache activations if needed
        fc2_out = torch.relu(self.fc2(fc1_out))
        if cache_activations:
            self.activation_cache['fc2'].append(fc2_out.detach().clone())  # Append fc2 activations
        
        # Pass through fc3 and cache activations if needed
        fc3_out = self.fc3(fc2_out)
        if cache_activations:
            self.activation_cache['fc3'].append(fc3_out.detach().clone())  # Append fc3 activations

        return fc3_out

    # Method to retrieve cached activations for a specified layer
    def get_cached_activations(self, layer_name):
        return torch.cat(self.activation_cache[layer_name]) if layer_name in self.activation_cache else None

    # Method to clear the cache
    def clear_cache(self):
        self.activation_cache = {
            'fc1': [],
            'fc2': [],
            'fc3': []
        }

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 5

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and std deviation of MNIST dataset
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function and optimizer
model = MNISTModel()
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for classification
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
all_activations = []

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()  # Zero out previous gradients
        outputs = model(images)  # Get output and activations from fc2
        loss = criterion(outputs, labels)
        loss.backward()  # Backpropagation
        optimizer.step()  # Gradient descent
        total_loss += loss.item()

        # Store activations for the sparse autoencoder

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")


Epoch [1/5], Loss: 0.2645
Epoch [2/5], Loss: 0.1110
Epoch [3/5], Loss: 0.0776
Epoch [4/5], Loss: 0.0612
Epoch [5/5], Loss: 0.0484


### Test Loop

In [4]:
#* Test Accuracy Loop 
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Test accuracy: {100 * correct / total}%")

Test accuracy: 97.58%


### Cache Code

In [5]:
#* Test Accuracy Loop 
model.clear_cache()
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images, cache_activations=True)
        _, predicted = torch.max(outputs.data, 1)

In [6]:
model.get_cached_activations('fc1').shape, model.get_cached_activations('fc2').shape, model.get_cached_activations('fc3').shape

(torch.Size([10000, 128]), torch.Size([10000, 64]), torch.Size([10000, 10]))

In [7]:
test_loader.dataset.targets.shape

torch.Size([10000])

## SAE

*Activations Shapes:* 
- fc1: torch.Size([10000, 128])
- fc2: torch.Size([10000, 64])
- fc3: torch.Size([10000, 10]))


In [8]:
class LayerConfig:
    def __init__(self, name, input_dim):
        self.name = name
        self.input_dim = input_dim

# Create instances for each layer
fc1_config = LayerConfig('fc1', 128)
fc2_config = LayerConfig('fc2', 64)
fc3_config = LayerConfig('fc3', 10)

### Simple SAE

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

class SimpleSAE(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=32, l1_coeff=0.1, seed=42):
        super(SimpleSAE, self).__init__()
        torch.manual_seed(seed)
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
        self.l1_coeff = l1_coeff  # L1 regularization coefficient for sparsity
        
        # Initialize a cache to store lists of activations
        self.activation_cache = {
            'encoder': [],
            'decoder': []
        }

    def forward(self, x, cache_activations=False):
        # Encoder: Reduce the dimensionality
        encoded = torch.relu(self.encoder(x))
        if cache_activations:
            self.activation_cache['encoder'].append(encoded.detach().clone())  # Append encoder activations

        # Decoder: Reconstruct the original input
        decoded = self.decoder(encoded)
        if cache_activations:
            self.activation_cache['decoder'].append(decoded.detach().clone())  # Append decoder activations

        return encoded, decoded

    def compute_loss(self, x, decoded, encoded):
        # Reconstruction Loss (MSE)
        recon_loss = nn.MSELoss()(decoded, x)

        # L1 Sparsity Loss (L1 regularization on encoded activations)
        l1_loss = self.l1_coeff * torch.sum(torch.abs(encoded))

        # Combine losses: L = MSE + λ * L1
        loss = recon_loss + l1_loss

        return loss

    # Method to retrieve cached activations for a specified layer
    def get_cached_activations(self, layer_name):
        return torch.cat(self.activation_cache[layer_name]) if layer_name in self.activation_cache else None

    # Method to clear the cache
    def clear_cache(self):
        self.activation_cache = {
            'encoder': [],
            'decoder': []
        }


In [10]:
import math 
import torch.nn.functional as F

class EnhancedSAE(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=32, l1_coeff=0.1, seed=42):
        super(EnhancedSAE, self).__init__()
        self.l1_coeff = l1_coeff  # L1 regularization coefficient for sparsity
        torch.manual_seed(seed)

        # Encoder and decoder with Kaiming initialization
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)

        nn.init.kaiming_uniform_(self.encoder.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.decoder.weight, a=math.sqrt(5))
        self.encoder.bias.data.zero_()
        self.decoder.bias.data.zero_()

    def forward(self, x):
        # Center the input
        x_centered = x - self.decoder.bias
        # Encoder: Reduce the dimensionality
        encoded = F.relu(self.encoder(x_centered))
        # Decoder: Reconstruct the original input
        decoded = self.decoder(encoded)

        return encoded, decoded

    def compute_loss(self, x, decoded, encoded):
        # Reconstruction Loss (MSE)
        recon_loss = nn.MSELoss()(decoded, x)

        # L1 Sparsity Loss (L1 regularization on encoded activations)
        l1_loss = self.l1_coeff * torch.sum(torch.abs(encoded))

        # Combine losses: L = MSE + λ * L1
        loss = recon_loss + l1_loss

        return loss


In [11]:
# Choose the layer configuration you want to use
selected_layer_config = fc1_config  # Change this to fc2_config or fc3_config as needed

# Use the selected layer configuration
input_dim = selected_layer_config.input_dim
hidden_dim = 32
sae = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim, l1_coeff=0.1)

# Optimizer
optimizer_sae = optim.Adam(sae.parameters(), lr=learning_rate)

# Activations loaded from cache
train_activations = model.get_cached_activations(selected_layer_config.name)

# Split the activations into train and test sets (80% train, 20% test)
train_activations, test_activations = train_test_split(train_activations, test_size=0.2, random_state=42)

# DataLoader for activations
train_activations_loader = DataLoader(train_activations, batch_size=batch_size, shuffle=True)
test_activations_loader = DataLoader(test_activations, batch_size=batch_size, shuffle=False)

# Training Loop for SAE
epochs = 5  # You can adjust this based on your preference
for epoch in range(epochs):
    sae.train()
    total_loss = 0
    for batch in train_activations_loader:
        optimizer_sae.zero_grad()  # Zero out previous gradients

        # Forward pass through the SAE
        encoded, decoded = sae(batch)

        # Compute loss (MSE + sparsity penalty)
        loss = sae.compute_loss(batch, decoded, encoded)

        loss.backward()  # Backprop for SAE
        optimizer_sae.step()  # Optimizer step

        total_loss += loss.item()

    print(f"SAE Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_activations_loader):.4f}")


SAE Epoch [1/5], Loss: 15.6593
SAE Epoch [2/5], Loss: 8.7535
SAE Epoch [3/5], Loss: 8.4562
SAE Epoch [4/5], Loss: 8.1989
SAE Epoch [5/5], Loss: 7.9738


In [12]:
# Evaluation Loop for SAE
sae.eval()
total_test_loss = 0
with torch.no_grad():
    for batch in test_activations_loader:
        # Forward pass through the SAE
        encoded, decoded = sae(batch)

        # Compute loss (MSE + sparsity penalty)
        loss = sae.compute_loss(batch, decoded, encoded)

        total_test_loss += loss.item()

print(f"SAE Test Loss: {total_test_loss/len(test_activations_loader):.4f}")


SAE Test Loss: 7.9110


## What have I done so far? 

- Created MNIST: ~97% accuracy 
- Trained two SAEs on MNIST
    - SimpleSAE (just encoder decoder)
    - ComplexSae (Based off of Neel Nanda's sae when replicating monosemanticity paper) 
- Simple seems to performs slightly better 

*Set to a random seed* 
 
Next Steps: 
- Cache all activations when testing the sae
- Take the middle layer of the encoder of sae and pass it back into simple sae 

Confusions: 
- Should the hidden layer of an SAE be larger than the input
    - *What is a hidden layer?* 

### NEED TO TRACK MAX ACTIVATIONS

## Meta-SAE Time

In [13]:
encoder_config = LayerConfig('encoder', 32)
decoder_config = LayerConfig('decoder', 128)

Cache activations

In [14]:
# Evaluation Loop for SAE
sae.clear_cache()
sae.eval()
with torch.no_grad():
    for batch in test_activations_loader:
        # Forward pass through the SAE
        encoded, decoded = sae(batch, cache_activations=True)

In [15]:
# get cached activations
sae.get_cached_activations('encoder').shape, sae.get_cached_activations('decoder').shape

(torch.Size([2000, 32]), torch.Size([2000, 128]))

In [16]:
selected_layer_config = encoder_config
input_dim = selected_layer_config.input_dim
hidden_dim = 16
meta_sae = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim, l1_coeff=0.1)

optimizer_sae = optim.Adam(meta_sae.parameters(), lr=learning_rate)

train_activations = sae.get_cached_activations(selected_layer_config.name)

In [17]:

train_activations, test_activations = train_test_split(train_activations, test_size=0.2, random_state=42)

train_activations_loader = DataLoader(train_activations, batch_size=batch_size, shuffle=True)
test_activations_loader = DataLoader(test_activations, batch_size=batch_size, shuffle=False)

epochs = 5
for epoch in range(epochs):
    meta_sae.train()
    total_loss = 0
    for batch in train_activations_loader:
        optimizer_sae.zero_grad()
        encoded, decoded = meta_sae(batch)

        loss = meta_sae.compute_loss(batch, decoded, encoded)

        loss.backward()
        optimizer_sae.step()

        total_loss += loss.item()

    print(f"SAE Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_activations_loader):.4f}")

SAE Epoch [1/5], Loss: 4.4836
SAE Epoch [2/5], Loss: 3.2285
SAE Epoch [3/5], Loss: 2.1484
SAE Epoch [4/5], Loss: 1.2984
SAE Epoch [5/5], Loss: 0.5460


### Notes

- meta sae has reallly good loss???
- Whats topk and why?

## Evaluation

How does ViT Prisma do thier emjoi thing? --> Want to track the change in the models understanding of the number over time. 

- Max Activation Evalution

(1) MNIST: Run the train set and cache activations 
--- Store the output labels 
(2) Take the labels from the MNIST test and run those activations through the SAE and track max activations of the encoder output layer. 
(3) Take the acrtivations of the encoder output and track max activations of the encoder output layer 

(2) (3) can have a function in teh simple sae class to help with tracking max activations of the SAE 

### MNIST Max Activations

Save the top 