# RQ-VAE Training

Train a Residual Quantized Variational AutoEncoder (RQ-VAE) to learn semantic IDs for items in a catalogue.

**Environment:** RunPod Jupyter with GPU

**What this does:**
1. Load item catalogue (JSONL format)
2. Generate embeddings using a pretrained embedding model
3. Train RQ-VAE to compress embeddings into discrete codes
4. Generate semantic IDs for all items
5. Save trained model and semantic ID mappings

**Outputs:**
- `models/rqvae_model.pt` - Trained RQ-VAE model
- `data/semantic_ids.json` - Item ID to Semantic ID mappings

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Navigate to project root if needed
import sys
import os

if os.path.abspath(".").endswith("notebooks"):
    repo_root = os.path.abspath("..")
    print(f"Current directory: {os.getcwd()}, changing to {repo_root}")
    if repo_root not in sys.path:
        sys.path.insert(0, repo_root)
    os.chdir(repo_root)
else:
    print(f"Already in project root: {os.getcwd()}")

In [None]:
from unsloth import FastLanguageModel # Always import unsloth first

In [None]:
import torch
import json
from pathlib import Path
import lightning as L
from torch.utils.data import DataLoader, random_split

from src.rqvae import (
    SemanticRQVAE,
    SemanticRQVAEConfig,
    RQVAETrainer,
    RqvaeTrainConfig,
    WandbArtifactCallback,
    ItemEmbeddingDataset,
    save_model_for_hub,
    upload_to_hub,
    download_model_files,
)

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
# Paths
CATALOGUE_PATH = "data/mcf_articles.jsonl"  # Input catalogue
EMBEDDINGS_CACHE = "data/embeddings_mcf.pt"  # Cache for embeddings

# Embedding model
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"  # 1024-dim embeddings

# Training configuration - all training-related parameters in one place
train_config = RqvaeTrainConfig(
    # Model architecture
    embedding_dim=1024,       # Match embedding model output
    hidden_dim=256,           # Encoder/decoder hidden dimension
    codebook_size=64,         # Number of codes per quantizer
    num_quantizers=3,         # Number of quantization levels
    
    # Training hyperparameters
    learning_rate=1e-3,
    max_epochs=500,
    batch_size=512,
    train_split=0.95,
    
    # W&B artifact logging
    log_wandb_artifacts=False,
    artifact_name="rqvae-model",
    
    # Output paths
    model_save_path="models/rqvae_model.pt",
    semantic_ids_path="data/semantic_ids.json",
)

# Create output directories
Path("models").mkdir(exist_ok=True)
Path("data").mkdir(exist_ok=True)

print("Configuration:")
print(f"  Catalogue: {CATALOGUE_PATH}")
print(f"  Embedding model: {EMBEDDING_MODEL}")
print(f"  Codebook: {train_config.num_quantizers} levels x {train_config.codebook_size} codes")
print(f"  Output model: {train_config.model_save_path}")
print(f"  Log W&B artifacts: {train_config.log_wandb_artifacts}")

## 3. Verify Environment Variables

In [None]:
# Check for required environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

print("=== Environment Variables ===")
if HF_TOKEN:
    print(f"✓ HF_TOKEN is set (length: {len(HF_TOKEN)})")
else:
    print("⚠️  HF_TOKEN not set - HuggingFace uploads will fail")
    print("   Set it in RunPod Settings → Environment Variables")

if WANDB_API_KEY:
    print(f"✓ WANDB_API_KEY is set (length: {len(WANDB_API_KEY)})")
else:
    print("⚠️  WANDB_API_KEY not set - Weights & Biases logging disabled")
    print("   Set it in RunPod Settings → Environment Variables")
    print("   Or run: wandb login")

# Verify HF_TOKEN is set (required for hub upload)
assert HF_TOKEN, "HF_TOKEN must be set as an environment variable"
print("\n✓ All required environment variables are set")

## 4. Initialize Weights & Biases

In [None]:
import wandb
from lightning.pytorch.loggers import WandbLogger

# Initialize wandb using environment variable
if WANDB_API_KEY:
    wandb.init(
        project="semantic-id-recommender",
        name=f"rqvae_{train_config.codebook_size}x{train_config.num_quantizers}",
        config={
            "embedding_model": EMBEDDING_MODEL,
            "embedding_dim": train_config.embedding_dim,
            "hidden_dim": train_config.hidden_dim,
            "codebook_size": train_config.codebook_size,
            "num_quantizers": train_config.num_quantizers,
            "learning_rate": train_config.learning_rate,
            "max_epochs": train_config.max_epochs,
            "batch_size": train_config.batch_size,
            "train_split": train_config.train_split,
        }
    )

    # Create wandb logger for Lightning
    wandb_logger = WandbLogger(
        project="semantic-id-recommender",
        log_model=False,
    )

    print("✓ Weights & Biases initialized")
    print(f"  Project: semantic-id-recommender")
    print(f"  Run: {wandb.run.name}")
    print(f"  URL: {wandb.run.url}")
