# RQ-VAE Training

Train a Residual Quantized Variational AutoEncoder (RQ-VAE) to learn semantic IDs for items in a catalogue.

**Environment:** RunPod Jupyter with GPU

**What this does:**
1. Load item catalogue (JSONL format)
2. Generate embeddings using a pretrained embedding model
3. Train RQ-VAE to compress embeddings into discrete codes
4. Save trained model (optionally to W&B artifacts)

**Outputs:**
- `models/rqvae_model.pt` - Trained RQ-VAE model

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Navigate to project root if needed
import sys
import os

if os.path.abspath(".").endswith("notebooks"):
    repo_root = os.path.abspath("..")
    print(f"Current directory: {os.getcwd()}, changing to {repo_root}")
    if repo_root not in sys.path:
        sys.path.insert(0, repo_root)
    os.chdir(repo_root)
else:
    print(f"Already in project root: {os.getcwd()}")

In [None]:
from unsloth import FastLanguageModel # Always import unsloth first

In [None]:
import torch
import json
from pathlib import Path
import lightning as L
from torch.utils.data import DataLoader, random_split

from src.rqvae import (
    SemanticRQVAE,
    SemanticRQVAEConfig,
    RQVAETrainer,
    RqvaeTrainConfig,
    WandbArtifactCallback,
    ItemEmbeddingDataset,
)

