In [1]:
import os 

batch_size = 64

path = '/Volumes/Ayush_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

print(prefix)

/Volumes/Ayush_Drive/mnist/


In [2]:
from structs.models import SimpleSAE
from structs.utils import fc1_config, encoder_config

# Measurements 

- Have the EMNIST ones 
- Need to load in each set of models and each set of embeddings 
- or run the embeddings again? 

In [3]:
data_name = 'EMNIST'

In [4]:
# load in the activations in a dictionary 
import torch
from tqdm import tqdm 

activations = {}
for depth in tqdm(range(1, 10), desc="Loading activations"): 
    filename = f'{prefix}embeddings/mnist_encoder_{data_name}_depth_{depth}.pth'
    activations[depth] = torch.load(filename)

  activations[depth] = torch.load(filename)
Loading activations: 100%|██████████| 9/9 [01:15<00:00,  8.39s/it]


In [5]:
activations[1].shape

torch.Size([112800, 2304])

## L0 Sparsity

- the average number of nonzero feature activations

In [6]:
# get it for one depth 
def calc_l0(activation_vector):
    return (activation_vector != 0).sum() / activation_vector.shape[0]

In [None]:
# calc the average l0 sparsity per depth 
l0_sparsity = {}
for depth in activations.keys():
    l0_sparsity[depth] = calc_l0(activations[depth])
    print(f"Depth {depth}: {l0_sparsity[depth]}")


In [None]:
# plot the l0 sparsity
import matplotlib.pyplot as plt
plt.plot(list(l0_sparsity.keys()), list(l0_sparsity.values()))
plt.xlabel('Depth')
plt.ylabel('L0 Sparsity')
plt.title('L0 Sparsity vs Depth')
plt.savefig(f"plots/saebench-metrics/{data_name}_l0_sparsity.png")

In [None]:
l0_sparsity

In [19]:
# save the dictionary as a json 

import json

l0_sparsity_serializable = {str(k): float(v) for k, v in l0_sparsity.items()}

data_object = {
    "data_name" : f"{data_name}",
    "metrics" : {
        "l0_sparsity" : l0_sparsity_serializable
    }
}

json.dump(data_object, open(f'{prefix}data/{data_name}.json', 'w'))

# Measuring impact on loss

- Get model activations at layer L
- Pass through SAE to get reconstruction
- Replace original activations with reconstruction
- Continue model forward pass



- For MNIST we are just measuring the sae, then we recursively measure essentially 

#### Basic Loss

In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split

torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 10

In [3]:
# MNIST Activations  --> colored SAE? 

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class ToRGB:
    def __call__(self, img):
        return img.repeat(3, 1, 1)  # Repeat the grayscale channel 3 times

# Updated transforms for colored MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    ToRGB(),  # Convert to RGB
    transforms.Normalize(mean=[0.1307, 0.1307, 0.1307],  # Same normalization for each channel
                       std=[0.3081, 0.3081, 0.3081])
])