else:
    wandb_logger = None
    print("⚠️  Wandb logging disabled (WANDB_API_KEY not set)")

## 5. Load Catalogue and Generate Embeddings

In [None]:
# Load catalogue
items = []
with open(CATALOGUE_PATH) as f:
    for line in f:
        items.append(json.loads(line))

print(f"Loaded {len(items)} items from catalogue")
print(f"\nSample items:")
for item in items[:3]:
    print(f"  [{item.get('category', 'N/A')}] {item['item_id']}: {item['title'][:60]}...")

In [None]:
# Generate embeddings (or load from cache)
dataset = ItemEmbeddingDataset.from_catalogue(
    catalogue_path=CATALOGUE_PATH,
    embedding_model=EMBEDDING_MODEL,
    cache_path=EMBEDDINGS_CACHE,
    batch_size=32,
    fields=["title", "slug", "introduction", "content"],
)
embeddings = dataset.embeddings

print(f"\nDataset info:")
print(f"  Number of items: {len(dataset)}")
print(f"  Embedding dimension: {embeddings.shape[1]}")
print(f"  Embedding shape: {embeddings.shape}")

## 6. Create and Train RQ-VAE Model

In [None]:
# Create RQ-VAE model using train_config
config = train_config.to_model_config()
trainer_module = RQVAETrainer(config=config, learning_rate=train_config.learning_rate)

print(f"RQ-VAE Configuration:")
print(f"  Embedding dim: {config.embedding_dim}")
print(f"  Hidden dim: {config.hidden_dim}")
print(f"  Codebook: {config.num_quantizers} levels x {config.codebook_size} codes")
print(f"  Total semantic ID space: {config.codebook_size ** config.num_quantizers:,} unique IDs")

In [None]:
# Split dataset into train/val
train_size = int(train_config.train_split * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=train_config.batch_size, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=train_config.batch_size, num_workers=8, shuffle=False)

print(f"Data split:")
print(f"  Train: {len(train_dataset)} items")
print(f"  Validation: {len(val_dataset)} items")

In [None]:
# Build callbacks list
callbacks = []
if train_config.log_wandb_artifacts:
    callbacks.append(WandbArtifactCallback(
        train_config=train_config,
        embedding_model=EMBEDDING_MODEL,
    ))

# Train the model
trainer = L.Trainer(
    max_epochs=train_config.max_epochs,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=5,
    logger=wandb_logger,
    callbacks=callbacks,
)

print("Starting training...")
trainer.fit(trainer_module, train_loader, val_loader)
print("✓ Training complete!")

## 7. Evaluate Model

In [None]:
# Compute codebook statistics on full dataset
model = trainer_module.model
model.eval()
device = next(model.parameters()).device

with torch.no_grad():
    all_indices = model.get_semantic_ids(embeddings.to(device))
    stats = model.compute_codebook_stats(all_indices)

print(f"\n=== Codebook Statistics (Full Dataset) ===")
print(f"Average perplexity: {stats['avg_perplexity']:.2f} / {config.codebook_size} (max)")
print(f"Average usage: {stats['avg_usage']*100:.1f}%")
print(f"\nPer-level breakdown:")
for q in range(config.num_quantizers):
    perp = stats['perplexity_per_level'][q].item()
    usage = stats['usage_per_level'][q].item() * 100
    print(f"  Level {q}: perplexity={perp:.1f}, usage={usage:.1f}%")

In [None]:
# Check for semantic ID collisions
unique_ids = len(set(model.semantic_id_to_string(all_indices)))
print(f"\n=== Semantic ID Quality ===")
print(f"Total items: {len(embeddings)}")
print(f"Unique semantic IDs: {unique_ids}")
print(f"Collision rate: {(1 - unique_ids/len(embeddings))*100:.1f}%")

if unique_ids < len(embeddings):
    print(f"\n⚠️  Warning: {len(embeddings) - unique_ids} collisions detected")
    print("   Consider increasing codebook size or number of quantizers")
else:
    print("\n✓ No collisions - all items have unique semantic IDs")

# Log evaluation metrics to wandb (if initialized)
if WANDB_API_KEY and wandb.run:
    wandb.log({
        "eval/avg_perplexity": stats['avg_perplexity'].item(),
        "eval/avg_usage": stats['avg_usage'].item(),
        "eval/unique_semantic_ids": unique_ids,
        "eval/total_items": len(embeddings),
        "eval/collision_rate": (1 - unique_ids/len(embeddings)),
    })

    # Log per-level statistics
    for q in range(config.num_quantizers):
        wandb.log({
            f"eval/level_{q}_perplexity": stats['perplexity_per_level'][q].item(),
            f"eval/level_{q}_usage": stats['usage_per_level'][q].item(),
        })

    print("\n✓ Evaluation metrics logged to wandb")

