# Semantic ID Recommender - Training Test

This notebook tests the full training pipeline in Google Colab.

**Requirements:**
- Google Colab with GPU runtime (T4 is sufficient)
- ~10 minutes for full test

**What this tests:**
1. RQ-VAE training (semantic ID learning)
2. Training data generation
3. LLM fine-tuning with Unsloth
4. Inference test

## Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q transformers datasets accelerate
!pip install -q vector-quantize-pytorch sentence-transformers
!pip install -q lightning wandb omegaconf einops tqdm rich
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install -q pydantic

In [None]:
# Create project structure
!mkdir -p src/config src/rqvae src/llm data checkpoints

## 1. Define Models & Utilities

In [None]:
# RQ-VAE Model
from dataclasses import dataclass
import torch
import torch.nn as nn
from vector_quantize_pytorch import ResidualVQ

@dataclass
class SemanticRQVAEConfig:
    embedding_dim: int = 384
    hidden_dim: int = 512
    codebook_size: int = 256
    num_quantizers: int = 4
    commitment_weight: float = 0.25
    decay: float = 0.99

class SemanticRQVAE(nn.Module):
    def __init__(self, config: SemanticRQVAEConfig):
        super().__init__()
        self.config = config
        
        self.encoder = nn.Sequential(
            nn.Linear(config.embedding_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.hidden_dim),
        )
        
        self.rq = ResidualVQ(
            dim=config.hidden_dim,
            codebook_size=config.codebook_size,
            num_quantizers=config.num_quantizers,
            commitment_weight=config.commitment_weight,
            decay=config.decay,
            kmeans_init=True,
            threshold_ema_dead_code=2,
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(config.hidden_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Linear(config.hidden_dim, config.embedding_dim),
        )
    
    def forward(self, x):
        z = self.encoder(x)
        z_unsqueezed = z.unsqueeze(1)
        quantized, indices, commit_loss = self.rq(z_unsqueezed)
        quantized = quantized.squeeze(1)
        indices = indices.squeeze(1)
        if commit_loss.dim() > 0:
            commit_loss = commit_loss.sum()
        reconstructed = self.decoder(quantized)
        recon_loss = nn.functional.mse_loss(reconstructed, x)
        return reconstructed, indices, recon_loss, commit_loss
    
    def get_semantic_ids(self, x):
        z = self.encoder(x)
        z_unsqueezed = z.unsqueeze(1)
        _, indices, _ = self.rq(z_unsqueezed)
        return indices.squeeze(1)
    
    def semantic_id_to_string(self, indices):
        batch_size = indices.shape[0]
        results = []
        for i in range(batch_size):
            tokens = [f"[SEM_{q}_{indices[i, q].item()}]" for q in range(self.config.num_quantizers)]
            results.append("".join(tokens))
        return results

print("✓ RQ-VAE model defined")

## 2. Create Test Dataset

In [None]:
import json
import random

# Generate dummy catalogue
categories = ["Electronics", "Books", "Clothing", "Home", "Sports"]
adjectives = ["Premium", "Budget", "Professional", "Compact", "Luxury"]
nouns = ["Widget", "Gadget", "Device", "Tool", "Accessory"]

NUM_ITEMS = 500  # Small for quick testing

items = []
for i in range(NUM_ITEMS):
    category = random.choice(categories)
    adj = random.choice(adjectives)
    noun = random.choice(nouns)
    items.append({
        "id": f"item_{i:05d}",
        "title": f"{adj} {category} {noun}",
        "description": f"A high-quality {adj.lower()} {noun.lower()} for {category.lower()} enthusiasts.",
        "category": category,
    })

with open("data/catalogue.json", "w") as f:
    json.dump({"items": items}, f)

print(f"✓ Created catalogue with {NUM_ITEMS} items")

## 3. Generate Embeddings

In [None]:
from sentence_transformers import SentenceTransformer

# Load embedding model
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Generate embeddings
texts = [f"{item['title']}. {item['description']}" for item in items]
embeddings = embed_model.encode(texts, show_progress_bar=True, convert_to_tensor=True)

print(f"✓ Generated embeddings: {embeddings.shape}")

## 4. Train RQ-VAE

In [None]:
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# Config for quick test
config = SemanticRQVAEConfig(
    embedding_dim=384,
    hidden_dim=256,
    codebook_size=64,  # Smaller for quick test
    num_quantizers=4,
)

# Create model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SemanticRQVAE(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Create dataloader
dataset = TensorDataset(embeddings.to(device))
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Train
NUM_EPOCHS = 20
model.train()

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for batch in dataloader:
        x = batch[0]
        optimizer.zero_grad()
        _, _, recon_loss, commit_loss = model(x)
        loss = recon_loss + commit_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {total_loss/len(dataloader):.4f}")

print("✓ RQ-VAE training complete")

## 5. Generate Semantic IDs

In [None]:
# Generate semantic IDs for all items
model.eval()
with torch.no_grad():
    indices = model.get_semantic_ids(embeddings.to(device))
    semantic_strings = model.semantic_id_to_string(indices)

# Create mapping
item_to_semantic = {}
semantic_to_item = {}

for i, item in enumerate(items):
    item_id = item["id"]
    sem_id = semantic_strings[i]
    item_to_semantic[item_id] = {
        "codes": indices[i].cpu().tolist(),
        "semantic_id": sem_id,
    }
    semantic_to_item[sem_id] = item_id

# Save mapping
mapping = {
    "item_to_semantic": item_to_semantic,
    "semantic_to_item": semantic_to_item,
    "config": {"num_quantizers": config.num_quantizers, "codebook_size": config.codebook_size}
}

with open("data/semantic_ids.json", "w") as f:
    json.dump(mapping, f, indent=2)

# Show examples
print("\nExample semantic IDs:")
for i in range(5):
    print(f"  {items[i]['title'][:30]:30} -> {semantic_strings[i]}")

print(f"\n✓ Generated {len(semantic_to_item)} unique semantic IDs")

## 6. Generate Training Data

In [None]:
# Generate query -> semantic ID training pairs
query_templates = [
    "Find me a {title}",
    "I'm looking for {title}",
    "Recommend a {title}",
    "{title}",
    "{description}",
]

training_examples = []
system_prompt = "You are a recommendation system. Given a user query, output the semantic ID of the most relevant item."

for item in items:
    item_id = item["id"]
    sem_id = item_to_semantic[item_id]["semantic_id"]
    
    for template in random.sample(query_templates, 2):
        query = template.format(title=item["title"], description=item["description"])
        
        training_examples.append({
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": query},
                {"role": "assistant", "content": sem_id},
            ],
            "item_id": item_id,
        })

