# RQ-VAE Training

Train a Residual Quantized Variational AutoEncoder (RQ-VAE) to learn semantic IDs for items in a catalogue.

**Environment:** RunPod Jupyter with GPU

**What this does:**
1. Load item catalogue (JSONL format)
2. Generate embeddings using a pretrained embedding model
3. Train RQ-VAE to compress embeddings into discrete codes
4. Save trained model (optionally to W&B artifacts)

**Outputs:**
- `models/rqvae_model.pt` - Trained RQ-VAE model

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Navigate to project root if needed
import sys
import os

if os.path.abspath(".").endswith("notebooks"):
    repo_root = os.path.abspath("..")
    print(f"Current directory: {os.getcwd()}, changing to {repo_root}")
    if repo_root not in sys.path:
        sys.path.insert(0, repo_root)
    os.chdir(repo_root)
else:
    print(f"Already in project root: {os.getcwd()}")

In [None]:
from unsloth import FastLanguageModel # Always import unsloth first

In [None]:
import torch
from pathlib import Path

from src.rqvae import (
    train,
    RqvaeTrainConfig,
    SemanticRQVAE,
    SemanticRQVAEConfig,
)

print("Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

All training parameters are consolidated into `RqvaeTrainConfig`.

In [None]:
import yaml

# Load configuration from YAML
with open("notebooks/rqvae_config.yaml") as f:
    config_dict = yaml.safe_load(f)

config = RqvaeTrainConfig(**config_dict)

# Create output directories
Path("models").mkdir(exist_ok=True)

print("Configuration:")
print(f"  Catalogue: {config.catalogue_path}")
print(f"  Embedding model: {config.embedding_model}")
print(f"  Codebook: {config.num_quantizers} levels x {config.codebook_size} codes")
print(f"  Total semantic ID space: {config.codebook_size ** config.num_quantizers:,} unique IDs")
print(f"  Output model: {config.model_save_path}")
print(f"  W&B project: {config.wandb_project}")
print(f"  Log W&B artifacts: {config.log_wandb_artifacts}")

## 3. Train RQ-VAE

The `train()` function handles the complete training lifecycle:
1. Initialize W&B (if project provided)
2. Load catalogue and generate/cache embeddings
3. Split dataset into train/val
4. Train the model with Lightning
5. Evaluate and compute final metrics
6. Log summary metrics to W&B
7. Clean up W&B run

In [None]:
# Run end-to-end training
result = train(config)

print("\n" + "="*50)
print("Training Complete!")
print("="*50)

## 4. Review Results

The `train()` function returns a `TrainResult` containing:
- `model`: Trained SemanticRQVAE model
- `config`: Model configuration
- `metrics`: Dictionary of final evaluation metrics
- `semantic_ids`: Mapping of item_id -> semantic_id string

In [None]:
# Display final metrics
print("=== Final Metrics ===")
print(f"Average perplexity: {result.metrics['avg_perplexity']:.2f} / {config.codebook_size} (max)")
print(f"Average usage: {result.metrics['avg_usage']*100:.1f}%")
print(f"Total items: {result.metrics['total_items']}")
print(f"Unique semantic IDs: {result.metrics['unique_semantic_ids']}")
print(f"Collision rate: {result.metrics['collision_rate']*100:.2f}%")

print("\nPer-level breakdown:")
for q in range(config.num_quantizers):
    perp = result.metrics[f'level_{q}_perplexity']
    usage = result.metrics[f'level_{q}_usage'] * 100
    print(f"  Level {q}: perplexity={perp:.1f}, usage={usage:.1f}%")

In [None]:
# Show sample semantic IDs
print("=== Sample Semantic IDs ===")
for i, (item_id, sem_id) in enumerate(list(result.semantic_ids.items())[:5]):
    print(f"  {item_id}: {sem_id}")

## 5. Test Loading from W&B Artifact

Verify that we can load the trained model from the W&B artifact.

In [None]:
import wandb

if config.log_wandb_artifacts:
    # Initialize a new wandb run to download the artifact
    wandb.init(
        project=config.wandb_project,
        job_type="artifact-verification",
    )
    
    # Download the artifact we just logged
    artifact = wandb.use_artifact(f"{config.artifact_name}:latest")
    artifact_dir = artifact.download()
    
    print(f"Downloaded artifact to: {artifact_dir}")
    print(f"Artifact metadata: {artifact.metadata}")
    
    # Load the model from artifact
    checkpoint = torch.load(f"{artifact_dir}/rqvae_model.pt")
    artifact_config = SemanticRQVAEConfig(**checkpoint["config"])
    artifact_model = SemanticRQVAE(artifact_config)
    artifact_model.load_state_dict(checkpoint["model_state_dict"])
    artifact_model.eval()
    
    print(f"\nLoaded model from W&B artifact: {config.artifact_name}:latest")
    print(f"  Config: {checkpoint['config']}")
    
    wandb.finish()
else:
    print("Skipping artifact verification (log_wandb_artifacts=False)")

In [None]:
if config.log_wandb_artifacts:
    # Verify artifact model produces same outputs as trained model
    from src.rqvae import ItemEmbeddingDataset
    
    # Load embeddings to test
    dataset = ItemEmbeddingDataset.from_embeddings_file(config.embeddings_cache_path)
    test_embeddings = dataset.embeddings[:5]
    
    # Get device of trained model
    device = next(result.model.parameters()).device
    
    with torch.no_grad():
        # Get IDs from trained model
        original_ids = result.model.semantic_id_to_string(
            result.model.get_semantic_ids(test_embeddings.to(device))
        )
        
        # Get IDs from artifact model (on CPU)
        artifact_ids = artifact_model.semantic_id_to_string(
            artifact_model.get_semantic_ids(test_embeddings)
        )
    
    print("=== Artifact Verification ===")
    all_match = True
    for i, (orig, art) in enumerate(zip(original_ids, artifact_ids)):
        match = orig == art
        all_match = all_match and match
        status = "OK" if match else "MISMATCH"
        print(f"  [{status}] {orig}")
    
    if all_match:
        print("\nArtifact model produces identical outputs to trained model")
    else:
        print("\nWARNING: Artifact model outputs differ from trained model!")
else:
    print("Skipping artifact verification (log_wandb_artifacts=False)")