# LLM Fine-tuning for Semantic ID Generation

Fine-tune a small LLM to generate semantic IDs from user queries using a trained RQ-VAE model.

**Environment:** RunPod Jupyter with GPU

**Prerequisites:**
- Trained RQ-VAE model (either as W&B artifact or local file)
- Item catalogue (same one used for RQ-VAE training)

**Training Stages:**
1. **Stage 1**: Train only embedding layers for new semantic ID tokens (backbone frozen)
2. **Stage 2**: LoRA fine-tuning on the full model using stage 1 checkpoint

**Outputs:**
- `checkpoints/llm_stage1/` - Stage 1 model (embeddings only)
- `checkpoints/llm_stage2/` - Stage 2 model (full fine-tuned)

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Navigate to project root if needed
import sys
import os

if os.path.abspath(".").endswith("notebooks"):
    repo_root = os.path.abspath("..")
    print(f"Current directory: {os.getcwd()}, changing to {repo_root}")
    if repo_root not in sys.path:
        sys.path.insert(0, repo_root)
    os.chdir(repo_root)
else:
    print(f"Already in project root: {os.getcwd()}")

In [None]:
from unsloth import FastLanguageModel  # Always import unsloth first

In [None]:
import torch
from pathlib import Path
from dataclasses import replace

from src.llm import (
    train,
    LLMTrainConfig,
    LLMTrainResult,
    SemanticIDGenerator,
    load_finetuned_model,
)

