# Semantic ID Recommender - Training Test

This notebook tests the full training pipeline in Google Colab.

**Requirements:**
- Google Colab with GPU runtime (T4 is sufficient)
- ~10 minutes for full test

**What this tests:**
1. RQ-VAE training (semantic ID learning)
2. Training data generation
3. LLM fine-tuning with Unsloth
4. Inference test

## Setup

In [None]:
!nvidia-smi

In [None]:
# Install uv
!pip install uv

# Clone private repo (skip if already exists)
import os
if not os.path.exists("semantic_id_recommender"):
    !git clone https://github.com/charleslow/semantic_id_recommender.git
else:
    print("✓ Repo already exists")
    
%cd semantic_id_recommender
!git pull --ff-only

In [None]:
# Install dependencies (let uv resolve versions compatible with Colab's environment)
!uv pip install --system \
    torch torchvision torchaudio \
    transformers datasets accelerate \
    vector-quantize-pytorch sentence-transformers \
    lightning wandb omegaconf einops tqdm rich \
    pydantic unsloth psutil

# Fix protobuf version conflict
!uv pip install --system "protobuf>=3.20,<5"

# Note: We skip using uv.lock because Colab has pre-installed packages
# that may conflict with locked versions

## 1. Import from Package

In [None]:
from unsloth import FastLanguageModel

In [None]:
# Add repo root to Python path for imports
import sys
sys.path.insert(0, ".")

# Import from the cloned package
import torch
from src.rqvae.model import SemanticRQVAE, SemanticRQVAEConfig
from src.rqvae.trainer import RQVAETrainer
from src.rqvae.dataset import ItemEmbeddingDataset

print("✓ Package imports successful")

## 2. Inspect Dataset

In [None]:
import json

items = []
with open("data/ag_news_500.jsonl") as f:
    for line in f:
        items.append(json.loads(line))

print(f"Loaded {len(items)} items")
print(f"\nKeys: {items[0].keys()}")
print(f"\nSample items:")
for item in items[:5]:
    print(f"  [{item['category']}] {item['item_id']}: {item['title'][:60]}...")

## 3. Generate Embeddings

In [None]:
# Use the package's ItemEmbeddingDataset to generate embeddings
# Using GTE-Qwen2 - a powerful Qwen-based embedding model (1536 dim)
dataset = ItemEmbeddingDataset.from_catalogue(
    catalogue_path="data/ag_news_500.jsonl",
    embedding_model="Qwen/Qwen3-Embedding-0.6B",
    cache_path="data/embeddings_qwen3_0.6b.pt",  # New cache for different model
)

print(f"\nDataset info:")
print(f"  Number of items: {len(dataset)}")
print(f"  Embedding dim: {dataset.embeddings.shape[1]}")
print(f"  Item IDs (first 5): {dataset.item_ids[:5]}")

# Keep embeddings tensor for later use
embeddings = dataset.embeddings

In [None]:
embeddings.shape

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

# Category distribution
categories = [item['category'] for item in items]
cat_counts = Counter(categories)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Bar chart of categories
ax1 = axes[0]
ax1.bar(cat_counts.keys(), cat_counts.values(), color=['#2ecc71', '#3498db', '#e74c3c', '#9b59b6'])
ax1.set_xlabel('Category')
ax1.set_ylabel('Count')
ax1.set_title('Distribution by Category')
for i, (cat, count) in enumerate(cat_counts.items()):
    ax1.text(i, count + 2, str(count), ha='center')

# Content length distribution
content_lengths = [len(item['content']) for item in items]
ax2 = axes[1]
ax2.hist(content_lengths, bins=30, color='#3498db', edgecolor='white')
ax2.set_xlabel('Content Length (chars)')
ax2.set_ylabel('Count')
ax2.set_title('Distribution of Article Lengths')
ax2.axvline(sum(content_lengths)/len(content_lengths), color='red', linestyle='--', label=f'Mean: {sum(content_lengths)//len(content_lengths)}')
ax2.legend()

plt.tight_layout()
plt.show()

print(f"\nSummary:")
print(f"  Total articles: {len(items)}")
print(f"  Categories: {dict(cat_counts)}")
print(f"  Avg content length: {sum(content_lengths)//len(content_lengths)} chars")

## 4. Train RQ-VAE

In [None]:

from lightning.pytorch.loggers import WandbLogger
import wandb

# Initialize wandb for the entire experiment
wandb.init(
    project="semantic-id-recommender",
    name="full-pipeline-test",
    config={
        "embedding_model": "Qwen/Qwen3-Embedding-0.6B",
        "num_items": len(items),
    }
)

In [None]:
import lightning as L
from torch.utils.data import DataLoader, random_split

