# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../')

In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from structs.models import CIFAR100Model, SimpleSAE

In [4]:
import os 
path = '/Volumes/Ayush_Drive/mnist/'
embedding_path = 'embeddings/cifar100/'
if os.path.exists(path):
    prefix = path
else:
    prefix = ''

In [5]:
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 20
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: mps


# Data

In [6]:
root=f'{prefix}/data'

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load training data
train_dataset = datasets.CIFAR100(
    root=root,
    train=True,
    download=True,
    transform=transform
)

# Load test data
test_dataset = datasets.CIFAR100(
    root=root,
    train=False,
    download=True,
    transform=transform
)

In [7]:
len(test_dataset), len(train_dataset)

(10000, 50000)

# Collect Activation Map

Activation collection done in colab 

neuron_idx, depth, indices, activations

In [6]:
def get_sae_activations(sae, dataset):
    sae.clear_cache()
    sae.eval()
    with torch.no_grad():
        for images in tqdm(dataset):
            images = images.to(device)
            sae(images)

In [16]:
def collect_acts(weights, depth): 
    input_dim = weights.shape[1]
    hidden_dim = max(int(weights.shape[0] / 2), 2304)

    sae = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim)
    print(f"Loading weights for depth {depth} with input dim {input_dim} and hidden dim {hidden_dim}")
    sae.load_state_dict(torch.load(os.path.join(path, 'embeddings', 'laion', f'sae_depth_{depth}_decoder.pth'), map_location=device))
    sae.to(device)

    _ = get_sae_activations(sae, weights)
    sae_acts = torch.stack(sae.activation_cache['decoder'])
    print(f"Shape of activations for depth {depth}: {sae_acts.shape}")
    torch.save(sae_acts, os.path.join(path, 'laion', f'sae_depth_{depth}_decoder_acts.pt'))
    collect_acts(sae.decoder.weight.T.detach().clone(), depth=depth+1)


In [8]:
sae_weights = torch.load(os.path.join(path, 'laion', 'sae_depth_1_weights.pt'), map_location=device)
sae_weights.shape

torch.Size([65536, 1024])

In [15]:
dicts = torch.load(os.path.join(path, 'embeddings', 'laion', f'sae_depth_2_decoder.pth'), map_location=device)
dicts['decoder.weight'].shape, dicts['encoder.weight'].shape 

(torch.Size([1024, 16384]), torch.Size([16384, 1024]))

In [None]:
collect_acts = collect_acts(dicts['decoder.weight'].T, depth=3)

In [21]:
data_df.head()

Unnamed: 0,neuron_id,top_values,top_indices
0,0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3,3,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
4,4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [22]:
data_df.to_csv('sae_acts.csv', index=False)