print("Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

All training parameters are consolidated into `LLMTrainConfig`.

In [None]:
# Stage 1 configuration - all parameters in one place
stage1_config = LLMTrainConfig(
    # RQ-VAE Model Source
    # Option 1: Load from W&B artifact (recommended)
    wandb_rqvae_artifact="rqvae-model:latest",  # e.g., "rqvae-model:v3"
    # Option 2: Load from local file
    # rqvae_model_path="models/rqvae_model.pt",
    
    # Catalogue and Embeddings (must match RQ-VAE training)
    catalogue_path="data/mcf_articles.jsonl",
    catalogue_id_field="item_id",
    embedding_model="Qwen/Qwen3-Embedding-0.6B",
    embeddings_cache_path="data/embeddings_mcf.pt",
    
    # Query templates for training data generation
    query_templates={
        "predict_semantic_id": [
            "{title}",
            "Find: {title}",
            "Search for {title}",
            "Recommend: {title}",
            "Show me {title}",
            "I want to read about {title}",
            "Article about {title}",
        ],
        "predict_attribute": [
            "What is the {field_name} for {semantic_id}?",
            "Get {field_name} for {semantic_id}",
            "{semantic_id} - what is the {field_name}?",
        ],
    },
    field_mapping={"title": "title", "category": "category"},
    num_examples_per_item=5,
    predict_semantic_id_ratio=0.8,
    val_split=0.1,
    
    # Base LLM
    # Options: "unsloth/Qwen3-4B", "unsloth/Qwen3-1.7B", "HuggingFaceTB/SmolLM2-1.7B-Instruct"
    base_model="HuggingFaceTB/SmolLM2-135M-Instruct",
    max_seq_length=512,
    load_in_4bit=True,
    
    # Stage 1: Embedding training (backbone frozen)
    stage=1,
    
    # Training hyperparameters
    learning_rate=2e-4,
    batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    warmup_ratio=0.03,
    
    # Output
    output_dir="checkpoints/llm_stage1",
    semantic_ids_output_path="data/semantic_ids.json",
    
    # Logging
    logging_steps=10,
    save_strategy="epoch",
    eval_steps=100,
    
    # W&B configuration
    wandb_project="semantic-id-recommender",
    wandb_run_name="llm-stage1",
    report_to="wandb",
    log_wandb_artifacts=True,
    
    # Test queries for evaluation callback
    recommendation_test_queries=[
        "News about stock market and business",
        "Sports football game results",
        "Technology and science discoveries",
    ],
)

# Create output directories
Path("checkpoints").mkdir(exist_ok=True)

print("Stage 1 Configuration:")
print(f"  RQ-VAE source: {stage1_config.wandb_rqvae_artifact or stage1_config.rqvae_model_path}")
print(f"  Catalogue: {stage1_config.catalogue_path}")
print(f"  Embedding model: {stage1_config.embedding_model}")
print(f"  Base LLM: {stage1_config.base_model}")
print(f"  Output: {stage1_config.output_dir}")
print(f"  Effective batch size: {stage1_config.batch_size * stage1_config.gradient_accumulation_steps}")
print(f"  W&B project: {stage1_config.wandb_project}")
print(f"  Log W&B artifacts: {stage1_config.log_wandb_artifacts}")

## 3. Helper Functions

In [None]:
def test_model(result: LLMTrainResult, test_queries: list[str] | None = None):
    """
    Test a trained model with sample queries.
    
    Args:
        result: LLMTrainResult from train()
        test_queries: List of queries to test (uses defaults if None)
    """
    if test_queries is None:
        test_queries = [
            "News about stock market and business",
            "Sports football game results",
            "Technology and science discoveries",
            "World news international politics",
        ]
    
    # Create semantic ID generator from trained model
    generator = SemanticIDGenerator(
        model=result.model,
        tokenizer=result.tokenizer,
        num_quantizers=result.semantic_id_mapping["config"]["num_quantizers"],
    )
    
    # Get semantic_id -> item mapping
    semantic_to_item = result.semantic_id_mapping["semantic_to_item"]
    
    print("Testing semantic ID generation:")
    print("=" * 60)
    
    for query in test_queries:
        # Generate using beam search for multiple candidates
        results = generator.generate_beam(query, num_beams=5, num_return_sequences=3)
        
        print(f"\nQuery: {query}")
        for i, (sem_id, score) in enumerate(results[:3]):
            item_id = semantic_to_item.get(sem_id)
            if item_id:
                print(f"  {i+1}. [OK] {sem_id} (score: {score:.2f}) -> {item_id}")
            else:
                print(f"  {i+1}. [--] {sem_id} (score: {score:.2f}) -> (not in catalogue)")
    
    return generator

## 4. Stage 1: Embedding Training

In stage 1, we:
- Add new semantic ID tokens to the vocabulary
- Freeze the entire backbone
- Train only the input/output embedding layers

This teaches the model to recognize and generate the new semantic ID tokens.

The `train()` function handles the complete pipeline:
1. Initialize W&B
2. Load RQ-VAE model from artifact or local path
3. Create semantic ID mapping for all catalogue items
4. Prepare training data (query -> semantic ID pairs)
5. Train the LLM
6. Log artifacts to W&B
7. Clean up

In [None]:
# Run Stage 1 training
stage1_result = train(stage1_config)

print("\n" + "=" * 50)
print("Stage 1 Training Complete!")
print("=" * 50)
print(f"Model saved to: {stage1_config.output_dir}")
print(f"Semantic IDs saved to: {stage1_config.semantic_ids_output_path}")

In [None]:
# Test Stage 1 model
print("\n=== Stage 1 Model Test ===")
_ = test_model(stage1_result)

In [None]:
# Clean up GPU memory before Stage 2
del stage1_result.model
torch.cuda.empty_cache()
print("Cleared GPU memory for Stage 2")

## 5. Stage 2: LoRA Fine-tuning

In stage 2, we:
- Load the stage 1 checkpoint (with trained embeddings)
- Apply LoRA adapters to all linear layers
- Fine-tune the model to generate semantic IDs from queries

In [None]:
# Stage 2 configuration - copy from stage 1 and modify
stage2_config = replace(
    stage1_config,
    
    # Stage 2: LoRA fine-tuning
    stage=2,
    stage1_checkpoint=stage1_config.output_dir,  # Load from stage 1
    # Or load from W&B artifact:
    # wandb_stage1_artifact="llm-stage1:latest",
    
    # LoRA settings
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    
    # Training hyperparameters (can adjust for stage 2)
    learning_rate=2e-4,
    num_train_epochs=3,
    
    # Output
    output_dir="checkpoints/llm_stage2",
    
    # W&B
    wandb_run_name="llm-stage2",
)

print("Stage 2 Configuration:")
print(f"  Stage 1 checkpoint: {stage2_config.stage1_checkpoint}")
print(f"  LoRA rank: {stage2_config.lora_r}")
print(f"  LoRA alpha: {stage2_config.lora_alpha}")
print(f"  Learning rate: {stage2_config.learning_rate}")
print(f"  Epochs: {stage2_config.num_train_epochs}")
print(f"  Output: {stage2_config.output_dir}")

In [None]:
# Run Stage 2 training
stage2_result = train(stage2_config)

print("\n" + "=" * 50)
print("Stage 2 Training Complete!")
print("=" * 50)
print(f"Model saved to: {stage2_config.output_dir}")

In [None]:
# Test Stage 2 model
print("\n=== Stage 2 Model Test ===")
generator = test_model(stage2_result)

## 6. Load Model Later

To load the trained model in a new session:

In [None]:
# Example: Load model from local checkpoint
# model, tokenizer = load_finetuned_model("checkpoints/llm_stage2")
# generator = SemanticIDGenerator(model, tokenizer, num_quantizers=3)
# semantic_id = generator.generate("Recommend me an article about cooking")
# print(f"Generated: {semantic_id}")

# Example: Load model from W&B artifact
# import wandb
# wandb.init(project="semantic-id-recommender", job_type="inference")
# artifact = wandb.use_artifact("llm-stage2:latest")
# artifact_dir = artifact.download()
# model, tokenizer = load_finetuned_model(artifact_dir)
# wandb.finish()

print("Training complete!")
print(f"\nFinal model: {stage2_config.output_dir}")
print(f"\nTo use the model:")
print(f"  from src.llm import load_finetuned_model, SemanticIDGenerator")
print(f"  model, tokenizer = load_finetuned_model('{stage2_config.output_dir}')")
print(f"  generator = SemanticIDGenerator(model, tokenizer, num_quantizers=3)")
print(f"  semantic_id = generator.generate('your query here')")