In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
import os
from tqdm import tqdm
from model import *
from activation_dataset import *
from eval import *

# Setup logging
logging.basicConfig(
    filename="training.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Hyperparameters
input_dim = 3072
hidden_dim = 2 * input_dim
num_epochs = 100
lr = 1e-5
l1_reg = 0.01
batch_size = 512
num_samples_per_prompt = 256
num_prompts_to_load = 50 #512*4

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset iterator
dataset = load_dataset("Anthropic/hh-rlhf", split='train', streaming=True)
data_iter = iter(dataset)

# Model, Tokenizer, & Dataset
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
).to(device).eval()

# Autoencoder, Loss, & Optimizer
autoencoder = SparseAutoencoder(input_dim, hidden_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=lr)

In [None]:
#### Training loop
iteration = 0
for epoch in range(num_epochs):
    print('Epoch:', epoch)
    # Create a new dataloader for each epoch
    prompts = load_new_samples(data_iter, num_prompts_to_load)
    activation_dataset = ActivationDataset(prompts=prompts, tokenizer=tokenizer, model=model, num_samples_per_prompt=num_samples_per_prompt)
    activation_dataset.process_prompts()
    dataloader = DataLoader(activation_dataset, batch_size=batch_size, shuffle=True)

    for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
        inputs = batch['activations'].to(device)
        reconstructed = autoencoder(inputs)

        # Loss calculation (same as before)
        mse_loss = criterion(reconstructed, inputs)
        l1_loss = l1_reg * torch.norm(autoencoder.encoder[0].weight, p=1)
        loss = mse_loss + l1_loss

        # Backpropagation and optimization (same as before)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Normalize decoder weights (same as before)
        with torch.no_grad():
            autoencoder.decoder.weight.data /= torch.norm(autoencoder.decoder.weight.data, dim=0, keepdim=True)

        # Log every 100 iterations
        if (batch_idx + 1) % 100 == 0:
            logging.info(f"Epoch [{epoch+1}/{num_epochs}], Iteration [{batch_idx+1}], Loss: {loss.item():.4f}")

        # if iteration % 5000 == 0 and iteration != 0:
        if iteration % 5000 == 0:
            eval_dataloader = DataLoader(activation_dataset, batch_size=4, shuffle=True)
            evaluate_sae(autoencoder, eval_dataloader, tokenizer, iteration=iteration)  # Evaluate on the Dataloader
            del eval_dataloader
            autoencoder.train()
        iteration += 1