# Semantic ID Recommender - Training Test

This notebook tests the full training pipeline in Google Colab.

**Requirements:**
- Google Colab with GPU runtime (T4 is sufficient)
- ~10 minutes for full test

**What this tests:**
1. RQ-VAE training (semantic ID learning)
2. Training data generation
3. LLM fine-tuning with Unsloth
4. Inference test

## Setup

In [None]:
!nvidia-smi

In [None]:
# Install uv
!pip install uv

# Clone private repo (skip if already exists)
import os
if not os.path.exists("semantic_id_recommender"):
    !git clone https://github.com/charleslow/semantic_id_recommender.git
else:
    print("✓ Repo already exists")
    
%cd semantic_id_recommender
!git pull --ff-only

In [None]:
# Install dependencies (let uv resolve versions compatible with Colab's environment)
!uv pip install --system \
    torch torchvision torchaudio \
    transformers datasets accelerate \
    vector-quantize-pytorch sentence-transformers \
    lightning wandb omegaconf einops tqdm rich \
    pydantic

# Fix protobuf version conflict
!uv pip install --system "protobuf>=3.20,<5"

# Note: We skip using uv.lock because Colab has pre-installed packages
# that may conflict with locked versions

## 1. Import from Package

In [None]:
# Add repo root to Python path for imports
import sys
sys.path.insert(0, ".")

# Import from the cloned package
import torch
from src.rqvae.model import SemanticRQVAE, SemanticRQVAEConfig
from src.rqvae.trainer import RQVAETrainer
from src.rqvae.dataset import ItemEmbeddingDataset

print("✓ Package imports successful")

## 2. Inspect Dataset

In [None]:
import json

items = []
with open("data/ag_news_500.jsonl") as f:
    for line in f:
        items.append(json.loads(line))

print(f"Loaded {len(items)} items")
print(f"\nKeys: {items[0].keys()}")
print(f"\nSample items:")
for item in items[:5]:
    print(f"  [{item['category']}] {item['item_id']}: {item['title'][:60]}...")

## 3. Generate Embeddings

In [None]:
# Use the package's ItemEmbeddingDataset to generate embeddings
dataset = ItemEmbeddingDataset.from_catalogue(
    catalogue_path="data/ag_news_500.jsonl",
    cache_path="data/embeddings.pt",  # Cache for faster reruns
)

print(f"\nDataset info:")
print(f"  Number of items: {len(dataset)}")
print(f"  Embedding dim: {dataset.embeddings.shape[1]}")
print(f"  Item IDs (first 5): {dataset.item_ids[:5]}")

# Keep embeddings tensor for later use
embeddings = dataset.embeddings

In [None]:
embeddings.shape

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

# Category distribution
categories = [item['category'] for item in items]
cat_counts = Counter(categories)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Bar chart of categories
ax1 = axes[0]
ax1.bar(cat_counts.keys(), cat_counts.values(), color=['#2ecc71', '#3498db', '#e74c3c', '#9b59b6'])
ax1.set_xlabel('Category')
ax1.set_ylabel('Count')
ax1.set_title('Distribution by Category')
for i, (cat, count) in enumerate(cat_counts.items()):
    ax1.text(i, count + 2, str(count), ha='center')

# Content length distribution
content_lengths = [len(item['content']) for item in items]
ax2 = axes[1]
ax2.hist(content_lengths, bins=30, color='#3498db', edgecolor='white')
ax2.set_xlabel('Content Length (chars)')
ax2.set_ylabel('Count')
ax2.set_title('Distribution of Article Lengths')
ax2.axvline(sum(content_lengths)/len(content_lengths), color='red', linestyle='--', label=f'Mean: {sum(content_lengths)//len(content_lengths)}')
ax2.legend()

plt.tight_layout()
plt.show()

print(f"\nSummary:")
print(f"  Total articles: {len(items)}")
print(f"  Categories: {dict(cat_counts)}")
print(f"  Avg content length: {sum(content_lengths)//len(content_lengths)} chars")

## 4. Train RQ-VAE

