# 2. VQ-VAE Training (Stage 1)

**Objective:** Train the VQ-VAE model from `src.model.vae` on the full `P+C+S` sequences. We will run a small training loop here, plot the losses, and save the final model weights to `experiments/vqvae_stage1.pth`.

In [None]:
%pip install datasets transformers torch tqdm matplotlib

In [None]:
import sys
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

# Add 'src' to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.utils import (
    get_llm_tokenizer, MAX_SEQ_LEN, PATH_VQVAE_MODEL,
    VQ_CODEBOOK_SIZE
)
from src.dataset import VQVAE_Dataset
from src.model.vae import VQVAEModel

# --- Configuration ---
D_MODEL = 256
NUM_EPOCHS = 3 # Increase for a real run
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2.1 Load Tokenizer and Dataset

In [None]:
tokenizer = get_llm_tokenizer()
vocab_size = len(tokenizer)

print(f"Tokenizer vocabulary size (including new tokens): {vocab_size}")

raw_dataset = load_dataset("gsm8k", "main")['train']
train_dataset = VQVAE_Dataset(tokenizer, raw_dataset, max_length=MAX_SEQ_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Loaded {len(train_dataset)} samples for VQ-VAE training.")

## 2.2 Initialize Model and Optimizer

In [None]:
model = VQVAEModel(
    vocab_size=vocab_size,
    d_model=D_MODEL,
    num_embeddings=VQ_CODEBOOK_SIZE,
    max_seq_len=MAX_SEQ_LEN
).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"VQ-VAE Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

## 2.3 Training Loop

We'll run the training loop directly in the notebook to monitor its progress and plot the losses.

In [None]:
model.train()
losses = []
recon_losses = []
vq_losses = []

for epoch in range(NUM_EPOCHS):
    print(f"--- EPOCH {epoch+1}/{NUM_EPOCHS} ---")
    epoch_loss, epoch_recon, epoch_vq = 0, 0, 0
    
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        
        optimizer.zero_grad()
        
        total_loss, recon_loss, vq_loss = model(input_ids)
        
        total_loss.backward()
        optimizer.step()
        
        epoch_loss += total_loss.item()
        epoch_recon += recon_loss.item()
        epoch_vq += vq_loss.item()
    
    # Log average losses for the epoch
    avg_loss = epoch_loss / len(train_loader)
    avg_recon = epoch_recon / len(train_loader)
    avg_vq = epoch_vq / len(train_loader)
    
    losses.append(avg_loss)
    recon_losses.append(avg_recon)
    vq_losses.append(avg_vq)
    
    print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f} | Recon: {avg_recon:.4f} | VQ: {avg_vq:.4f}")

## 2.4 Visualize Losses

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(losses, label='Total Loss')
plt.plot(recon_losses, label='Reconstruction Loss', linestyle='--')
plt.plot(vq_losses, label='VQ Loss', linestyle=':')
plt.title('VQ-VAE Training Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

## 2.5 Save the Model

Finally, we save the trained VQ-VAE weights. These will be loaded by the next notebook to create the assorted dataset.

In [None]:
print(f"Saving VQ-VAE model to {PATH_VQVAE_MODEL}")
os.makedirs(os.path.dirname(PATH_VQVAE_MODEL), exist_ok=True)
torch.save(model.state_dict(), PATH_VQVAE_MODEL)
print("Model saved.")