# CausalLMRouter - Training

This notebook demonstrates how to train the **CausalLMRouter** (Causal Language Model Router).

## Overview

CausalLMRouter finetunes a causal language model (e.g., Llama-2-7B) to predict the best LLM for routing.
It uses LoRA (Low-Rank Adaptation) for efficient finetuning.

**Key Features**:
- Uses powerful LLM backbone (Llama-2)
- Efficient LoRA finetuning
- Can understand complex query semantics
- Supports vLLM for fast inference

**Requirements**:
- GPU with at least 16GB VRAM recommended
- HuggingFace access to Llama-2 models

## 1. Environment Setup

In [None]:
# Install required packages (for Colab)
# !pip install llmrouter transformers torch peft accelerate bitsandbytes

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
import torch
from llmrouter.models.causallm_router import CausalLMRouter, CausalLMTrainer
from llmrouter.utils import setup_environment

setup_environment()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# HuggingFace login (required for Llama-2 access)
# from huggingface_hub import login
# login(token="your_hf_token")

## 2. Configuration

CausalLMRouter uses the following configuration parameters:

| Parameter | Description | Default |
|-----------|-------------|--------|
| `base_model` | Base LLM for finetuning | "meta-llama/Llama-2-7b-hf" |
| `use_lora` | Enable LoRA finetuning | true |
| `lora_r` | LoRA rank | 16 |
| `lora_alpha` | LoRA alpha | 32 |
| `lora_dropout` | LoRA dropout | 0.1 |
| `num_epochs` | Training epochs | 3 |
| `batch_size` | Batch size | 4 |
| `learning_rate` | Learning rate | 2e-5 |

In [None]:
import yaml

CONFIG_PATH = "configs/model_config_train/causallm_router.yaml"

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

print("Current Configuration:")
print("=" * 50)
print(yaml.dump(config, default_flow_style=False))

## 3. Initialize Router

In [None]:
router = CausalLMRouter(yaml_path=CONFIG_PATH)

print("Router initialized successfully!")
print(f"Number of training samples: {len(router.routing_data_train)}")
print(f"Number of LLM candidates: {len(router.llm_data)}")
print(f"LLM candidates: {list(router.llm_data.keys())}")
print(f"Base model: {config['hparam']['base_model']}")

## 4. Training Data Preparation

In [None]:
# Understand the training data format
print("Training Data Format:")
print("=" * 50)
print("\nThe model is trained to predict the best LLM given a query.")
print("\nInput format:")
print("  Query: {user query}")
print("  Best model: ")
print("\nTarget format:")
print("  {best_model_name}")

## 5. Training

In [None]:
trainer = CausalLMTrainer(router=router, device=device)

print("Trainer initialized!")
print(f"Device: {device}")
print(f"Save path: {trainer.save_model_path}")
print(f"LoRA enabled: {config['hparam'].get('use_lora', True)}")

In [None]:
# Show trainable parameters
if hasattr(trainer, 'model'):
    total_params = sum(p.numel() for p in trainer.model.parameters())
    trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable %: {100 * trainable_params / total_params:.2f}%")

In [None]:
print("Starting training...")
print("=" * 50)
print("Note: CausalLM training requires significant GPU memory.")
print("Consider reducing batch_size if you encounter OOM errors.")
print("=" * 50)

trainer.train()

print("=" * 50)
print("Training completed!")

## 6. Model Verification

In [None]:
# Check saved model
import os

save_path = trainer.save_model_path
if os.path.exists(save_path):
    print(f"Model saved at: {save_path}")
    
    # List saved files
    if os.path.isdir(save_path):
        files = os.listdir(save_path)
        print(f"\nSaved files:")
        for f in files:
            size = os.path.getsize(os.path.join(save_path, f)) / 1e6
            print(f"  {f}: {size:.2f} MB")
else:
    print(f"Model not found at: {save_path}")

In [None]:
# Test prediction
test_query = {"query": "What is the capital of France?"}
result = router.route_single(test_query)

print(f"Test query: {test_query['query']}")
print(f"Routed to: {result['model_name']}")

## Summary

In this notebook, we:

1. **Loaded Configuration**: Set up CausalLMRouter with YAML configuration
2. **Initialized Router**: Created router with Llama-2 backbone
3. **Applied LoRA**: Efficient finetuning with low-rank adaptation
4. **Trained Model**: Finetuned to predict best LLM for queries
5. **Saved Model**: LoRA weights and merged model saved

**Key Takeaways**:
- CausalLMRouter uses powerful LLM understanding
- LoRA enables efficient finetuning (only ~0.1% params trainable)
- Requires GPU with sufficient memory

**Next Steps**:
- Use `02_causallm_router_inference.ipynb` for inference
- Consider using vLLM for faster inference