In [None]:
import lightning as L
from torch.utils.data import DataLoader, random_split

config = SemanticRQVAEConfig(
    embedding_dim=384,
    hidden_dim=128,
    codebook_size=16,
    num_quantizers=4,
    threshold_ema_dead_code=1,
)
trainer_module = RQVAETrainer(config=config, learning_rate=1e-3)

# Split dataset into train/val (90/10)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

trainer = L.Trainer(
    max_epochs=20,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=1,
)
trainer.fit(trainer_module, train_loader, val_loader)
print("✓ RQ-VAE training complete")

In [None]:
# Check training diagnostics
print("=== Training Diagnostics ===\n")

# Get logged metrics from trainer
logged_metrics = trainer.logged_metrics
print("Final logged metrics:")
for key, value in sorted(logged_metrics.items()):
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# Compute codebook stats on full dataset
model = trainer_module.model
model.eval()
device = next(model.parameters()).device

with torch.no_grad():
    all_indices = model.get_semantic_ids(embeddings.to(device))
    stats = model.compute_codebook_stats(all_indices)

print(f"\n=== Codebook Statistics (full dataset) ===")
print(f"Average perplexity: {stats['avg_perplexity']:.2f} / {config.codebook_size} (max)")
print(f"Average usage: {stats['avg_usage']*100:.1f}%")
print(f"\nPer-level breakdown:")
for q in range(config.num_quantizers):
    perp = stats['perplexity_per_level'][q].item()
    usage = stats['usage_per_level'][q].item() * 100
    print(f"  Level {q}: perplexity={perp:.1f}, usage={usage:.1f}%")

# Check for semantic ID collisions
unique_ids = len(set(model.semantic_id_to_string(all_indices)))
print(f"\n=== Semantic ID Quality ===")
print(f"Total items: {len(embeddings)}")
print(f"Unique semantic IDs: {unique_ids}")
print(f"Collision rate: {(1 - unique_ids/len(embeddings))*100:.1f}%")

## 5. Generate Semantic IDs

In [None]:
# Generate semantic IDs for all items using the trained model
model = trainer_module.model
model.eval()

device = next(model.parameters()).device
with torch.no_grad():
    indices = model.get_semantic_ids(embeddings.to(device))
    semantic_strings = model.semantic_id_to_string(indices)

# Create mapping
item_to_semantic = {}
semantic_to_item = {}

for i, item in enumerate(items):
    item_id = item["item_id"]  # AG News uses item_id
    sem_id = semantic_strings[i]
    item_to_semantic[item_id] = {
        "codes": indices[i].cpu().tolist(),
        "semantic_id": sem_id,
    }
    semantic_to_item[sem_id] = item_id

# Save mapping
mapping = {
    "item_to_semantic": item_to_semantic,
    "semantic_to_item": semantic_to_item,
    "config": {"num_quantizers": config.num_quantizers, "codebook_size": config.codebook_size}
}

with open("data/semantic_ids.json", "w") as f:
    json.dump(mapping, f, indent=2)

# Show examples
print("\nExample semantic IDs:")
for i in range(5):
    print(f"  {items[i]['title'][:30]:30} -> {semantic_strings[i]}")

print(f"\n✓ Generated {len(semantic_to_item)} unique semantic IDs")

In [None]:
# Examine embedding similarity vs semantic ID similarity
import torch.nn.functional as F

# Compute pairwise cosine similarities
normed = F.normalize(embeddings, dim=1)
similarities = normed @ normed.T

# Mask out diagonal and near-duplicates (sim > 0.99)
similarities_filtered = similarities.clone()
similarities_filtered.fill_diagonal_(-1)
mask = similarities_filtered > 0.99
similarities_filtered[mask] = -1  # Exclude near-duplicates

# Find most similar pair (excluding duplicates)
max_idx = similarities_filtered.argmax()
i_close, j_close = max_idx // len(items), max_idx % len(items)