config = SemanticRQVAEConfig(
    embedding_dim=1024,  # Match embedding model output dimension
    hidden_dim=256,      # Larger hidden dim for bigger embeddings
    codebook_size=32,
    num_quantizers=3,
    threshold_ema_dead_code=1,
)

# Log RQ-VAE config to wandb
wandb.config.update({
    "rqvae_embedding_dim": config.embedding_dim,
    "rqvae_hidden_dim": config.hidden_dim,
    "rqvae_codebook_size": config.codebook_size,
    "rqvae_num_quantizers": config.num_quantizers,
})

trainer_module = RQVAETrainer(config=config, learning_rate=1e-3)

# Split dataset into train/val (90/10)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

# Create WandbLogger for Lightning
wandb_logger = WandbLogger(experiment=wandb.run)

trainer = L.Trainer(
    max_epochs=50,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=1,
    logger=wandb_logger,  # Log to wandb
)
trainer.fit(trainer_module, train_loader, val_loader)
print("✓ RQ-VAE training complete")

In [None]:
# Check training diagnostics
print("=== Training Diagnostics ===\n")

# Get logged metrics from trainer
logged_metrics = trainer.logged_metrics
print("Final logged metrics:")
for key, value in sorted(logged_metrics.items()):
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# Compute codebook stats on full dataset
model = trainer_module.model
model.eval()
device = next(model.parameters()).device

with torch.no_grad():
    all_indices = model.get_semantic_ids(embeddings.to(device))
    stats = model.compute_codebook_stats(all_indices)

print(f"\n=== Codebook Statistics (full dataset) ===")
print(f"Average perplexity: {stats['avg_perplexity']:.2f} / {config.codebook_size} (max)")
print(f"Average usage: {stats['avg_usage']*100:.1f}%")
print(f"\nPer-level breakdown:")
for q in range(config.num_quantizers):
    perp = stats['perplexity_per_level'][q].item()
    usage = stats['usage_per_level'][q].item() * 100
    print(f"  Level {q}: perplexity={perp:.1f}, usage={usage:.1f}%")

# Check for semantic ID collisions
unique_ids = len(set(model.semantic_id_to_string(all_indices)))
print(f"\n=== Semantic ID Quality ===")
print(f"Total items: {len(embeddings)}")
print(f"Unique semantic IDs: {unique_ids}")
print(f"Collision rate: {(1 - unique_ids/len(embeddings))*100:.1f}%")

## 5. Generate Semantic IDs

In [None]:
# Generate semantic IDs for all items using the trained model
model = trainer_module.model
model.eval()

device = next(model.parameters()).device
with torch.no_grad():
    indices = model.get_semantic_ids(embeddings.to(device))
    semantic_strings = model.semantic_id_to_string(indices)

# Create mapping
item_to_semantic = {}
semantic_to_item = {}

for i, item in enumerate(items):
    item_id = item["item_id"]  # AG News uses item_id
    sem_id = semantic_strings[i]
    item_to_semantic[item_id] = {
        "codes": indices[i].cpu().tolist(),
        "semantic_id": sem_id,
    }
    semantic_to_item[sem_id] = item_id

# Save mapping
mapping = {
    "item_to_semantic": item_to_semantic,
    "semantic_to_item": semantic_to_item,
    "config": {"num_quantizers": config.num_quantizers, "codebook_size": config.codebook_size}
}

with open("data/semantic_ids.json", "w") as f:
    json.dump(mapping, f, indent=2)

# Show examples
print("\nExample semantic IDs:")
for i in range(5):
    print(f"  {items[i]['title'][:30]:30} -> {semantic_strings[i]}")

print(f"\n✓ Generated {len(semantic_to_item)} unique semantic IDs")

In [None]:
# Examine embedding similarity vs semantic ID similarity
import torch.nn.functional as F

# Compute pairwise cosine similarities
normed = F.normalize(embeddings, dim=1)
similarities = normed @ normed.T

# Mask out diagonal and near-duplicates (sim > 0.99)
similarities_filtered = similarities.clone()
similarities_filtered.fill_diagonal_(-1)
mask = similarities_filtered > 0.99
similarities_filtered[mask] = -1  # Exclude near-duplicates

# Find most similar pair (excluding duplicates)
max_idx = similarities_filtered.argmax()
i_close, j_close = max_idx // len(items), max_idx % len(items)

# Find most dissimilar pair
similarities_for_min = similarities.clone()
similarities_for_min.fill_diagonal_(2)  # Exclude self
min_idx = similarities_for_min.argmin()
i_far, j_far = min_idx // len(items), min_idx % len(items)

