In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
sys.path.append('../../')

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from structs.models import CIFAR100Model
import pandas as pd

import os 
path = '/Volumes/Ayush_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 30
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: mps


In [5]:
concept_names = pd.read_csv('concept_names.csv')
concept_names.head()

Unnamed: 0,neuron,concept
0,1,merriam
1,2,speakers
2,3,benches
3,4,particles
4,5,rugby


In [7]:
sae_weights = torch.load('/Volumes/Ayush_Drive/mnist/embeddings/decoder_weight_depth_1.pt')
sae_weights.T.shape

torch.Size([4096, 512])

# Process to analyze meta-sae 

- train meta-sae on weights
- pull activations 
- see max activation stuffs

In [19]:
from tqdm import tqdm
from structs.models import EnhancedSAE, SimpleSAE
import torch.optim as optim

def train_sae(train_loader, input_dim, hidden_dim, device):
    model = SimpleSAE(input_dim=input_dim, hidden_dim=hidden_dim).to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for activations in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            activations = activations.to(device)
            optimizer.zero_grad()
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")
    
    return model

def test_sae(model, test_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for activations in tqdm(test_loader, desc="Testing", colour="red"):
            activations = activations.to(device)
            encoded, decoded = model(activations)
            loss = model.compute_loss(activations, decoded, encoded)
            total_loss += loss.item()
    
    print(f"Test Loss: {total_loss/len(test_loader):.4f}")
    
    return total_loss/len(test_loader)

In [20]:
dataset = sae_weights.T.detach().numpy()
dataset = torch.tensor(dataset, dtype=torch.float32)
dataset.shape

torch.Size([4096, 512])

In [21]:
train_weights, test_weights = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_weights, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_weights, batch_size=batch_size, shuffle=False)

In [38]:
meta_sae = train_sae(train_loader, input_dim=dataset.shape[1], hidden_dim=1024, device=device)
test_sae(meta_sae, test_loader, device)

Epoch 1: 100%|██████████| 52/52 [00:01<00:00, 45.89it/s] 


Epoch [1/30], Loss: 38.1918


Epoch 2: 100%|██████████| 52/52 [00:00<00:00, 164.13it/s]


Epoch [2/30], Loss: 1.7544


Epoch 3: 100%|██████████| 52/52 [00:00<00:00, 117.66it/s]


Epoch [3/30], Loss: 0.1573


Epoch 4: 100%|██████████| 52/52 [00:00<00:00, 185.76it/s]


Epoch [4/30], Loss: 0.0310


Epoch 5: 100%|██████████| 52/52 [00:00<00:00, 147.60it/s]


Epoch [5/30], Loss: 0.0074


Epoch 6: 100%|██████████| 52/52 [00:00<00:00, 165.48it/s]


Epoch [6/30], Loss: 0.0030


Epoch 7: 100%|██████████| 52/52 [00:00<00:00, 169.45it/s]


Epoch [7/30], Loss: 0.0021


Epoch 8: 100%|██████████| 52/52 [00:00<00:00, 136.62it/s]


Epoch [8/30], Loss: 0.0019


Epoch 9: 100%|██████████| 52/52 [00:00<00:00, 144.07it/s]


Epoch [9/30], Loss: 0.0019


Epoch 10: 100%|██████████| 52/52 [00:00<00:00, 214.33it/s]


Epoch [10/30], Loss: 0.0019


Epoch 11: 100%|██████████| 52/52 [00:00<00:00, 124.44it/s]


Epoch [11/30], Loss: 0.0019


Epoch 12: 100%|██████████| 52/52 [00:00<00:00, 202.58it/s]


Epoch [12/30], Loss: 0.0019


Epoch 13: 100%|██████████| 52/52 [00:00<00:00, 120.71it/s]


Epoch [13/30], Loss: 0.0019


Epoch 14: 100%|██████████| 52/52 [00:00<00:00, 171.22it/s]


Epoch [14/30], Loss: 0.0019


Epoch 15: 100%|██████████| 52/52 [00:00<00:00, 174.33it/s]


Epoch [15/30], Loss: 0.0019


Epoch 16: 100%|██████████| 52/52 [00:00<00:00, 155.77it/s]


Epoch [16/30], Loss: 0.0019


Epoch 17: 100%|██████████| 52/52 [00:00<00:00, 220.23it/s]


Epoch [17/30], Loss: 0.0019


Epoch 18: 100%|██████████| 52/52 [00:00<00:00, 137.44it/s]


Epoch [18/30], Loss: 0.0019


Epoch 19: 100%|██████████| 52/52 [00:00<00:00, 194.45it/s]


Epoch [19/30], Loss: 0.0019


Epoch 20: 100%|██████████| 52/52 [00:00<00:00, 167.22it/s]


Epoch [20/30], Loss: 0.0019


Epoch 21: 100%|██████████| 52/52 [00:00<00:00, 172.70it/s]


Epoch [21/30], Loss: 0.0019


Epoch 22: 100%|██████████| 52/52 [00:00<00:00, 137.61it/s]


Epoch [22/30], Loss: 0.0019


Epoch 23: 100%|██████████| 52/52 [00:00<00:00, 224.49it/s]


Epoch [23/30], Loss: 0.0019