# Find most dissimilar pair
similarities_for_min = similarities.clone()
similarities_for_min.fill_diagonal_(2)  # Exclude self
min_idx = similarities_for_min.argmin()
i_far, j_far = min_idx // len(items), min_idx % len(items)

print("=" * 70)
print("CLOSEST PAIR (high embedding similarity, excluding duplicates)")
print("=" * 70)
print(f"Cosine similarity: {similarities[i_close, j_close]:.4f}")
print(f"\nItem A [{items[i_close]['category']}]:")
print(f"  Title: {items[i_close]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[i_close]}")
print(f"\nItem B [{items[j_close]['category']}]:")
print(f"  Title: {items[j_close]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[j_close]}")

# Check how many codes match
codes_a = indices[i_close].tolist()
codes_b = indices[j_close].tolist()
matching = sum(a == b for a, b in zip(codes_a, codes_b))
print(f"\nCodes A: {codes_a}")
print(f"Codes B: {codes_b}")
print(f"Matching codes: {matching}/{len(codes_a)} ✓" if matching > 0 else f"Matching codes: {matching}/{len(codes_a)}")

print("\n" + "=" * 70)
print("FARTHEST PAIR (low embedding similarity)")
print("=" * 70)
print(f"Cosine similarity: {similarities[i_far, j_far]:.4f}")
print(f"\nItem A [{items[i_far]['category']}]:")
print(f"  Title: {items[i_far]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[i_far]}")
print(f"\nItem B [{items[j_far]['category']}]:")
print(f"  Title: {items[j_far]['title'][:80]}")
print(f"  Semantic ID: {semantic_strings[j_far]}")

codes_a = indices[i_far].tolist()
codes_b = indices[j_far].tolist()
matching = sum(a == b for a, b in zip(codes_a, codes_b))
print(f"\nCodes A: {codes_a}")
print(f"Codes B: {codes_b}")
print(f"Matching codes: {matching}/{len(codes_a)}")

# Check for duplicates in dataset
dup_count = (similarities > 0.999).sum().item() - len(items)  # Subtract diagonal
print(f"\n⚠️  Found {dup_count//2} duplicate pairs in dataset (cosine sim > 0.999)")

## 6. Generate Training Data

In [None]:
# Generate query -> semantic ID training pairs
query_templates = [
    "Find me a {title}",
    "I'm looking for {title}",
    "Recommend a {title}",
    "{title}",
    "{description}",
]

training_examples = []
system_prompt = "You are a recommendation system. Given a user query, output the semantic ID of the most relevant item."

for item in items:
    item_id = item["id"]
    sem_id = item_to_semantic[item_id]["semantic_id"]
    
    for template in random.sample(query_templates, 2):
        query = template.format(title=item["title"], description=item["description"])
        
        training_examples.append({
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": query},
                {"role": "assistant", "content": sem_id},
            ],
            "item_id": item_id,
        })

# Shuffle and split
random.shuffle(training_examples)
split_idx = int(len(training_examples) * 0.9)
train_examples = training_examples[:split_idx]
val_examples = training_examples[split_idx:]

# Save
with open("data/train.jsonl", "w") as f:
    for ex in train_examples:
        f.write(json.dumps(ex) + "\n")

with open("data/val.jsonl", "w") as f:
    for ex in val_examples:
        f.write(json.dumps(ex) + "\n")

print(f"✓ Generated {len(train_examples)} train, {len(val_examples)} val examples")
print(f"\nExample:")
print(f"  Query: {train_examples[0]['messages'][1]['content'][:50]}...")
print(f"  Target: {train_examples[0]['messages'][2]['content']}")

## 7. Fine-tune LLM with Unsloth

In [None]:
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import SFTTrainer
from transformers import TrainingArguments
import os

# Disable wandb for this test
os.environ["WANDB_DISABLED"] = "true"

# Use a small model for testing
BASE_MODEL = "unsloth/Qwen2.5-0.5B"  # Very small for quick test
MAX_SEQ_LENGTH = 256

print(f"Loading {BASE_MODEL}...")
llm_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=True,
)

print("✓ Model loaded")