print("=" * 70)
print("CLOSEST PAIR (high embedding similarity, excluding duplicates)")
print("=" * 70)
print(f"Cosine similarity: {similarities[i_close, j_close]:.4f}")
print(f"\nItem A [{items[i_close]['category']}]:")
print(f"  Title: {items[i_close]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[i_close]}")
print(f"\nItem B [{items[j_close]['category']}]:")
print(f"  Title: {items[j_close]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[j_close]}")

# Check how many codes match
codes_a = indices[i_close].tolist()
codes_b = indices[j_close].tolist()
matching = sum(a == b for a, b in zip(codes_a, codes_b))
print(f"\nCodes A: {codes_a}")
print(f"Codes B: {codes_b}")
print(f"Matching codes: {matching}/{len(codes_a)} ✓" if matching > 0 else f"Matching codes: {matching}/{len(codes_a)}")

print("\n" + "=" * 70)
print("FARTHEST PAIR (low embedding similarity)")
print("=" * 70)
print(f"Cosine similarity: {similarities[i_far, j_far]:.4f}")
print(f"\nItem A [{items[i_far]['category']}]:")
print(f"  Title: {items[i_far]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[i_far]}")
print(f"\nItem B [{items[j_far]['category']}]:")
print(f"  Title: {items[j_far]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[j_far]}")

codes_a = indices[i_far].tolist()
codes_b = indices[j_far].tolist()
matching = sum(a == b for a, b in zip(codes_a, codes_b))
print(f"\nCodes A: {codes_a}")
print(f"Codes B: {codes_b}")
print(f"Matching codes: {matching}/{len(codes_a)}")

# Check for duplicates in dataset
dup_count = (similarities > 0.999).sum().item() - len(items)  # Subtract diagonal
print(f"\n⚠️  Found {dup_count//2} duplicate pairs in dataset (cosine sim > 0.999)")

## 6. Generate Training Data

In [None]:
# Generate training data using the package utilities
from src.llm.data import format_for_chat

# Create items with semantic IDs for training
items_with_sem_ids = []
for item in items:
    item_copy = item.copy()
    item_copy["semantic_id"] = item_to_semantic[item["item_id"]]["semantic_id"]
    items_with_sem_ids.append(item_copy)

# Generate training examples
from src.llm.data import generate_training_examples
raw_examples = generate_training_examples(
    items_with_sem_ids, 
    num_examples_per_item=10,
    query_templates=[
        # Direct title queries
        "{title}",
        "Find me news about {title}",
        "I'm looking for articles on {title}",
        "Show me {title}",
        "Search for {title}",
        # Category-based queries
        "News about {category}: {title}",
        "{category} news: {title}",
        "Latest {category} article about {title}",
        # Question formats
        "What's the news about {title}?",
        "Any articles on {title}?",
        "Do you have news about {title}?",
        # Recommendation style
        "Recommend articles about {title}",
        "Suggest news on {title}",
        # Short/casual queries
        "news {title}",
        "article {title}",
        "{category} {title}",
    ],
)

# Format for chat and split
import random
formatted_examples = format_for_chat(raw_examples)
random.shuffle(formatted_examples)
split_idx = int(len(formatted_examples) * 0.9)
train_examples = formatted_examples[:split_idx]
val_examples = formatted_examples[split_idx:]

print(f"✓ Generated {len(train_examples)} train, {len(val_examples)} val examples")
print(f"\nExample:")
print(f"  Query: {train_examples[0]['messages'][1]['content'][:50]}...")
print(f"  Target: {train_examples[0]['messages'][2]['content']}")

## 7. Fine-tune LLM with Unsloth

In [None]:
# Update wandb config for LLM training (wandb already initialized)
wandb.config.update({
    "base_model": "unsloth/SmolLM2-135M-Instruct",
    "num_train_examples": len(train_examples),
})

In [None]:
from unsloth import FastLanguageModel
from datasets import Dataset

# Use SmolLM - very small model (135M parameters)
BASE_MODEL = "unsloth/SmolLM2-135M-Instruct"
MAX_SEQ_LENGTH = 256

print(f"Loading {BASE_MODEL}...")
llm_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=True,
)

print("✓ Model loaded")

In [None]:
# Add semantic ID tokens using package utility
from src.llm.finetune import add_semantic_tokens

tokenizer = add_semantic_tokens(tokenizer, config.num_quantizers, config.codebook_size)
llm_model.resize_token_embeddings(len(tokenizer))

print(f"✓ Added {config.num_quantizers * config.codebook_size} semantic tokens")

In [None]:
# Add LoRA adapters
llm_model = FastLanguageModel.get_peft_model(
    llm_model,
    r=32,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=42,
)

print("✓ LoRA adapters added")

