# RQ-VAE Training

Train a Residual Quantized Variational AutoEncoder (RQ-VAE) to learn semantic IDs for items in a catalogue.

**Environment:** RunPod Jupyter with GPU

**What this does:**
1. Load item catalogue (JSONL format)
2. Generate embeddings using a pretrained embedding model
3. Train RQ-VAE to compress embeddings into discrete codes
4. Generate semantic IDs for all items
5. Save trained model and semantic ID mappings

**Outputs:**
- `models/rqvae_model.pt` - Trained RQ-VAE model
- `data/semantic_ids.json` - Item ID to Semantic ID mappings

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Navigate to project root if needed
import sys
import os

if os.path.abspath(".").endswith("notebooks"):
    repo_root = os.path.abspath("..")
    print(f"Current directory: {os.getcwd()}, changing to {repo_root}")
    if repo_root not in sys.path:
        sys.path.insert(0, repo_root)
    os.chdir(repo_root)
else:
    print(f"Already in project root: {os.getcwd()}")

In [None]:
from unsloth import FastLanguageModel # Always import unsloth first

In [None]:
import torch
import json
from pathlib import Path
import lightning as L
from torch.utils.data import DataLoader, random_split

from src.rqvae.model import SemanticRQVAE, SemanticRQVAEConfig
from src.rqvae.trainer import RQVAETrainer
from src.rqvae.dataset import ItemEmbeddingDataset

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
# Paths
CATALOGUE_PATH = "data/ag_news_500.jsonl"  # Input catalogue
EMBEDDINGS_CACHE = "data/embeddings.pt"  # Cache for embeddings
SEMANTIC_IDS_PATH = "data/semantic_ids.json"  # Output mappings
MODEL_SAVE_PATH = "models/rqvae_model.pt"  # Output trained model

# Embedding model
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"  # 1024-dim embeddings

# RQ-VAE hyperparameters
EMBEDDING_DIM = 1024      # Match embedding model output
HIDDEN_DIM = 256          # Encoder/decoder hidden dimension
CODEBOOK_SIZE = 32        # Number of codes per quantizer
NUM_QUANTIZERS = 3        # Number of quantization levels
LEARNING_RATE = 1e-3
MAX_EPOCHS = 50
BATCH_SIZE = 64
TRAIN_SPLIT = 0.9         # 90% train, 10% validation

# Create output directories
Path("models").mkdir(exist_ok=True)
Path("data").mkdir(exist_ok=True)

print("Configuration:")
print(f"  Catalogue: {CATALOGUE_PATH}")
print(f"  Embedding model: {EMBEDDING_MODEL}")
print(f"  Codebook: {NUM_QUANTIZERS} levels x {CODEBOOK_SIZE} codes")
print(f"  Output model: {MODEL_SAVE_PATH}")

## 2.1. Initialize Weights & Biases

Set up wandb logging to track training metrics, hyperparameters, and model performance.

In [None]:
import wandb
from lightning.pytorch.loggers import WandbLogger

# Initialize wandb
wandb.init(
    project="semantic-id-recommender",
    name=f"rqvae_{CODEBOOK_SIZE}x{NUM_QUANTIZERS}",  # Run name
    config={
        "embedding_model": EMBEDDING_MODEL,
        "embedding_dim": EMBEDDING_DIM,
        "hidden_dim": HIDDEN_DIM,
        "codebook_size": CODEBOOK_SIZE,
        "num_quantizers": NUM_QUANTIZERS,
        "learning_rate": LEARNING_RATE,
        "max_epochs": MAX_EPOCHS,
        "batch_size": BATCH_SIZE,
        "train_split": TRAIN_SPLIT,
    }
)

# Create wandb logger for Lightning
wandb_logger = WandbLogger(
    project="semantic-rqvae",
    log_model=False,  # We'll manually log the model at the end
)

print("✓ Weights & Biases initialized")
print(f"  Project: semantic-rqvae")
print(f"  Run: {wandb.run.name}")
print(f"  URL: {wandb.run.url}")

## 3. Load Catalogue and Generate Embeddings

In [None]:
# Load catalogue
items = []
with open(CATALOGUE_PATH) as f:
    for line in f:
        items.append(json.loads(line))