## 8. Generate and Save Semantic IDs

In [None]:
# Generate semantic IDs for all items
with torch.no_grad():
    indices = model.get_semantic_ids(embeddings.to(device))
    semantic_strings = model.semantic_id_to_string(indices)

# Create bidirectional mappings
item_to_semantic = {}
semantic_to_item = {}

for i, item in enumerate(items):
    item_id = item["item_id"]
    sem_id = semantic_strings[i]
    item_to_semantic[item_id] = {
        "codes": indices[i].cpu().tolist(),
        "semantic_id": sem_id,
    }
    semantic_to_item[sem_id] = item_id

print("\nExample semantic IDs:")
for i in range(min(5, len(items))):
    title = items[i]['title'][:40]
    print(f"  {title:40} -> {semantic_strings[i]}")

print(f"\n✓ Generated {len(semantic_to_item)} unique semantic IDs")

In [None]:
# Save semantic ID mappings
mapping = {
    "item_to_semantic": item_to_semantic,
    "semantic_to_item": semantic_to_item,
    "config": {
        "num_quantizers": config.num_quantizers,
        "codebook_size": config.codebook_size,
        "embedding_dim": config.embedding_dim,
        "hidden_dim": config.hidden_dim,
    }
}

with open(train_config.semantic_ids_path, "w") as f:
    json.dump(mapping, f, indent=2)

print(f"✓ Saved semantic ID mappings to {train_config.semantic_ids_path}")

## 9. Save Trained Model

In [None]:
# Save the trained model (artifact logging is handled by WandbArtifactCallback if enabled)
checkpoint = {
    "model_state_dict": model.state_dict(),
    "config": {
        "embedding_dim": config.embedding_dim,
        "hidden_dim": config.hidden_dim,
        "codebook_size": config.codebook_size,
        "num_quantizers": config.num_quantizers,
        "threshold_ema_dead_code": config.threshold_ema_dead_code,
    },
    "training_info": {
        "num_items": len(items),
        "unique_semantic_ids": unique_ids,
        "codebook_usage": float(stats['avg_usage']),
        "perplexity": float(stats['avg_perplexity']),
    }
}

torch.save(checkpoint, train_config.model_save_path)
print(f"✓ Saved trained model to {train_config.model_save_path}")
print(f"  Model size: {Path(train_config.model_save_path).stat().st_size / 1024 / 1024:.2f} MB")

if train_config.log_wandb_artifacts:
    print(f"\n✓ Model artifact was logged to W&B by WandbArtifactCallback")

## 10. Test Model Loading

In [None]:
# Test that we can reload the model
checkpoint = torch.load(train_config.model_save_path)

# Recreate model from saved config
loaded_config = SemanticRQVAEConfig(**checkpoint["config"])
loaded_model = SemanticRQVAE(loaded_config)
loaded_model.load_state_dict(checkpoint["model_state_dict"])
loaded_model.eval()

print("✓ Successfully reloaded model from checkpoint")
print(f"\nCheckpoint info:")
print(f"  Config: {checkpoint['config']}")
if "training_info" in checkpoint:
    print(f"  Training info: {checkpoint['training_info']}")

# Verify it produces same outputs
with torch.no_grad():
    test_embedding = embeddings[:5]
    original_ids = model.semantic_id_to_string(model.get_semantic_ids(test_embedding.to(device)))
    loaded_ids = loaded_model.semantic_id_to_string(loaded_model.get_semantic_ids(test_embedding.to(device)))
    
if original_ids == loaded_ids:
    print("\n✓ Loaded model produces identical outputs")
else:
    print("\n⚠️  Warning: Loaded model outputs differ")

## 10a. Load Model from W&B Artifacts (Optional)

If you've logged models as W&B artifacts, you can load them in future sessions.

In [None]:
# Example: Load model from W&B artifact
# Uncomment and modify the following to load a model from W&B

# # Initialize wandb (if not already)
# wandb.init(project="semantic-id-recommender", job_type="inference")
# 
# # Download artifact - use "latest", "best", or a specific version like "v0"
# artifact = wandb.use_artifact("rqvae-model:latest")
# artifact_dir = artifact.download()
# 
# # Load the model
# checkpoint = torch.load(f"{artifact_dir}/rqvae_model.pt")
# artifact_config = SemanticRQVAEConfig(**checkpoint["config"])
# artifact_model = SemanticRQVAE(artifact_config)
# artifact_model.load_state_dict(checkpoint["model_state_dict"])
# artifact_model.eval()
# 
# # Load semantic ID mappings
# with open(f"{artifact_dir}/semantic_ids.json") as f:
#     artifact_mapping = json.load(f)
# 
# print(f"Loaded model from artifact")
# print(f"  Config: {checkpoint['config']}")
# print(f"  Items: {len(artifact_mapping['item_to_semantic'])}")
# 
# # Check artifact metadata
# print(f"\nArtifact metadata:")
# for k, v in artifact.metadata.items():
#     print(f"  {k}: {v}")