Epoch 24: 100%|██████████| 52/52 [00:00<00:00, 161.08it/s]


Epoch [24/30], Loss: 0.0019


Epoch 25: 100%|██████████| 52/52 [00:00<00:00, 131.82it/s]


Epoch [25/30], Loss: 0.0019


Epoch 26: 100%|██████████| 52/52 [00:00<00:00, 151.81it/s]


Epoch [26/30], Loss: 0.0019


Epoch 27: 100%|██████████| 52/52 [00:00<00:00, 169.11it/s]


Epoch [27/30], Loss: 0.0019


Epoch 28: 100%|██████████| 52/52 [00:00<00:00, 189.56it/s]


Epoch [28/30], Loss: 0.0019


Epoch 29: 100%|██████████| 52/52 [00:00<00:00, 138.08it/s]


Epoch [29/30], Loss: 0.0019


Epoch 30: 100%|██████████| 52/52 [00:00<00:00, 146.16it/s]


Epoch [30/30], Loss: 0.0019


Testing: 100%|[31m██████████[0m| 13/13 [00:00<00:00, 86.24it/s]

Test Loss: 0.5369





0.5369405608910781

In [60]:
meta_sae

SimpleSAE(
  (encoder): Linear(in_features=512, out_features=1024, bias=True)
  (decoder): Linear(in_features=1024, out_features=512, bias=True)
)

In [39]:
def get_sae_activations(sae, dataset, layer_name):
    sae.clear_cache()
    sae.eval()
    with torch.no_grad():
        for images in tqdm(dataset):
            images = images.to(device)
            sae(images)
    
    return sae.get_cached_activations(layer_name)

In [None]:
_ = get_sae_activations(meta_sae, dataset, 'encoder')

In [41]:
encoder_acts = torch.stack(meta_sae.activation_cache['encoder'])
decoder_acts = torch.stack(meta_sae.activation_cache['decoder'])
encoder_acts.shape, decoder_acts.shape

(torch.Size([4096, 1024]), torch.Size([4096, 512]))

In [42]:
encoder_acts.mean(), encoder_acts.std(), encoder_acts.min(), encoder_acts.max()

(tensor(1.6576e-05, device='mps:0'),
 tensor(0.0006, device='mps:0'),
 tensor(0., device='mps:0'),
 tensor(0.1628, device='mps:0'))

In [43]:
decoder_acts.mean(), decoder_acts.std(), decoder_acts.min(), decoder_acts.max()

(tensor(0.0004, device='mps:0'),
 tensor(0.0017, device='mps:0'),
 tensor(-0.0290, device='mps:0'),
 tensor(0.0307, device='mps:0'))

In [70]:
import numpy as np
import pandas as pd
# the analysis function given the dataset (dont even need the model tbh, pull the activations from elsewhere)
def find_max_indices(activations, topk=10): 
    # columnwise top 10 
    top_k = min(topk, activations.shape[0])
    top_values, top_indices = torch.topk(activations, top_k, dim=1)

    # convert into dataframe 
    values_np = top_values.cpu().numpy()
    indices_np = top_indices.cpu().numpy()
    
    print(activations.shape)
    neuron_ids = np.arange(activations.shape[0])

    df = pd.DataFrame({
        'neuron_id': neuron_ids,
        'top_values': list(values_np),
        'top_indices': list(indices_np)
    })

    return df

In [71]:
decoder_acts.shape

torch.Size([4096, 512])

In [72]:
df = find_max_indices(activations=decoder_acts, topk=10)
df.head()

torch.Size([4096, 512])


Unnamed: 0,neuron_id,top_values,top_indices
0,0,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]"
1,1,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]"
2,2,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]"
3,3,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]"
4,4,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]"


In [73]:
# create a new column with an array of names that are pulled by finding the concept names for the top indices per row 
def get_concept_names(row):
    indices = row['top_indices']
    names = concept_names.iloc[indices]['concept'].values
    return names

df['concept_names'] = df.apply(get_concept_names, axis=1)
df['concept_names'] = df['concept_names'].apply(lambda x: ', '.join(x))
df.head()

Unnamed: 0,neuron_id,top_values,top_indices,concept_names
0,0,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]","sanchez, refused, turquoise, watering, vests, ..."
1,1,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]","sanchez, refused, turquoise, watering, vests, ..."
2,2,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]","sanchez, refused, turquoise, watering, vests, ..."
3,3,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]","sanchez, refused, turquoise, watering, vests, ..."
4,4,"[0.0057347855, 0.0056616445, 0.004677073, 0.00...","[107, 182, 349, 49, 53, 256, 374, 128, 228, 427]","sanchez, refused, turquoise, watering, vests, ..."


In [75]:
df.to_csv('vit_depth2_labels.csv', index=False)

In [81]:
df['top_values'].explode().mode()

0    0.003955
1    0.003992
2    0.004079
3    0.004105
4    0.004146
5    0.004208
6    0.004313
7    0.004677
8    0.005662
9    0.005735
Name: top_values, dtype: object

In [6]:
df = pd.read_csv('vit_depth2_labels.csv')

In [9]:
df['concept_names'].to_csv('vit_depth2_labels.csv', index=False, header=False)