In [None]:
# import numpy as np
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer
# from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
# from mistral_common.protocol.instruct.request import ChatCompletionRequest
# from mistral_common.protocol.instruct.messages import UserMessage

### SAE Implementation

https://transformer-circuits.pub/2024/april-update/index.html#training-saes


19token -> B*19*4096 B*19 4096

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, lambda_reg):
        super(SparseAutoencoder, self).__init__()
        self.relu = nn.ReLU()
        # Encoder and Decoder weights with initializations as described
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)

        self.encoder.bias.data.fill_(0)
        self.decoder.bias.data.fill_(0)

        # Initialize encoder and decoder weights
        with torch.no_grad():
            for i in range(hidden_dim):
                # Create a random vector for each column
                random_vector = torch.randn(input_dim)
                # Set L2 norm to a random value between 0.05 and 1
                norm = torch.FloatTensor(1).uniform_(0.05, 1.0)
                # Normalize the vector and scale it by the chosen norm
                self.decoder.weight[:, i] = (random_vector / random_vector.norm()) * norm

        # Initialize W_e as W_d^T
        self.encoder.weight.data = self.decoder.weight.data.T.clone()

        self.lambda_reg = lambda_reg

    def forward(self, x):
        # Forward pass
        hidden = self.relu(self.encoder(x)) #f(x) Shape :
        reconstructed = self.decoder(hidden) #
        return reconstructed, hidden

    def compute_loss(self, x, reconstructed, hidden):
        # Reconstruction loss
        reconstruction_loss = torch.mean((x - reconstructed) ** 2)

        # Sparsity penalty
        sparsity_loss = self.lambda_reg * torch.sum(
            torch.abs(hidden) @ torch.norm(self.decoder.weight, dim=0)
        )


        return reconstruction_loss + sparsity_loss


In [3]:
data = torch.load("./dataset/residual_data_batch_1.pt")

  data = torch.load("./dataset/residual_data_batch_1.pt")


In [6]:
data[0]['embedding'].shape

torch.Size([3072])

In [8]:
# input data is form of (Batch, hidden dim)

random_data = torch.randn(100, 3072)

In [10]:
import torch
from torch.utils.data import Dataset, DataLoader
import glob

class EmbeddingDataset(Dataset):
    def __init__(self, path, file_pattern):
        # Load all `.pt` files based on the pattern
        self.file_path = path
        self.files = sorted(glob.glob(path+file_pattern))
        print(f"Num of files found : {len(self.files)}")
        self.data = []

        # Read and store all embeddings from all files
        for file in self.files:
            batch_data = torch.load(file)
            # Extract embeddings and flatten them into a list
            self.data.extend([item["embedding"] for item in batch_data])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Define your dataset
file_pattern = "residual_data_batch_*.pt"  # Adjust the path if needed
path = "./dataset/"
embedding_dataset = EmbeddingDataset(path, file_pattern)

# # Example usage
# for batch in embedding_loader:
#     print(batch.shape)  # Prints: [batch_size, embedding_dim]
#     break

Num of files found : 36


  batch_data = torch.load(file)


torch.Size([32, 3072])


In [13]:
batch_size = 2048  # Adjust as needed
embedding_loader = DataLoader(embedding_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Hyperparameters

# 1,048,576 (~1M), 4,194,304 (~4M), and 33,554,432 (~34M)

input_dim = 3072  # Input and output dimensions
hidden_dim = 3072*10  # Hidden layer dimension
# hidden_dims = [128, 512, 4096]
final_lambda = 5  # Final regularization strength after 5% of training steps
learning_rate = 5e-5

num_epochs = 100 #200000  # as per scaling laws
lambda_increase_steps = int(num_epochs * 0.05)
# 200k

# Dataset scaling
# X is data tensor (embeddings?)
# X = X * (input_dim ** 0.5) / torch.norm(X, dim=1, keepdim=True).mean()  # Scaling dataset

# Model, optimizer
model = SparseAutoencoder(input_dim, hidden_dim, 0) # initial lambda is 0
model = model.cuda()
# Adam optimizer beta1=0.9, beta2=0.999 and no weight decay
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

loss_values = []

# Training loop
for epoch in range(num_epochs):
    print(f"Starting Epoch [{epoch}/{num_epochs}]")
    avg_loss = 0
    done_vals = 0
    # Linearly increase λ over the first 5% of steps
    if epoch < lambda_increase_steps:
        model.lambda_reg = final_lambda * (epoch / lambda_increase_steps)
    else:
        model.lambda_reg = final_lambda

    # Decay learning rate linearly over the last 20% of training
    if epoch > num_epochs * 0.8:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 1 - (epoch - num_epochs * 0.8) / (num_epochs * 0.2)

    # Change the Batch sampling 
    # indices = torch.randperm(len(X))[:batch_size]
    # batch = X[indices] 
    for batch_num, batch in enumerate(embedding_loader):
        batch = batch.cuda().to(torch.float32)
        # Forward pass
        reconstructed, hidden = model(batch)

        #print # of activated neurons

        # Compute loss
        loss = model.compute_loss(batch, reconstructed, hidden)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1)  # Gradient clipping
        optimizer.step()
        avg_loss += loss.item()
        done_vals += 1
        if batch_num % 200 == 0:
            print(f"\t Batch {batch_num} Loss:{avg_loss/done_vals}")
        
    loss_values.append(avg_loss/done_vals)
    print(f"Epoch [{epoch}/{num_epochs}], Loss: {loss_values}")

# Conceptually a feature’s activation is now f i ∣ ∣ W d , i ∣ ∣ 2 f i ​ ∣∣W d,i ​ ∣∣ 2 ​ instead of f i f i ​ .
# Normalize W_d and adjust encoder and bias after training
with torch.no_grad():
    W_d_norm = model.decoder.weight.norm(dim=0, keepdim=True)
    model.decoder.weight /= W_d_norm
    model.encoder.weight *= W_d_norm
    model.encoder.bias /= W_d_norm


Starting Epoch [0/100]
	 Batch 0 Loss:212.5830535888672
	 Batch 200 Loss:7.379077021310579


In [17]:
batch.dtype

torch.bfloat16