print(f"Loaded {len(items)} items from catalogue")
print(f"\nSample items:")
for item in items[:3]:
    print(f"  [{item.get('category', 'N/A')}] {item['item_id']}: {item['title'][:60]}...")

In [None]:
# Generate embeddings (or load from cache)
dataset = ItemEmbeddingDataset.from_catalogue(
    catalogue_path=CATALOGUE_PATH,
    embedding_model=EMBEDDING_MODEL,
    cache_path=EMBEDDINGS_CACHE,
)

embeddings = dataset.embeddings

print(f"\nDataset info:")
print(f"  Number of items: {len(dataset)}")
print(f"  Embedding dimension: {embeddings.shape[1]}")
print(f"  Embedding shape: {embeddings.shape}")

## 4. Train RQ-VAE

In [None]:
# Create RQ-VAE model
config = SemanticRQVAEConfig(
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    codebook_size=CODEBOOK_SIZE,
    num_quantizers=NUM_QUANTIZERS,
    threshold_ema_dead_code=1,
)

trainer_module = RQVAETrainer(config=config, learning_rate=LEARNING_RATE)

print(f"RQ-VAE Configuration:")
print(f"  Embedding dim: {config.embedding_dim}")
print(f"  Hidden dim: {config.hidden_dim}")
print(f"  Codebook: {config.num_quantizers} levels x {config.codebook_size} codes")
print(f"  Total semantic ID space: {config.codebook_size ** config.num_quantizers:,} unique IDs")

In [None]:
# Split dataset into train/val
train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Data split:")
print(f"  Train: {len(train_dataset)} items")
print(f"  Validation: {len(val_dataset)} items")

In [None]:
# Train the model
trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=1,
    logger=wandb_logger,
)

print("Starting training...")
trainer.fit(trainer_module, train_loader, val_loader)
print("✓ Training complete!")

## 5. Evaluate Model

In [None]:
# Check training diagnostics
print("=== Final Training Metrics ===")
logged_metrics = trainer.logged_metrics
for key, value in sorted(logged_metrics.items()):
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

In [None]:
# Compute codebook statistics on full dataset
model = trainer_module.model
model.eval()
device = next(model.parameters()).device

with torch.no_grad():
    all_indices = model.get_semantic_ids(embeddings.to(device))
    stats = model.compute_codebook_stats(all_indices)

print(f"\n=== Codebook Statistics (Full Dataset) ===")
print(f"Average perplexity: {stats['avg_perplexity']:.2f} / {config.codebook_size} (max)")
print(f"Average usage: {stats['avg_usage']*100:.1f}%")
print(f"\nPer-level breakdown:")
for q in range(config.num_quantizers):
    perp = stats['perplexity_per_level'][q].item()
    usage = stats['usage_per_level'][q].item() * 100
    print(f"  Level {q}: perplexity={perp:.1f}, usage={usage:.1f}%")

In [None]:
# Check for semantic ID collisions
unique_ids = len(set(model.semantic_id_to_string(all_indices)))
print(f"\n=== Semantic ID Quality ===")
print(f"Total items: {len(embeddings)}")
print(f"Unique semantic IDs: {unique_ids}")
print(f"Collision rate: {(1 - unique_ids/len(embeddings))*100:.1f}%")

if unique_ids < len(embeddings):
    print(f"\n⚠️  Warning: {len(embeddings) - unique_ids} collisions detected")
    print("   Consider increasing codebook size or number of quantizers")
else:
    print("\n✓ No collisions - all items have unique semantic IDs")

# Log evaluation metrics to wandb
wandb.log({
    "eval/avg_perplexity": stats['avg_perplexity'].item(),
    "eval/avg_usage": stats['avg_usage'].item(),
    "eval/unique_semantic_ids": unique_ids,
    "eval/total_items": len(embeddings),
    "eval/collision_rate": (1 - unique_ids/len(embeddings)),
})

# Log per-level statistics
for q in range(config.num_quantizers):
    wandb.log({
        f"eval/level_{q}_perplexity": stats['perplexity_per_level'][q].item(),
        f"eval/level_{q}_usage": stats['usage_per_level'][q].item(),
    })

print("\n✓ Evaluation metrics logged to wandb")

## 6. Generate and Save Semantic IDs

In [None]:
# Generate semantic IDs for all items
model.eval()
device = next(model.parameters()).device