# Load datasets
train_dataset = datasets.MNIST(root=f'{prefix}/data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=f'{prefix}/data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
from structs.models import MNISTModel, ColoredMNISTModel
# load in trained mnist model 
model = ColoredMNISTModel()
model.load_state_dict(torch.load(f'{prefix}models/colored_mnist_model.pth'))

#### Baseline Metrics
- Validation accuracy: 0.79%
- Average loss: 0.01

#### Depth 1 Metrics 
- Validation accuracy: 0.18%
- Average loss: 0.08

In [None]:
model.clear_cache()
model.eval()
total_correct = 0
total_loss = 0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()

accuracy = total_correct / len(test_loader.dataset)
average_loss = total_loss / len(test_loader.dataset)
print(f'Validation accuracy: {accuracy:.2f}%')
print(f'Average loss: {average_loss:.2f}')

In [35]:
# HOOK Testing 
# baseline is there , now need to replace the activations in the model internally
def basic_hook(module, input, output):
    # output here is after the fc1 layer but before ReLU
    # Any changes you return will be what gets ReLU'd and passed to fc2
    print(f"input: {input[0].shape}")
    print(f"output: {output.shape}")
    print(f"module: {module}")
    return output # Pass through unchanged

# Replace the hook
basic_handle = model.fc1.register_forward_hook(basic_hook)
basic_handle.remove()

In [6]:
def create_hook_with_sae(sae_model, depth):
    def hook(module, input, output):
        encoded, reconstructed = sae_model(output)
        return reconstructed
    return hook

In [None]:
from structs.models import SimpleSAE
from structs.utils import fc1_config, encoder_config
sae_model = SimpleSAE(input_dim=fc1_config.input_dim, hidden_dim=encoder_config.input_dim) 
sae_model.load_state_dict(torch.load(f'{prefix}models/mnist-colored_sae_MNIST_depth_1.pth'))
sae_model.clear_cache()

handle = model.fc1.register_forward_hook(create_hook_with_sae(sae_model))

In [None]:
model.clear_cache()
model.eval()

total_correct = 0
total_loss = 0

criterion = nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()

accuracy = total_correct / len(test_loader.dataset)
average_loss = total_loss / len(test_loader.dataset)
print(f'Validation accuracy: {accuracy:.2f}%')
print(f'Average loss: {average_loss:.2f}')

handle.remove()

#### Load in SAE Models

In [7]:
# sae at depth 1 , 2, 3, 4, 5, 6, 7, 8, 9
sae_models = {}
for depth in range(1, 10):
    if depth == 1:
        sae_model = SimpleSAE(input_dim=fc1_config.input_dim, hidden_dim=encoder_config.input_dim) 
    else:
        sae_model = SimpleSAE(input_dim=encoder_config.input_dim, hidden_dim=encoder_config.input_dim)
    sae_model.load_state_dict(torch.load(f'{prefix}models/mnist-colored_sae_MNIST_depth_{depth}.pth'))
    sae_model.clear_cache()
    sae_models[depth] = sae_model

  sae_model.load_state_dict(torch.load(f'{prefix}models/mnist-colored_sae_MNIST_depth_{depth}.pth'))


#### Look at the impact of each successive sae model on the loss of the previous 

- establish baseline loss for each model 
- then establish the loss difference with the reconstruction just below 
- then establish the loss difference through successive layers

**This is using EMNIST Embeddings instead of MNIST** 

In [8]:
# I need to construct some hook that gets the reconstruction from depth 1 but through the entire network
# first for each sae lets just get the loss difference from base and when substituted with a forward pass through the sae 
# then we can do the same for the entire network

def sae_pass_hook(sae_model): # for some model at some depth 
    def hook(module, input, output):
        encoded, reconstructed = sae_model(output)
        return reconstructed
    return hook

In [9]:
import torch.nn as nn

#pass in an sae and the test_loader contains the activations of the sae above it (that it is fed)

def measure_loss(model, test_loader, criterion=nn.CrossEntropyLoss()):
    total_loss = 0
    total_correct = 0
    with torch.no_grad():
        for activations, labels in test_loader:
            encoded, decoded = model(activations)
            loss = criterion(encoded, activations)
            total_loss += loss.item()

    average_loss = total_loss / len(test_loader.dataset)

    return average_loss

In [10]:
dataset = activations[1]
dataset = torch.utils.data.TensorDataset(dataset, dataset)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

measure_loss(sae_models[2], test_loader)

0.024787288249886415

In [11]:
import json 

try:
    loss = json.load(open(f'{prefix}metrics/{data_name}_loss.json'))
    start_depth = int(max(loss.keys())) + 1
except (FileNotFoundError, ValueError):
    loss = {}
    start_depth = 2


for depth in tqdm(range(start_depth, 9), desc="Measuring Loss"):
    sae_obj = sae_models[depth]

    dataset = activations[depth - 1]
    dataset = torch.utils.data.TensorDataset(dataset, dataset)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    before_loss = measure_loss(sae_obj, test_loader); 

    next_sae_obj = sae_models[depth + 1]
    handle = sae_obj.encoder.register_forward_hook(sae_pass_hook(next_sae_obj))

    after_loss = measure_loss(sae_obj, test_loader); 

    loss[str(depth)] = {  # Convert depth to string for JSON
        "before_loss": before_loss,  # Convert tensor to float
        "after_loss": after_loss,
        "loss_diff": (after_loss - before_loss)
    }
    
    # Save after each iteration
    handle.remove()

    json.dump(loss, open(f'{prefix}metrics/{data_name}_loss.json', 'w'))

    print(f"Depth {depth} Loss: {loss[str(depth)]}")


Measuring Loss: 100%|██████████| 1/1 [00:18<00:00, 18.07s/it]

Depth 8 Loss: {'before_loss': 0.0, 'after_loss': 0.0, 'loss_diff': 0.0}