In [None]:
# Prepare dataset using package formatting function
def format_single_example(messages):
    """Format messages into prompt string."""
    text = ""
    for msg in messages:
        role, content = msg["role"], msg["content"]
        if role == "system":
            text += f"<|system|>\n{content}\n"
        elif role == "user":
            text += f"<|user|>\n{content}\n"
        elif role == "assistant":
            text += f"<|assistant|>\n{content}\n"
    return text

train_texts = [{"text": format_single_example(ex["messages"])} for ex in train_examples]
val_texts = [{"text": format_single_example(ex["messages"])} for ex in val_examples]

train_dataset = Dataset.from_list(train_texts)
val_dataset = Dataset.from_list(val_texts)

print(f"✓ Dataset prepared: {len(train_dataset)} train, {len(val_dataset)} val")
print(f"Sample: {train_texts[0]['text'][:100]}...")

In [None]:
# Fix: Unsloth's compiled cache needs psutil in builtins
import psutil
import builtins
builtins.psutil = psutil

from trl import SFTTrainer
from transformers import TrainingArguments

# Training arguments
training_args = TrainingArguments(
    output_dir="checkpoints/llm",
    num_train_epochs=5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    logging_steps=10,
    weight_decay=0.01,
    max_grad_norm=1.0,
    save_strategy="no",
    eval_strategy="steps",      # Enable evaluation
    eval_steps=50,              # Evaluate every 50 steps
    fp16=True,
    optim="adamw_8bit",
    report_to="wandb",
)

trainer = SFTTrainer(
    model=llm_model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,   # Add validation dataset
    args=training_args,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
)

print("✓ Trainer configured")

In [None]:
# Train!
print("Starting training...")
trainer.train()
print("✓ Training complete!")

## 8. Test Inference

In [None]:
FastLanguageModel.for_inference(llm_model)
from src.llm.finetune import SemanticIDGenerator

generator = SemanticIDGenerator(
    model=llm_model,
    tokenizer=tokenizer,
    num_quantizers=config.num_quantizers,
    codebook_size=config.codebook_size,
    use_constrained_decoding=True,
)

print("✓ Inference mode ready with constrained decoding")
print(f"  Semantic tokens: {config.num_quantizers * config.codebook_size} total")
print(f"  Output format: {config.num_quantizers} tokens (one per quantizer level)")

In [None]:
# Test queries with beam search - show top 10 candidates and valid matches
test_queries = [
    "News about stock market and business",
    "Sports football game results",
    "Technology and science discoveries",
    "World news international politics",
]

NUM_BEAMS = 10

print("\n" + "="*70)
print("INFERENCE TESTS (Beam Search)")
print("="*70)

results_table = []
for query in test_queries:
    print(f"\n{'─'*70}")
    print(f"Query: {query}")
    print(f"{'─'*70}")
    
    # Get top candidates via beam search
    candidates = generator.generate_beam(query, num_beams=NUM_BEAMS, num_return_sequences=NUM_BEAMS)
    
    print(f"\nTop {len(candidates)} candidates:")
    valid_count = 0
    for rank, (sem_id, score) in enumerate(candidates, 1):
        if sem_id in semantic_to_item:
            item_id = semantic_to_item[sem_id]
            item = next(i for i in items if i["item_id"] == item_id)
            print(f"  {rank}. ✓ {sem_id} (score: {score:.2f})")
            print(f"     [{item['category']}] {item['title'][:55]}...")
            valid_count += 1
            results_table.append([query, rank, sem_id, score, item['category'], item['title'][:50], True])
        else:
            print(f"  {rank}. ✗ {sem_id} (score: {score:.2f}) - not in catalogue")
            results_table.append([query, rank, sem_id, score, "", "", False])
    
    print(f"\n  Valid matches: {valid_count}/{len(candidates)}")

# Log results to wandb
wandb.log({
    "inference_results": wandb.Table(
        columns=["query", "rank", "semantic_id", "score", "category", "title", "valid"],
        data=results_table
    )
})

print("\n" + "="*70)
print("✓ All tests complete!")
print("="*70)

# Finish wandb run
wandb.finish()

## Summary

This notebook tested:

1. **RQ-VAE Training** - ✓ Model learns to encode items as discrete codes
2. **Semantic ID Generation** - ✓ Each item gets a unique semantic ID
3. **Training Data Creation** - ✓ Query → Semantic ID pairs generated
4. **LLM Fine-tuning** - ✓ Model learns to predict semantic IDs
5. **Inference** - ✓ Model generates semantic IDs for new queries

### Next Steps

For production use:
- Use larger model (Qwen3-4B or Ministral-3B)
- Train for more epochs (3-5)
- Use full catalogue
- Deploy to Modal for serverless inference