with torch.no_grad():
    indices = model.get_semantic_ids(embeddings.to(device))
    semantic_strings = model.semantic_id_to_string(indices)

# Create bidirectional mappings
item_to_semantic = {}
semantic_to_item = {}

for i, item in enumerate(items):
    item_id = item["item_id"]
    sem_id = semantic_strings[i]
    item_to_semantic[item_id] = {
        "codes": indices[i].cpu().tolist(),
        "semantic_id": sem_id,
    }
    semantic_to_item[sem_id] = item_id

print("\nExample semantic IDs:")
for i in range(min(5, len(items))):
    title = items[i]['title'][:40]
    print(f"  {title:40} -> {semantic_strings[i]}")

print(f"\n✓ Generated {len(semantic_to_item)} unique semantic IDs")

In [None]:
# Save semantic ID mappings
mapping = {
    "item_to_semantic": item_to_semantic,
    "semantic_to_item": semantic_to_item,
    "config": {
        "num_quantizers": config.num_quantizers,
        "codebook_size": config.codebook_size,
        "embedding_dim": config.embedding_dim,
        "hidden_dim": config.hidden_dim,
    }
}

with open(SEMANTIC_IDS_PATH, "w") as f:
    json.dump(mapping, f, indent=2)

print(f"✓ Saved semantic ID mappings to {SEMANTIC_IDS_PATH}")

## 7. Save Trained Model

In [None]:
# Save the trained model
checkpoint = {
    "model_state_dict": model.state_dict(),
    "config": {
        "embedding_dim": config.embedding_dim,
        "hidden_dim": config.hidden_dim,
        "codebook_size": config.codebook_size,
        "num_quantizers": config.num_quantizers,
        "threshold_ema_dead_code": config.threshold_ema_dead_code,
    },
    "training_info": {
        "final_train_loss": logged_metrics.get("train/loss", None),
        "final_val_loss": logged_metrics.get("val/loss", None),
        "num_items": len(items),
        "unique_semantic_ids": unique_ids,
    }
}

torch.save(checkpoint, MODEL_SAVE_PATH)
print(f"✓ Saved trained model to {MODEL_SAVE_PATH}")
print(f"  Model size: {Path(MODEL_SAVE_PATH).stat().st_size / 1024 / 1024:.2f} MB")

## 8. Test Model Loading

In [None]:
# Test that we can reload the model
checkpoint = torch.load(MODEL_SAVE_PATH)

# Recreate model from saved config
loaded_config = SemanticRQVAEConfig(**checkpoint["config"])
loaded_model = SemanticRQVAE(loaded_config)
loaded_model.load_state_dict(checkpoint["model_state_dict"])
loaded_model.eval()

print("✓ Successfully reloaded model from checkpoint")
print(f"\nCheckpoint info:")
print(f"  Config: {checkpoint['config']}")
if "training_info" in checkpoint:
    print(f"  Training info: {checkpoint['training_info']}")

# Verify it produces same outputs
with torch.no_grad():
    test_embedding = embeddings[:5]
    original_ids = model.semantic_id_to_string(model.get_semantic_ids(test_embedding.to(device)))
    loaded_ids = loaded_model.semantic_id_to_string(loaded_model.get_semantic_ids(test_embedding.to(device)))
    
if original_ids == loaded_ids:
    print("\n✓ Loaded model produces identical outputs")
else:
    print("\n⚠️  Warning: Loaded model outputs differ")

## 9. Upload to HuggingFace Hub (Optional)

This section demonstrates how to upload your trained model to HuggingFace Hub for easy sharing and deployment.

**Prerequisites:**
- Create a HuggingFace account at https://huggingface.co
- Generate an access token: https://huggingface.co/settings/tokens (with write permissions)
- Install huggingface_hub: `pip install huggingface_hub`

In [None]:
# Import HuggingFace Hub utilities
from src.rqvae.hub import save_model_for_hub, upload_to_hub

print("✓ HuggingFace Hub utilities imported")

In [None]:
# Step 1: Prepare model files for upload
HUB_DIR = "models/rqvae_hub"  # Temporary directory for hub files

