In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import wandb
import os
import glob
from topk_sae import FastAutoencoder, loss_fn, unit_norm_decoder_grad_adjustment_, unit_norm_decoder_, init_from_data_


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(ae, train_loader, optimizer, epochs, k, auxk_coef, clip_grad=None, save_dir="checkpoints", model_name=""):
    os.makedirs(save_dir, exist_ok=True)
    step = 0
    num_batches = len(train_loader)
    for epoch in range(epochs):
        ae.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            x = batch[0].to(device)
            recons, info = ae(x)
            loss, recons_loss, auxk_loss = loss_fn(ae, x, recons, info, auxk_coef)
            loss.backward()
            step += 1
            
            # calculate proportion of dead latents (not fired in last num_batches = 1 epoch)
            dead_latents_prop = (ae.stats_last_nonzero > num_batches).float().mean().item()
            
            unit_norm_decoder_grad_adjustment_(ae)
            
            if clip_grad is not None:
                torch.nn.utils.clip_grad_norm_(ae.parameters(), clip_grad)
            
            optimizer.step()
            unit_norm_decoder_(ae)

            topk_indices = torch.cat((info["topk_indices"], info["auxk_indices"]), dim = -1)
            selected_grad = torch.abs(ae.encoder.weight.grad[topk_indices, :]).mean()
            print('encoder', selected_grad)

            mask = torch.ones_like(ae.encoder.weight.grad)
            mask[topk_indices, :] = 0
            unselected_grad = ae.encoder.weight.grad[mask == 1].mean()
            print(unselected_grad)

            selected_dgrad = torch.abs(ae.decoder.weight.grad[:, topk_indices]).mean()
            print('decoder', selected_dgrad)
            unselected_dgrad = ae.decoder.weight.grad[mask.T == 1].mean()
            print(unselected_dgrad)
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Delete previous model saves for this configuration
        for old_model in glob.glob(os.path.join(save_dir, f"{model_name}_epoch_*.pth")):
            os.remove(old_model)

        # Save new model
        save_path = os.path.join(save_dir, f"{model_name}_epoch_{epoch+1}.pth")
        torch.save(ae.state_dict(), save_path)
        print(f"Model saved to {save_path}")

d_model = 1536
n_dirs = d_model * 6
k = 32
auxk = 64 #256
batch_size = 1024
lr = 1e-4
epochs = 1
auxk_coef = 1/32
clip_grad = 1.0

# Create model name
model_name = f"{k}_{n_dirs}_{auxk}_final"

data = np.load("../data/vector_store/abstract_embeddings.npy")
data_tensor = torch.from_numpy(data).float()
dataset = TensorDataset(data_tensor)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

ae = FastAutoencoder(n_dirs, d_model, k, auxk).to(device)
init_from_data_(ae, data_tensor[:10000].to(device))

optimizer = optim.Adam(ae.parameters(), lr=lr)

In [53]:
# print abs gradient for topk/auxk and everything else -- encoder and decoder

train(ae, train_loader, optimizer, epochs, k, auxk_coef, clip_grad, model_name=model_name)

Epoch 1/1:   0%|          | 0/266 [00:00<?, ?it/s]

encoder tensor(2.4457e-06)
tensor(0.)


Epoch 1/1:   0%|          | 1/266 [00:03<17:34,  3.98s/it]

decoder tensor(8.5545e-06)
tensor(0.)
encoder tensor(2.2086e-06)
tensor(0.)


Epoch 1/1:   1%|          | 2/266 [00:07<15:19,  3.48s/it]

decoder tensor(8.2827e-06)
tensor(0.)
encoder tensor(2.0640e-06)
tensor(0.)


Epoch 1/1:   1%|          | 3/266 [00:10<14:13,  3.25s/it]

decoder tensor(8.2802e-06)
tensor(0.)


Epoch 1/1:   1%|          | 3/266 [00:11<17:08,  3.91s/it]


KeyboardInterrupt: 

In [1]:
import autointerp
from pathlib import Path

%load_ext autoreload
%autoreload 2

In [7]:
!ls ../data/vector_store/

abstract_embeddings.npy documents.pkl           keyword_index.json
abstract_texts.json     embeddings_matrix.npy   metadata.json
document_index.pkl      index_mapping.pkl


In [2]:
CONFIG_PATH = Path("../config.yaml")
DATA_DIR = Path("../data")
SAE_DATA_DIR = Path("sae_data")
feature_index = 1000
num_samples = 5 

analyzer = autointerp.NeuronAnalyzer(CONFIG_PATH, feature_index, num_samples)

In [31]:
top_abstracts, zero_abstracts = analyzer.get_feature_activations(num_samples)

(1536,)
(268065,)
[(0, 0.3231629149511315), (1, 0.3884177835290795), (2, 0.37229435697601637), (3, 0.4051744935201531), (4, 0.43787066220692483), (5, 0.40065434325844407), (6, 0.3293621663897055), (7, 0.30941542151710216), (8, 0.46078678381987526), (9, 0.49524984518679194)]


In [12]:
feature_index = 100
num_samples = 5

analyzer = autointerp.NeuronAnalyzer(CONFIG_PATH, feature_index, num_samples)

top_abstracts, zero_abstracts = analyzer.get_feature_activations(num_samples)
interpretation = analyzer.generate_interpretation(top_abstracts, zero_abstracts)
print(f"Interpretation: {interpretation}")

divider = 3
test_abstracts = [abstract for _, abstract, _ in top_abstracts[num_samples//divider:] + zero_abstracts[num_samples//divider:]]
# ground_truth = [1] * (num_samples//divider) + [0] * (num_samples//divider)
ground_truth = [1] * len(top_abstracts[num_samples//divider:]) + [0] * len(zero_abstracts[num_samples//divider:])

predictions = analyzer.predict_activations(interpretation, test_abstracts)
correlation, f1 = analyzer.evaluate_predictions(ground_truth, predictions)

print(f"Pearson correlation: {correlation}")
print(f"F1 score: {f1}")

KeyboardInterrupt: 