print("Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
# Paths
CATALOGUE_PATH = "data/mcf_articles.jsonl"  # Input catalogue
EMBEDDINGS_CACHE = "data/embeddings_mcf.pt"  # Cache for embeddings

# Embedding model
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"  # 1024-dim embeddings

# Training configuration - all training-related parameters in one place
train_config = RqvaeTrainConfig(
    # Model architecture
    embedding_dim=1024,       # Match embedding model output
    hidden_dim=256,           # Encoder/decoder hidden dimension
    codebook_size=64,         # Number of codes per quantizer
    num_quantizers=3,         # Number of quantization levels
    
    # Training hyperparameters
    learning_rate=1e-3,
    max_epochs=500,
    batch_size=512,
    train_split=0.95,
    
    # W&B artifact logging
    log_wandb_artifacts=False,
    artifact_name="rqvae-model",
    
    # Output paths
    model_save_path="models/rqvae_model.pt",
)

# Create output directories
Path("models").mkdir(exist_ok=True)

print("Configuration:")
print(f"  Catalogue: {CATALOGUE_PATH}")
print(f"  Embedding model: {EMBEDDING_MODEL}")
print(f"  Codebook: {train_config.num_quantizers} levels x {train_config.codebook_size} codes")
print(f"  Output model: {train_config.model_save_path}")
print(f"  Log W&B artifacts: {train_config.log_wandb_artifacts}")

## 3. Verify Environment Variables

In [None]:
# Check for required environment variables
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

print("=== Environment Variables ===")
if WANDB_API_KEY:
    print(f"WANDB_API_KEY is set (length: {len(WANDB_API_KEY)})")
else:
    print("WANDB_API_KEY not set - Weights & Biases logging disabled")
    print("   Set it in RunPod Settings -> Environment Variables")
    print("   Or run: wandb login")

## 4. Initialize Weights & Biases

In [None]:
import wandb
from lightning.pytorch.loggers import WandbLogger

# Initialize wandb using environment variable
if WANDB_API_KEY:
    wandb.init(
        project="semantic-id-recommender",
        name=f"rqvae_{train_config.codebook_size}x{train_config.num_quantizers}",
        config={
            "embedding_model": EMBEDDING_MODEL,
            "embedding_dim": train_config.embedding_dim,
            "hidden_dim": train_config.hidden_dim,
            "codebook_size": train_config.codebook_size,
            "num_quantizers": train_config.num_quantizers,
            "learning_rate": train_config.learning_rate,
            "max_epochs": train_config.max_epochs,
            "batch_size": train_config.batch_size,
            "train_split": train_config.train_split,
        }
    )

    # Create wandb logger for Lightning
    wandb_logger = WandbLogger(
        project="semantic-id-recommender",
        log_model=False,
    )

    print("✓ Weights & Biases initialized")
    print(f"  Project: semantic-id-recommender")
    print(f"  Run: {wandb.run.name}")
    print(f"  URL: {wandb.run.url}")
else:
    wandb_logger = None
    print("⚠️  Wandb logging disabled (WANDB_API_KEY not set)")

## 5. Load Catalogue and Generate Embeddings

In [None]:
# Load catalogue
items = []
with open(CATALOGUE_PATH) as f:
    for line in f:
        items.append(json.loads(line))

print(f"Loaded {len(items)} items from catalogue")
print(f"\nSample items:")
for item in items[:3]:
    print(f"  [{item.get('category', 'N/A')}] {item['item_id']}: {item['title'][:60]}...")

In [None]:
# Generate embeddings (or load from cache)
dataset = ItemEmbeddingDataset.from_catalogue(
    catalogue_path=CATALOGUE_PATH,
    embedding_model=EMBEDDING_MODEL,
    cache_path=EMBEDDINGS_CACHE,
    batch_size=32,
    fields=["title", "slug", "introduction", "content"],
)
embeddings = dataset.embeddings

print(f"\nDataset info:")
print(f"  Number of items: {len(dataset)}")
print(f"  Embedding dimension: {embeddings.shape[1]}")
print(f"  Embedding shape: {embeddings.shape}")

## 6. Create and Train RQ-VAE Model

In [None]:
# Create RQ-VAE model using train_config
config = train_config.to_model_config()
trainer_module = RQVAETrainer(config=config, learning_rate=train_config.learning_rate)

print(f"RQ-VAE Configuration:")
print(f"  Embedding dim: {config.embedding_dim}")
print(f"  Hidden dim: {config.hidden_dim}")
print(f"  Codebook: {config.num_quantizers} levels x {config.codebook_size} codes")
print(f"  Total semantic ID space: {config.codebook_size ** config.num_quantizers:,} unique IDs")

In [None]:
# Split dataset into train/val
train_size = int(train_config.train_split * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=train_config.batch_size, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=train_config.batch_size, num_workers=8, shuffle=False)

print(f"Data split:")
print(f"  Train: {len(train_dataset)} items")
print(f"  Validation: {len(val_dataset)} items")

In [None]:
# Build callbacks list
callbacks = []
if train_config.log_wandb_artifacts:
    callbacks.append(WandbArtifactCallback(
        train_config=train_config,
        embedding_model=EMBEDDING_MODEL,
    ))

# Train the model
trainer = L.Trainer(
    max_epochs=train_config.max_epochs,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=5,
    logger=wandb_logger,
    callbacks=callbacks,
)

print("Starting training...")
trainer.fit(trainer_module, train_loader, val_loader)
print("✓ Training complete!")

## 7. Evaluate Model

In [None]:
# Compute codebook statistics on full dataset
model = trainer_module.model
model.eval()
device = next(model.parameters()).device

with torch.no_grad():
    all_indices = model.get_semantic_ids(embeddings.to(device))
    stats = model.compute_codebook_stats(all_indices)

print(f"\n=== Codebook Statistics (Full Dataset) ===")
print(f"Average perplexity: {stats['avg_perplexity']:.2f} / {config.codebook_size} (max)")
print(f"Average usage: {stats['avg_usage']*100:.1f}%")
print(f"\nPer-level breakdown:")
for q in range(config.num_quantizers):
    perp = stats['perplexity_per_level'][q].item()
    usage = stats['usage_per_level'][q].item() * 100
    print(f"  Level {q}: perplexity={perp:.1f}, usage={usage:.1f}%")

In [None]:
# Check for semantic ID collisions
unique_ids = len(set(model.semantic_id_to_string(all_indices)))
print(f"\n=== Semantic ID Quality ===")
print(f"Total items: {len(embeddings)}")
print(f"Unique semantic IDs: {unique_ids}")
print(f"Collision rate: {(1 - unique_ids/len(embeddings))*100:.1f}%")

if unique_ids < len(embeddings):
    print(f"\n⚠️  Warning: {len(embeddings) - unique_ids} collisions detected")
    print("   Consider increasing codebook size or number of quantizers")
else:
    print("\n✓ No collisions - all items have unique semantic IDs")

# Log evaluation metrics to wandb (if initialized)
if WANDB_API_KEY and wandb.run:
    wandb.log({
        "eval/avg_perplexity": stats['avg_perplexity'].item(),
        "eval/avg_usage": stats['avg_usage'].item(),
        "eval/unique_semantic_ids": unique_ids,
        "eval/total_items": len(embeddings),
        "eval/collision_rate": (1 - unique_ids/len(embeddings)),
    })

    # Log per-level statistics
    for q in range(config.num_quantizers):
        wandb.log({
            f"eval/level_{q}_perplexity": stats['perplexity_per_level'][q].item(),
            f"eval/level_{q}_usage": stats['usage_per_level'][q].item(),
        })

    print("\n✓ Evaluation metrics logged to wandb")

## 8. Test Loading from W&B Artifact (Optional)

If `log_wandb_artifacts=True`, test that we can load the model from the artifact.

In [None]:
if train_config.log_wandb_artifacts and wandb.run:
    # Download the artifact we just logged
    artifact = wandb.use_artifact(f"{train_config.artifact_name}:latest")
    artifact_dir = artifact.download()
    
    # Load the model from artifact
    checkpoint = torch.load(f"{artifact_dir}/rqvae_model.pt")
    artifact_config = SemanticRQVAEConfig(**checkpoint["config"])
    artifact_model = SemanticRQVAE(artifact_config)
    artifact_model.load_state_dict(checkpoint["model_state_dict"])
    artifact_model.eval()
    
    print(f"Loaded model from W&B artifact: {train_config.artifact_name}:latest")
    print(f"  Config: {checkpoint['config']}")
    
    # Verify it produces same outputs as the trained model
    with torch.no_grad():
        test_embeddings = embeddings[:5].to(device)
        original_ids = model.semantic_id_to_string(model.get_semantic_ids(test_embeddings))
        artifact_ids = artifact_model.semantic_id_to_string(artifact_model.get_semantic_ids(test_embeddings.cpu()))
    
    if original_ids == artifact_ids:
        print("Artifact model produces identical outputs to trained model")
    else:
        print("Warning: Artifact model outputs differ from trained model")
        for i, (orig, art) in enumerate(zip(original_ids, artifact_ids)):
            match = "==" if orig == art else "!="
            print(f"  {orig} {match} {art}")
else:
    print("Skipping artifact test (log_wandb_artifacts=False or no wandb run)")