save_model_for_hub(
    model=model,
    save_dir=HUB_DIR,
    semantic_ids_path=SEMANTIC_IDS_PATH,
    training_info={
        "final_train_loss": float(logged_metrics.get("train/loss", 0)),
        "final_val_loss": float(logged_metrics.get("val/loss", 0)),
        "num_items": len(items),
        "unique_semantic_ids": unique_ids,
        "codebook_usage": float(stats['avg_usage']),
        "perplexity": float(stats['avg_perplexity']),
    }
)

print(f"\n✓ Model prepared for HuggingFace Hub in {HUB_DIR}")

In [None]:
# Step 2: Upload to HuggingFace Hub
#
# IMPORTANT: Update these settings before running!
#
# To run this cell:
# 1. Set your HuggingFace username and desired repo name
# 2. Either set HF_TOKEN environment variable or pass token directly
# 3. Run this cell to upload

REPO_ID = "your-username/semantic-rqvae"  # Change this!
PRIVATE = False  # Set to True for private repo

print("Upload Configuration:")
print(f"  Repo ID: {REPO_ID}")
print(f"  Private: {PRIVATE}")
print(f"  Model directory: {HUB_DIR}")
print("\nTo upload:")
print("  1. Update REPO_ID with your HuggingFace username")
print("  2. Set HF_TOKEN environment variable or pass token parameter")
print()

url = upload_to_hub(
    model_dir=HUB_DIR,
    repo_id=REPO_ID,
    private=PRIVATE,
    token=None,  # Uses HF_TOKEN env var, or pass your token here
    commit_message="Upload RQ-VAE semantic ID model",
)
print(f"\n✓ Model available at: {url}")

## 10. Download from HuggingFace Hub (Optional)

This section demonstrates how to download and use a model from HuggingFace Hub.
This is useful for deployment on Modal or other cloud platforms.

In [None]:
# Method 1: Load model directly into memory
from src.rqvae.hub import load_from_hub

# Example: Load from HuggingFace Hub
# Update REPO_ID to match your uploaded model

REPO_ID = "your-username/semantic-rqvae"

loaded_model, semantic_mappings = load_from_hub(
    repo_id=REPO_ID,
    token=None,  # Only needed for private repos
)

print("\n✓ Model loaded from HuggingFace Hub")
print(f"  Semantic mappings available: {semantic_mappings is not None}")

# Test the loaded model
with torch.no_grad():
    test_ids = loaded_model.get_semantic_ids(embeddings[:3].to(device))
    test_strings = loaded_model.semantic_id_to_string(test_ids)
    print(f"\nTest semantic IDs from loaded model:")
    for i, sem_id in enumerate(test_strings):
        print(f"  Item {i}: {sem_id}")

## 11. Finish Wandb Run

In [None]:
# Save model artifact to wandb (optional)
artifact = wandb.Artifact(
    name="rqvae-model",
    type="model",
    description=f"RQ-VAE model with {NUM_QUANTIZERS}x{CODEBOOK_SIZE} codebook",
    metadata={
        "embedding_dim": EMBEDDING_DIM,
        "hidden_dim": HIDDEN_DIM,
        "codebook_size": CODEBOOK_SIZE,
        "num_quantizers": NUM_QUANTIZERS,
        "unique_semantic_ids": unique_ids,
        "collision_rate": (1 - unique_ids/len(embeddings)),
    }
)

# Add model files to artifact
artifact.add_file(MODEL_SAVE_PATH)
artifact.add_file(SEMANTIC_IDS_PATH)

# Log the artifact
wandb.log_artifact(artifact)

print("✓ Model artifact logged to wandb")

# Finish the wandb run
wandb.finish()
print("\n✓ Wandb run completed")
print(f"  View results at: {wandb.run.url if wandb.run else 'Run finished'}")

In [None]:
# Method 2: Download all files to a local directory
from src.rqvae.hub import download_model_files

# Example: Download model files locally
# Update REPO_ID to match your uploaded model

REPO_ID = "your-username/semantic-rqvae"
LOCAL_DIR = "models/downloaded_from_hub"

download_model_files(
    repo_id=REPO_ID,
    local_dir=LOCAL_DIR,
    token=None,  # Only needed for private repos
)

print(f"\n✓ Model files downloaded to {LOCAL_DIR}")
print("  You can now load the model from local files:")
print(f"    checkpoint = torch.load('{LOCAL_DIR}/rqvae_model.pt')")