In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
torch.manual_seed(42)

<torch._C.Generator at 0x116987850>

### Config Stuff

In [3]:
# model hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 5

In [4]:
class LayerConfig:
    def __init__(self, name, input_dim):
        self.name = name
        self.input_dim = input_dim

# Create instances for each layer
fc1_config = LayerConfig('fc1', 256)
fc2_config = LayerConfig('fc2', 128)
fc3_config = LayerConfig('fc3', 10)
encoder_config = LayerConfig('encoder', 2304)
decoder_config = LayerConfig('decoder', 128)

### Load in Data

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
from tqdm import tqdm

# Data augmentation and normalization for training
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Normalization for validation
transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

data_name = 'EMNIST'

# Load CIFAR10 dataset
# train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
# val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)

# Load CIFAR100 dataset
# train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
# val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val)

# Load EMNIST dataset
train_dataset = datasets.EMNIST(root='./data', split='balanced', train=True, download=True, transform=transform_train)
val_dataset = datasets.EMNIST(root='./data', split='balanced', train=False, download=True, transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

### Display the Image

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    # Undo normalization
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2023, 0.1994, 0.2010])
    
    # Convert to numpy and transpose from (C,H,W) to (H,W,C)
    img = img.numpy().transpose((1, 2, 0))
    
    # Unnormalize
    img = std * img + mean
    
    # Clip values to be between 0 and 1
    img = np.clip(img, 0, 1)
    
    plt.imshow(img)
    plt.axis('off')

# Get a random image from the training set
dataiter = iter(train_loader)
images, labels = next(dataiter)

# # Display a single image
# plt.figure(figsize=(4, 4))
# imshow(images[0])
# plt.title(f'Class: {labels[0].item()}')
# plt.show()

# If you want to display the actual class name
classes = train_dataset.classes  # Get class names
plt.figure(figsize=(4, 4))
imshow(images[0])
plt.title(f'Class: {classes[labels[0].item()]}')
plt.show()

In [None]:
from structs.models import MNISTModel, ColoredMNISTModel
# load in trained mnist model 
model = ColoredMNISTModel()
model.load_state_dict(torch.load('models/mnist_colored.pth'))

### Get Activations

In [9]:
selected_layer_config = fc1_config

In [None]:
from tqdm import tqdm 
import matplotlib.pyplot as plt

model.clear_cache()
model.eval()
with torch.no_grad():
    for images, labels in tqdm(train_loader):
        outputs = model(images)

In [None]:
num_images = len(train_loader.dataset)
print(f'Number of images in the train_loader: {num_images}')

In [None]:
analysis_activations = model.get_cached_activations(selected_layer_config.name)
analysis_activations.shape

In [13]:
torch.save(analysis_activations, f'embeddings/mnist_{selected_layer_config.name}_{data_name}.pth')

### Sae Analysis

In [None]:
from structs.models import SimpleSAE, EnhancedSAE

input_dim = selected_layer_config.input_dim
hidden_dim = 2304
sae = EnhancedSAE(input_dim=input_dim, hidden_dim=hidden_dim, l1_coeff=0.01)
sae.load_state_dict(torch.load('models/mnist_sae_colored.pth'))

In [None]:
analysis_activations = torch.load(f'embeddings/mnist_{fc1_config.name}_{data_name}.pth')
analysis_loader = DataLoader(analysis_activations, batch_size=batch_size, shuffle=False) # do not shuffle

analysis_activations.shape

In [None]:
# pass the activations through the SAE and save the intermediary activations 
sae.clear_cache()
sae.eval()
with torch.no_grad():
    for activations in tqdm(analysis_loader):
        encoded, decoded = sae(activations, cache_activations=True)

In [None]:
sae_activations = sae.get_cached_activations('encoder')
torch.save(sae_activations, f'embeddings/mnist_encoder_{data_name}.pth')
sae_activations.shape

#### Second SAE

In [None]:
from structs.models import SimpleSAE, EnhancedSAE
selected_layer_config = encoder_config

input_dim = selected_layer_config.input_dim
hidden_dim = 2304
meta_sae = EnhancedSAE(input_dim=input_dim, hidden_dim=hidden_dim, l1_coeff=0.01)
meta_sae.load_state_dict(torch.load('models/mnist_meta_sae_colored.pth'))

In [None]:
meta_analysis_activations = torch.load(f'embeddings/mnist_encoder_{data_name}.pth')
meta_analysis_loader = DataLoader(meta_analysis_activations, batch_size=batch_size, shuffle=False) # do not shuffle

meta_analysis_activations.shape

In [None]:
# pass the activations through the SAE and save the intermediary activations 
meta_sae.clear_cache()
meta_sae.eval()
with torch.no_grad():
    for activations in tqdm(meta_analysis_loader):
        encoded, decoded = meta_sae(activations, cache_activations=True)

In [None]:
meta_sae_activations = meta_sae.get_cached_activations('encoder')
meta_sae_activations.shape

In [13]:
torch.save(meta_sae_activations, f'embeddings/mnist_meta_encoder_{data_name}.pth')

### Top 10 images

In [None]:
sae_activations = torch.load(f'embeddings/mnist_meta_encoder_{data_name}.pth')

In [None]:
# transpose is to make it by neuron as opposed to by sample 
transposed = sae_activations.T
transposed.shape

In [16]:
max_10_indices_per_neuron = torch.argsort(transposed, descending=True, dim=1)[:, :10]
max_10_indices_per_neuron_value = torch.gather(transposed, 1, max_10_indices_per_neuron)

### Analyze Image

In [None]:
original_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=None)

In [32]:
from matplotlib import pyplot as plt

def plot_indices_save(indices, filepath, filename, neuron_idx):
    plt.figure(figsize=(15, 4))  # Adjusted figure size for 10 images
    count = 0
    for i, idx in enumerate(indices):
        activation_value = transposed[neuron_idx][idx]
        if transposed[neuron_idx][idx] == 0:
            # print('skipped idx', idx)
            continue
        count += 1


        img = train_dataset.data[idx]
        
        plt.subplot(2, 5, i+1)  # 2 rows, 5 columns
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Index: {idx} - Value: {activation_value:.2f}')


    if plt.gca().has_data():
        plt.tight_layout()
        plt.savefig(f'{filepath}/{count}_{filename}')  # Save the figure to a file
    plt.close()  # Close the figure to free memory

In [None]:
from tqdm import tqdm
import os 

filepath = f'docs/neuron_{data_name}_meta'
os.makedirs(filepath, exist_ok=True)

for i in tqdm(range(len(max_10_indices_per_neuron)), desc="Plotting neurons"):
    indices = max_10_indices_per_neuron[i]
    plot_indices_save(indices, filepath, f'neuron_{i}_plots.png', i)  # Save individual plots if needed
