# LLM Fine-tuning for Semantic ID Generation

Fine-tune a small LLM to generate semantic IDs from user queries using a trained RQ-VAE model.

**Environment:** RunPod Jupyter with GPU

**Prerequisites:**
- Trained RQ-VAE model (`models/rqvae_model.pt`)
- Semantic ID mappings (`data/semantic_ids.json`)
- Item catalogue (`data/mcf_articles.jsonl`)

**Training Stages:**
1. **Stage 1**: Train only embedding layers for new semantic ID tokens (backbone frozen)
2. **Stage 2**: LoRA fine-tuning on the full model using stage 1 checkpoint

**Outputs:**
- `checkpoints/llm_stage1/` - Stage 1 model (embeddings only)
- `checkpoints/llm_stage2/` - Stage 2 model (full fine-tuned)

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Navigate to project root if needed
import sys
import os

if os.path.abspath(".").endswith("notebooks"):
    repo_root = os.path.abspath("..")
    print(f"Current directory: {os.getcwd()}, changing to {repo_root}")
    if repo_root not in sys.path:
        sys.path.insert(0, repo_root)
    os.chdir(repo_root)
else:
    print(f"Already in project root: {os.getcwd()}")

In [None]:
from unsloth import FastLanguageModel  # Always import unsloth first

In [None]:
import torch
import json
from pathlib import Path
from datasets import Dataset