In [None]:
# Add semantic ID tokens
semantic_tokens = []
for q in range(config.num_quantizers):
    for c in range(config.codebook_size):
        semantic_tokens.append(f"[SEM_{q}_{c}]")

tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
llm_model.resize_token_embeddings(len(tokenizer))

print(f"✓ Added {len(semantic_tokens)} semantic tokens")

In [None]:
# Add LoRA adapters
llm_model = FastLanguageModel.get_peft_model(
    llm_model,
    r=8,  # Small for quick test
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=42,
)

print("✓ LoRA adapters added")

In [None]:
# Prepare dataset
def formatting_func(examples):
    output_texts = []
    for messages in examples["messages"]:
        text = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                text += f"<|system|>\n{content}\n"
            elif role == "user":
                text += f"<|user|>\n{content}\n"
            elif role == "assistant":
                text += f"<|assistant|>\n{content}\n"
        output_texts.append(text)
    return output_texts

train_dataset = Dataset.from_list(train_examples)
val_dataset = Dataset.from_list(val_examples)

print(f"✓ Dataset prepared: {len(train_dataset)} train, {len(val_dataset)} val")

In [None]:
# Training arguments (quick test settings)
training_args = TrainingArguments(
    output_dir="checkpoints/llm",
    num_train_epochs=1,  # Just 1 epoch for test
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    logging_steps=10,
    save_strategy="no",  # Don't save for test
    eval_strategy="no",
    fp16=True,
    optim="adamw_8bit",
    report_to="none",
)

trainer = SFTTrainer(
    model=llm_model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    args=training_args,
    formatting_func=formatting_func,
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
)

print("✓ Trainer configured")

In [None]:
# Train!
print("Starting training...")
trainer.train()
print("✓ Training complete!")

## 8. Test Inference

In [None]:
# Set to inference mode
FastLanguageModel.for_inference(llm_model)

def generate_recommendation(query: str) -> str:
    system_prompt = "You are a recommendation system. Given a user query, output the semantic ID of the most relevant item."
    prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{query}\n<|assistant|>\n"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
    
    outputs = llm_model.generate(
        **inputs,
        max_new_tokens=32,
        temperature=0.1,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
    # Extract part after <|assistant|>
    if "<|assistant|>" in generated:
        result = generated.split("<|assistant|>")[-1].strip()
    else:
        result = generated[len(prompt):].strip()
    
    return result

print("✓ Inference mode ready")

In [None]:
# Test queries
test_queries = [
    "Find me a premium electronics widget",
    "I need a budget sports accessory",
    "Recommend a luxury home device",
]

print("\n" + "="*60)
print("INFERENCE TESTS")
print("="*60)

for query in test_queries:
    result = generate_recommendation(query)
    
    # Try to find the item
    import re
    sem_id_match = re.search(r"(\[SEM_\d+_\d+\])+", result)
    
    print(f"\nQuery: {query}")
    print(f"Generated: {result[:100]}")
    
    if sem_id_match:
        sem_id = sem_id_match.group(0)
        if sem_id in semantic_to_item:
            item_id = semantic_to_item[sem_id]
            item = next(i for i in items if i["id"] == item_id)
            print(f"Matched: {item['title']}")
        else:
            print(f"Semantic ID not in catalogue: {sem_id}")
    else:
        print("No valid semantic ID found")

print("\n" + "="*60)
print("✓ All tests complete!")
print("="*60)

## Summary

This notebook tested:

1. **RQ-VAE Training** - ✓ Model learns to encode items as discrete codes
2. **Semantic ID Generation** - ✓ Each item gets a unique semantic ID
3. **Training Data Creation** - ✓ Query → Semantic ID pairs generated
4. **LLM Fine-tuning** - ✓ Model learns to predict semantic IDs
5. **Inference** - ✓ Model generates semantic IDs for new queries

### Next Steps

For production use:
- Use larger model (Qwen3-4B or Ministral-3B)
- Train for more epochs (3-5)
- Use full catalogue
- Deploy to Modal for serverless inference