# Shuffle and split
random.shuffle(training_examples)
split_idx = int(len(training_examples) * 0.9)
train_examples = training_examples[:split_idx]
val_examples = training_examples[split_idx:]

# Save
with open("data/train.jsonl", "w") as f:
    for ex in train_examples:
        f.write(json.dumps(ex) + "\n")

with open("data/val.jsonl", "w") as f:
    for ex in val_examples:
        f.write(json.dumps(ex) + "\n")

print(f"✓ Generated {len(train_examples)} train, {len(val_examples)} val examples")
print(f"\nExample:")
print(f"  Query: {train_examples[0]['messages'][1]['content'][:50]}...")
print(f"  Target: {train_examples[0]['messages'][2]['content']}")

## 7. Fine-tune LLM with Unsloth

In [None]:
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import SFTTrainer
from transformers import TrainingArguments
import os

# Disable wandb for this test
os.environ["WANDB_DISABLED"] = "true"

# Use a small model for testing
BASE_MODEL = "unsloth/Qwen2.5-0.5B"  # Very small for quick test
MAX_SEQ_LENGTH = 256

print(f"Loading {BASE_MODEL}...")
llm_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=True,
)

print("✓ Model loaded")

In [None]:
# Add semantic ID tokens
semantic_tokens = []
for q in range(config.num_quantizers):
    for c in range(config.codebook_size):
        semantic_tokens.append(f"[SEM_{q}_{c}]")

tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
llm_model.resize_token_embeddings(len(tokenizer))

print(f"✓ Added {len(semantic_tokens)} semantic tokens")

In [None]:
# Add LoRA adapters
llm_model = FastLanguageModel.get_peft_model(
    llm_model,
    r=8,  # Small for quick test
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=42,
)

print("✓ LoRA adapters added")

In [None]:
# Prepare dataset
def formatting_func(examples):
    output_texts = []
    for messages in examples["messages"]:
        text = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                text += f"<|system|>\n{content}\n"
            elif role == "user":
                text += f"<|user|>\n{content}\n"
            elif role == "assistant":
                text += f"<|assistant|>\n{content}\n"
        output_texts.append(text)
    return output_texts

train_dataset = Dataset.from_list(train_examples)
val_dataset = Dataset.from_list(val_examples)

print(f"✓ Dataset prepared: {len(train_dataset)} train, {len(val_dataset)} val")

In [None]:
# Training arguments (quick test settings)
training_args = TrainingArguments(
    output_dir="checkpoints/llm",
    num_train_epochs=1,  # Just 1 epoch for test
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    logging_steps=10,
    save_strategy="no",  # Don't save for test
    eval_strategy="no",
    fp16=True,
    optim="adamw_8bit",
    report_to="none",
)

trainer = SFTTrainer(
    model=llm_model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    args=training_args,
    formatting_func=formatting_func,
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
)

print("✓ Trainer configured")

In [None]:
# Train!
print("Starting training...")
trainer.train()
print("✓ Training complete!")

## 8. Test Inference

In [None]:
# Set to inference mode
FastLanguageModel.for_inference(llm_model)

def generate_recommendation(query: str) -> str:
    system_prompt = "You are a recommendation system. Given a user query, output the semantic ID of the most relevant item."
    prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{query}\n<|assistant|>\n"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
    
    outputs = llm_model.generate(
        **inputs,
        max_new_tokens=32,
        temperature=0.1,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
    # Extract part after <|assistant|>
    if "<|assistant|>" in generated:
        result = generated.split("<|assistant|>")[-1].strip()
    else:
        result = generated[len(prompt):].strip()
    
    return result

print("✓ Inference mode ready")

In [None]:
# Test queries
test_queries = [
    "Find me a premium electronics widget",
    "I need a budget sports accessory",
    "Recommend a luxury home device",
]

print("\n" + "="*60)
print("INFERENCE TESTS")
print("="*60)

for query in test_queries:
    result = generate_recommendation(query)
    
    # Try to find the item
    import re
    sem_id_match = re.search(r"(\[SEM_\d+_\d+\])+", result)
    
    print(f"\nQuery: {query}")
    print(f"Generated: {result[:100]}")
    
    if sem_id_match:
        sem_id = sem_id_match.group(0)
        if sem_id in semantic_to_item:
            item_id = semantic_to_item[sem_id]
            item = next(i for i in items if i["id"] == item_id)
            print(f"Matched: {item['title']}")
        else:
            print(f"Semantic ID not in catalogue: {sem_id}")
    else:
        print("No valid semantic ID found")

print("\n" + "="*60)
print("✓ All tests complete!")
print("="*60)

## Summary

This notebook tested:

1. **RQ-VAE Training** - ✓ Model learns to encode items as discrete codes
2. **Semantic ID Generation** - ✓ Each item gets a unique semantic ID
3. **Training Data Creation** - ✓ Query → Semantic ID pairs generated
4. **LLM Fine-tuning** - ✓ Model learns to predict semantic IDs
5. **Inference** - ✓ Model generates semantic IDs for new queries

### Next Steps

For production use:
- Use larger model (Qwen3-4B or Ministral-3B)
- Train for more epochs (3-5)
- Use full catalogue
- Deploy to Modal for serverless inference