# CausalLMRouter - Inference

This notebook demonstrates how to use a trained **CausalLMRouter** for inference.

## 1. Environment Setup

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)

In [None]:
import torch
from llmrouter.models.causallm_router import CausalLMRouter
from llmrouter.utils import setup_environment
import yaml

setup_environment()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 2. Load Trained Router

In [None]:
CONFIG_PATH = "configs/model_config_train/causallm_router.yaml"

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Use merged model for inference
config['model_path']['load_model_path'] = config['model_path'].get(
    'load_model_path', 
    os.path.join(config['model_path']['save_model_path'], 'merged')
)

router = CausalLMRouter(yaml_path=CONFIG_PATH)
print(f"Router loaded with {len(router.llm_data)} LLM candidates")

## 3. Query Routing

In [None]:
EXAMPLE_QUERIES = [
    {"query": "What is the capital of France?"},
    {"query": "Solve the equation: 2x + 5 = 15"},
    {"query": "Write a Python function to check if a number is prime."},
    {"query": "Explain the difference between supervised and unsupervised learning."},
    {"query": "What are the main causes of climate change?"},
]

print("Routing Results:")
print("=" * 70)

for i, query in enumerate(EXAMPLE_QUERIES, 1):
    result = router.route_single(query)
    print(f"{i}. {query['query'][:55]}...")
    print(f"   Routed to: {result['model_name']}")

## 4. Using vLLM for Fast Inference (Optional)

In [None]:
# vLLM provides faster inference for CausalLM models
# Install: pip install vllm

try:
    from vllm import LLM, SamplingParams
    vllm_available = True
    print("vLLM is available for accelerated inference")
except ImportError:
    vllm_available = False
    print("vLLM not installed. Using standard HuggingFace inference.")
    print("For faster inference, install vLLM: pip install vllm")

## 5. File-Based Inference

Load queries from a file and save results.

In [None]:
import json

# Load queries from a JSONL file
def load_queries_from_file(file_path):
    """Load queries from a JSONL file."""
    queries = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                queries.append(json.loads(line))
    return queries

# Save results to a JSONL file
def save_results_to_file(results, output_path):
    """Save routing results to a JSONL file."""
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w', encoding='utf-8') as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    print(f"Results saved to: {output_path}")

# Example: Load from default query file
QUERY_FILE = "data/example_data/query_data/default_query_test.jsonl"
OUTPUT_FILE = "outputs/causallm_router_results.jsonl"

if os.path.exists(QUERY_FILE):
    # Load queries
    file_queries = load_queries_from_file(QUERY_FILE)
    print(f"Loaded {len(file_queries)} queries from: {QUERY_FILE}")
    
    # Route queries
    file_results = router.route_batch(batch=file_queries[:10])
    print(f"Routed {len(file_results)} queries")
    
    # Save results
    save_results_to_file(file_results, OUTPUT_FILE)
    
    # Show sample results
    print(f"\nSample results:")
    for i, result in enumerate(file_results[:3], 1):
        print(f"  {i}. {result.get('query', '')[:40]}... -> {result['model_name']}")
else:
    print(f"Query file not found: {QUERY_FILE}")
    print("Create a JSONL file with format: {\"query\": \"Your question\"}")

## Summary

This notebook demonstrated:
1. Loading a trained CausalLMRouter
2. Routing queries using LLM-based inference

CausalLMRouter is effective for:
- Complex queries requiring deep semantic understanding
- High-quality routing with LLM reasoning

**Tips for Production**:
- Use vLLM for faster inference
- Consider quantization (4-bit, 8-bit) for memory efficiency
- Batch queries for better throughput