print("To load a model from W&B artifacts, uncomment the code above.")

## 11. Upload to HuggingFace Hub

In [None]:
# Prepare model files for upload
local_dir = "models/rqvae_hub"

save_model_for_hub(
    model=model,
    local_dir=local_dir,
    semantic_ids_path=train_config.semantic_ids_path,
    training_info={
        "num_items": len(items),
        "unique_semantic_ids": unique_ids,
        "codebook_usage": float(stats['avg_usage']),
        "perplexity": float(stats['avg_perplexity']),
    }
)

print(f"\n✓ Model prepared for HuggingFace Hub in {local_dir}")

In [None]:
# Upload to HuggingFace Hub
#
# IMPORTANT: 
# 1. Create a private repository first at https://huggingface.co/new
# 2. Update REPO_ID below with your username and repo name
# 3. HF_TOKEN environment variable will be used automatically

REPO_ID = "your-username/semantic-rqvae"  # Change this to your repo!

print("Upload Configuration:")
print(f"  Repo ID: {REPO_ID}")
print(f"  Model directory: {local_dir}")
print(f"  Token: {'✓ Set' if HF_TOKEN else '✗ Not set'}")
print()

url = upload_to_hub(
    local_dir=local_dir,
    repo_id=REPO_ID,
    token=HF_TOKEN,
    commit_message="Upload RQ-VAE semantic ID model",
)
print(f"\n✓ Model available at: {url}")

In [None]:
# Download model files from HuggingFace Hub
from src.rqvae.hub import download_model_files

# Example: Download model files locally
# Update REPO_ID to match your uploaded model

REPO_ID = "your-username/semantic-rqvae"
LOCAL_DIR = "models/downloaded_from_hub"

download_model_files(
    repo_id=REPO_ID,
    local_dir=LOCAL_DIR,
    token=HF_TOKEN,  # Use environment variable (required for private repos)
)

print(f"\n✓ Model files downloaded to {LOCAL_DIR}")
print("  You can now load the model from local files:")
print(f"    checkpoint = torch.load('{LOCAL_DIR}/rqvae_model.pt')")

# Test the downloaded model
print("\n=== Testing Downloaded Model ===")

# Load checkpoint from downloaded files
downloaded_checkpoint = torch.load(f"{LOCAL_DIR}/rqvae_model.pt")

# Recreate model from downloaded config
downloaded_config = SemanticRQVAEConfig(**downloaded_checkpoint["config"])
downloaded_model = SemanticRQVAE(downloaded_config)
downloaded_model.load_state_dict(downloaded_checkpoint["model_state_dict"])
downloaded_model.eval()

print("✓ Successfully loaded model from downloaded files")
print(f"\nDownloaded model config:")
print(f"  Embedding dim: {downloaded_config.embedding_dim}")
print(f"  Hidden dim: {downloaded_config.hidden_dim}")
print(f"  Codebook: {downloaded_config.num_quantizers} levels x {downloaded_config.codebook_size} codes")

if "training_info" in downloaded_checkpoint:
    print(f"\nTraining info:")
    for key, value in downloaded_checkpoint["training_info"].items():
        print(f"  {key}: {value}")

# Verify downloaded model produces same outputs as original
with torch.no_grad():
    test_embedding = embeddings[:5]
    original_ids = model.semantic_id_to_string(model.get_semantic_ids(test_embedding.to(device)))
    downloaded_ids = downloaded_model.semantic_id_to_string(downloaded_model.get_semantic_ids(test_embedding.to(device)))

    print(f"\nComparing outputs on {len(test_embedding)} test embeddings:")
    if original_ids == downloaded_ids:
        print("✓ Downloaded model produces identical outputs to original")
        print("\nExample outputs:")
        for i, (orig, down) in enumerate(zip(original_ids, downloaded_ids)):
            print(f"  Item {i}: {orig} (match: {orig == down})")
    else:
        print("⚠️  Warning: Downloaded model outputs differ from original")
        print("\nOriginal vs Downloaded:")
        for i, (orig, down) in enumerate(zip(original_ids, downloaded_ids)):
            match = "✓" if orig == down else "✗"
            print(f"  Item {i}: {orig} vs {down} {match}")