from src.llm import (
    FinetuneConfig,
    finetune_model,
    prepare_training_data,
    SemanticIDDataset,
    SemanticIDGenerator,
    load_finetuned_model,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Paths
CATALOGUE_PATH = "data/mcf_articles.jsonl"
SEMANTIC_IDS_PATH = "data/semantic_ids.json"
TRAIN_DATA_PATH = "data/llm_train.jsonl"
VAL_DATA_PATH = "data/llm_val.jsonl"

# Model checkpoints
STAGE1_OUTPUT_DIR = "checkpoints/llm_stage1"
STAGE2_OUTPUT_DIR = "checkpoints/llm_stage2"

# Base model - choose based on your GPU memory
# Options: "unsloth/Qwen3-4B", "unsloth/Qwen3-1.7B", "HuggingFaceTB/SmolLM2-1.7B-Instruct"
BASE_MODEL = "unsloth/Qwen3-4B"

# Must match RQ-VAE configuration
NUM_QUANTIZERS = 3
CODEBOOK_SIZE = 64

# Training hyperparameters
MAX_SEQ_LENGTH = 512
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4

# Stage 1: Embedding training
STAGE1_EPOCHS = 3
STAGE1_LR = 2e-4

# Stage 2: LoRA fine-tuning
STAGE2_EPOCHS = 3
STAGE2_LR = 2e-4
LORA_R = 16
LORA_ALPHA = 32

# W&B Artifact logging - set to True to save model checkpoints as W&B artifacts
LOG_WANDB_ARTIFACTS = False

# Create output directories
Path(STAGE1_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
Path(STAGE2_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print("Configuration:")
print(f"  Base model: {BASE_MODEL}")
print(f"  Semantic ID tokens: {NUM_QUANTIZERS} quantizers x {CODEBOOK_SIZE} codes")
print(f"  Stage 1 output: {STAGE1_OUTPUT_DIR}")
print(f"  Stage 2 output: {STAGE2_OUTPUT_DIR}")
print(f"  Log W&B artifacts: {LOG_WANDB_ARTIFACTS}")

## 3. Verify Environment Variables

In [None]:
HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

print("=== Environment Variables ===")
if HF_TOKEN:
    print(f"HF_TOKEN is set (length: {len(HF_TOKEN)})")
else:
    print("HF_TOKEN not set - some models may not be accessible")

if WANDB_API_KEY:
    print(f"WANDB_API_KEY is set (length: {len(WANDB_API_KEY)})")
else:
    print("WANDB_API_KEY not set - Weights & Biases logging disabled")

## 4. Initialize Weights & Biases

In [None]:
import wandb

WANDB_PROJECT = "semantic-id-recommender"

if WANDB_API_KEY:
    # Don't initialize here - let the trainer handle it
    # Just verify we can connect
    wandb.login()
    print(f"Weights & Biases ready")
    print(f"  Project: {WANDB_PROJECT}")
    REPORT_TO = "wandb"
else:
    print("Wandb logging disabled")
    REPORT_TO = "none"

## 5. Load Semantic ID Mappings

In [None]:
# Load semantic ID mappings from RQ-VAE training
with open(SEMANTIC_IDS_PATH) as f:
    semantic_mapping = json.load(f)

item_to_semantic = semantic_mapping["item_to_semantic"]
semantic_to_item = semantic_mapping["semantic_to_item"]
rqvae_config = semantic_mapping.get("config", {})

print(f"Loaded {len(item_to_semantic)} item -> semantic ID mappings")
print(f"Loaded {len(semantic_to_item)} semantic ID -> item mappings")
print(f"\nRQ-VAE config from mapping:")
for k, v in rqvae_config.items():
    print(f"  {k}: {v}")

# Verify config matches
if rqvae_config:
    assert rqvae_config.get("num_quantizers") == NUM_QUANTIZERS, "NUM_QUANTIZERS mismatch!"
    assert rqvae_config.get("codebook_size") == CODEBOOK_SIZE, "CODEBOOK_SIZE mismatch!"
    print("\nConfiguration matches RQ-VAE settings")

In [None]:
# Load catalogue items for recommendation test callback
items = []
with open(CATALOGUE_PATH) as f:
    for line in f:
        items.append(json.loads(line))

# Create semantic_id -> item info mapping for callback
semantic_id_to_item = {}
for item in items:
    item_id = item.get("item_id", item.get("id", ""))
    if str(item_id) in item_to_semantic:
        sem_id = item_to_semantic[str(item_id)]["semantic_id"]
        semantic_id_to_item[sem_id] = {
            "item_id": item_id,
            "title": item.get("title", ""),
            "category": item.get("category", ""),
        }

print(f"\nLoaded {len(items)} catalogue items")
print(f"Created {len(semantic_id_to_item)} semantic_id -> item mappings")

## 6. Prepare Training Data

In [None]:
# Define query templates for training
# Customize these based on your use case
QUERY_TEMPLATES = {
    "predict_semantic_id": [
        "What is the semantic ID of {title}?",
        "Find the semantic ID for {title}",
        "Semantic ID for: {title}",
        "Recommend something like {title}. What is its semantic ID?",
        "{title} - what's the semantic ID?",
        "Item: {title}. Semantic ID?",
        "Get semantic ID: {title}",
    ],
    "predict_attribute": [
        "What is the {field_name} for semantic ID {semantic_id}?",
        "For semantic ID {semantic_id}, what is the {field_name}?",
        "Get {field_name} for {semantic_id}",
        "{semantic_id} - what's the {field_name}?",
    ],
}

# Field mapping: template placeholder -> actual field name in catalogue
# Adjust based on your catalogue structure
FIELD_MAPPING = {
    "title": "title",
    "category": "category",
}

# ID field name in your catalogue
ID_FIELD = "item_id"  # or "id" depending on your catalogue

In [None]:
# Prepare training and validation datasets
train_dataset, val_dataset = prepare_training_data(
    catalogue_path=CATALOGUE_PATH,
    semantic_ids_path=SEMANTIC_IDS_PATH,
    output_train_path=TRAIN_DATA_PATH,
    output_val_path=VAL_DATA_PATH,
    query_templates=QUERY_TEMPLATES,
    field_mapping=FIELD_MAPPING,
    id_field=ID_FIELD,
    num_examples_per_item=5,
    val_split=0.1,
    predict_semantic_id_ratio=0.8,
    seed=42,
)

print(f"\nTraining examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")

In [None]:
# Preview some training examples
print("Sample training examples:")
for i in range(min(3, len(train_dataset))):
    example = train_dataset[i]
    print(f"\n--- Example {i+1} ---")
    for msg in example["messages"]:
        role = msg["role"]
        content = msg["content"][:100] + "..." if len(msg["content"]) > 100 else msg["content"]
        print(f"{role}: {content}")

## 7. Stage 1: Embedding Training

In stage 1, we:
- Add new semantic ID tokens to the vocabulary
- Freeze the entire backbone
- Train only the input/output embedding layers

This teaches the model to recognize and generate the new semantic ID tokens.

In [None]:
# Test queries for the recommendation callback
TEST_QUERIES = [
    "Recommend me an article about cooking",
    "I want to read something about technology",
    "Find me a sports article",
]

# Stage 1 configuration
stage1_config = FinetuneConfig(
    # Model
    base_model=BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,
    
    # Semantic ID settings (must match RQ-VAE)
    num_quantizers=NUM_QUANTIZERS,
    codebook_size=CODEBOOK_SIZE,
    
    # Training
    stage=1,  # Embedding training only
    learning_rate=STAGE1_LR,
    batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=STAGE1_EPOCHS,
    warmup_ratio=0.03,
    
    # Logging and saving
    output_dir=STAGE1_OUTPUT_DIR,
    logging_steps=10,
    save_strategy="epoch",
    eval_steps=100,
    report_to=REPORT_TO,
    
    # W&B artifact logging
    log_wandb_artifacts=LOG_WANDB_ARTIFACTS,
    
    # Callbacks
    recommendation_test_queries=TEST_QUERIES,
    semantic_id_to_item=semantic_id_to_item,
)

print("Stage 1 Configuration:")
print(f"  Base model: {stage1_config.base_model}")
print(f"  Learning rate: {stage1_config.learning_rate}")
print(f"  Epochs: {stage1_config.num_train_epochs}")
print(f"  Effective batch size: {stage1_config.batch_size * stage1_config.gradient_accumulation_steps}")
print(f"  Log W&B artifacts: {stage1_config.log_wandb_artifacts}")

In [None]:
# Run Stage 1 training
print("Starting Stage 1: Embedding Training")
print("="*50)

stage1_model, stage1_tokenizer = finetune_model(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    config=stage1_config,
)

print("\nStage 1 complete!")
print(f"Model saved to: {STAGE1_OUTPUT_DIR}")

In [None]:
# Verify stage 1 model
print("Stage 1 model info:")
print(f"  Vocabulary size: {len(stage1_tokenizer)}")

# Check that semantic ID tokens were added
test_token = f"[SEM_0_0]"
token_id = stage1_tokenizer.convert_tokens_to_ids(test_token)
print(f"  Test token '{test_token}' -> ID: {token_id}")

# Count trainable parameters in stage 1
trainable_params = sum(p.numel() for p in stage1_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in stage1_model.parameters())
print(f"  Trainable params: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")

## 8. Stage 2: LoRA Fine-tuning

In stage 2, we:
- Load the stage 1 checkpoint (with trained embeddings)
- Apply LoRA adapters to all linear layers
- Fine-tune the model to generate semantic IDs from queries

In [None]:
# Clean up stage 1 model from GPU memory
del stage1_model
torch.cuda.empty_cache()
print("Cleared GPU memory for stage 2")

In [None]:
# Stage 2 configuration
stage2_config = FinetuneConfig(
    # Model - load from stage 1 checkpoint
    base_model=BASE_MODEL,  # Not used when stage1_checkpoint is set
    stage1_checkpoint=STAGE1_OUTPUT_DIR,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,
    
    # Semantic ID settings (must match RQ-VAE)
    num_quantizers=NUM_QUANTIZERS,
    codebook_size=CODEBOOK_SIZE,
    
    # LoRA settings
    lora_r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.05,
    
    # Training
    stage=2,  # LoRA fine-tuning
    learning_rate=STAGE2_LR,
    batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=STAGE2_EPOCHS,
    warmup_ratio=0.03,
    
    # Logging and saving
    output_dir=STAGE2_OUTPUT_DIR,
    logging_steps=10,
    save_strategy="epoch",
    eval_steps=100,
    report_to=REPORT_TO,
    
    # W&B artifact logging
    log_wandb_artifacts=LOG_WANDB_ARTIFACTS,
    
    # Callbacks
    recommendation_test_queries=TEST_QUERIES,
    semantic_id_to_item=semantic_id_to_item,
)

print("Stage 2 Configuration:")
print(f"  Stage 1 checkpoint: {stage2_config.stage1_checkpoint}")
print(f"  LoRA rank: {stage2_config.lora_r}")
print(f"  LoRA alpha: {stage2_config.lora_alpha}")
print(f"  Learning rate: {stage2_config.learning_rate}")
print(f"  Epochs: {stage2_config.num_train_epochs}")
print(f"  Log W&B artifacts: {stage2_config.log_wandb_artifacts}")

In [None]:
# Run Stage 2 training
print("Starting Stage 2: LoRA Fine-tuning")
print("="*50)

stage2_model, stage2_tokenizer = finetune_model(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    config=stage2_config,
)

print("\nStage 2 complete!")
print(f"Model saved to: {STAGE2_OUTPUT_DIR}")

## 9. Test the Fine-tuned Model

In [None]:
# Load the fine-tuned model for inference
model, tokenizer = load_finetuned_model(STAGE2_OUTPUT_DIR)

# Create semantic ID generator
generator = SemanticIDGenerator(
    model=model,
    tokenizer=tokenizer,
    num_quantizers=NUM_QUANTIZERS,
)

print("Model loaded for inference")

In [None]:
# Test with sample queries
test_queries = [
    "Recommend me an article about cooking",
    "I want to read something about technology",
    "Find me a sports article",
    "What article should I read about health?",
]

print("Testing semantic ID generation:")
print("="*50)

for query in test_queries:
    semantic_id = generator.generate(query)
    
    # Look up the item if we have a mapping
    item_info = semantic_id_to_item.get(semantic_id, {})
    
    print(f"\nQuery: {query}")
    print(f"  Semantic ID: {semantic_id}")
    if item_info:
        print(f"  Item: {item_info.get('title', 'Unknown')}")
        print(f"  Category: {item_info.get('category', 'Unknown')}")
    else:
        print(f"  (No matching item found)")

## 10. Save Final Model

Optionally merge LoRA weights and save as a standalone model.

In [None]:
# Save model info for later use
model_info = {
    "base_model": BASE_MODEL,
    "num_quantizers": NUM_QUANTIZERS,
    "codebook_size": CODEBOOK_SIZE,
    "stage1_checkpoint": STAGE1_OUTPUT_DIR,
    "stage2_checkpoint": STAGE2_OUTPUT_DIR,
    "training_config": {
        "stage1_epochs": STAGE1_EPOCHS,
        "stage1_lr": STAGE1_LR,
        "stage2_epochs": STAGE2_EPOCHS,
        "stage2_lr": STAGE2_LR,
        "lora_r": LORA_R,
        "lora_alpha": LORA_ALPHA,
    }
}

with open(f"{STAGE2_OUTPUT_DIR}/model_info.json", "w") as f:
    json.dump(model_info, f, indent=2)

print(f"Model info saved to {STAGE2_OUTPUT_DIR}/model_info.json")

# Note: W&B artifacts are now logged automatically by finetune_model
# when log_wandb_artifacts=True in FinetuneConfig

In [None]:
# Log final metrics to wandb
if WANDB_API_KEY and wandb.run:
    wandb.log({
        "final/base_model": BASE_MODEL,
        "final/num_quantizers": NUM_QUANTIZERS,
        "final/codebook_size": CODEBOOK_SIZE,
        "final/train_examples": len(train_dataset),
        "final/val_examples": len(val_dataset),
    })
    wandb.finish()
    print("Logged final metrics to wandb")

## 11. Load Model from W&B Artifacts (Optional)

If you've logged models as W&B artifacts, you can load them in future sessions.

In [None]:
# Example: Load model from W&B artifact
# Uncomment and modify the following to load a model from W&B

# # Initialize wandb (if not already)
# wandb.init(project=WANDB_PROJECT, job_type="inference")
# 
# # Download artifact - use "latest", "best", or a specific version like "v0"
# artifact = wandb.use_artifact("llm-stage2:latest")
# artifact_dir = artifact.download()
# 
# # Load the model
# model, tokenizer = load_finetuned_model(artifact_dir)
# generator = SemanticIDGenerator(model, tokenizer, num_quantizers=NUM_QUANTIZERS)
# 
# # Use the model
# semantic_id = generator.generate("Recommend me an article about cooking")
# print(f"Generated semantic ID: {semantic_id}")
# 
# # Check artifact metadata
# print(f"\nArtifact metadata:")
# for k, v in artifact.metadata.items():
#     print(f"  {k}: {v}")

print("To load a model from W&B artifacts, uncomment the code above.")

## 12. Upload to HuggingFace Hub (Optional)

In [None]:
# Upload to HuggingFace Hub
# Uncomment and modify the following to upload your model

# REPO_ID = "your-username/semantic-id-llm"  # Change this!
# 
# if HF_TOKEN:
#     stage2_model.push_to_hub(REPO_ID, token=HF_TOKEN)
#     stage2_tokenizer.push_to_hub(REPO_ID, token=HF_TOKEN)
#     print(f"Model uploaded to: https://huggingface.co/{REPO_ID}")
# else:
#     print("HF_TOKEN not set - cannot upload to Hub")

print("Training complete!")
print(f"\nFinal model saved to: {STAGE2_OUTPUT_DIR}")
print("\nTo use the model:")
print("  from src.llm import load_finetuned_model, SemanticIDGenerator")
print(f"  model, tokenizer = load_finetuned_model('{STAGE2_OUTPUT_DIR}')")
print(f"  generator = SemanticIDGenerator(model, tokenizer, num_quantizers={NUM_QUANTIZERS})")
print("  semantic_id = generator.generate('your query here')")