# KNNRouter - Inference

This notebook demonstrates how to use a trained **KNNRouter** for inference.

## Overview

We will cover:
1. Loading a trained KNNRouter model
2. Single query routing
3. Batch query routing
4. Full inference with API calls
5. Performance evaluation

## 1. Environment Setup

In [None]:
# Install required packages (for Colab)
# !pip install llmrouter scikit-learn transformers torch

In [None]:
import os
import sys
from pathlib import Path

# Set project root
PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
# Import required modules
from llmrouter.models.knnrouter import KNNRouter
from llmrouter.utils import setup_environment, load_model, get_longformer_embedding

setup_environment()
print("Environment setup complete!")

## 2. Configuration

In [None]:
import yaml

# Use inference configuration
# The inference config should have load_model_path set
CONFIG_PATH = "configs/model_config_train/knnrouter.yaml"

# Load configuration
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Add load_model_path for inference
config['model_path']['load_model_path'] = config['model_path'].get(
    'load_model_path', 
    config['model_path']['save_model_path']
)

print("Configuration loaded!")
print(f"Model path: {config['model_path']['load_model_path']}")

In [None]:
# Create inference config file
INFERENCE_CONFIG_PATH = "configs/model_config_test/knnrouter_inference.yaml"

os.makedirs(os.path.dirname(INFERENCE_CONFIG_PATH), exist_ok=True)

inference_config = config.copy()
inference_config['model_path']['load_model_path'] = 'saved_models/knnrouter/knnrouter.pkl'

with open(INFERENCE_CONFIG_PATH, 'w') as f:
    yaml.dump(inference_config, f)

print(f"Inference config saved to: {INFERENCE_CONFIG_PATH}")

## 3. Load Trained Router

In [None]:
# Initialize router for inference
router = KNNRouter(yaml_path=INFERENCE_CONFIG_PATH)

print("Router loaded successfully!")
print(f"Number of LLM candidates: {len(router.llm_data)}")
print(f"LLM candidates: {list(router.llm_data.keys())}")

In [None]:
# Load the trained KNN model
model_path = os.path.join(PROJECT_ROOT, inference_config['model_path']['load_model_path'])

if os.path.exists(model_path):
    knn_model = load_model(model_path)
    print(f"Loaded model from: {model_path}")
    print(f"Model classes: {knn_model.classes_}")
else:
    print(f"Model not found at: {model_path}")
    print("Please run the training notebook first!")

## 4. Single Query Routing

In [None]:
# Example queries for different task types
EXAMPLE_QUERIES = [
    {
        "query": "What is the capital of France?",
        "task_type": "world_knowledge"
    },
    {
        "query": "Solve the equation: 2x + 5 = 15",
        "task_type": "math"
    },
    {
        "query": "Write a Python function to check if a number is prime.",
        "task_type": "code"
    },
    {
        "query": "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly?",
        "task_type": "reasoning"
    },
    {
        "query": "Explain the theory of relativity in simple terms.",
        "task_type": "explanation"
    }
]

print(f"Prepared {len(EXAMPLE_QUERIES)} example queries")

In [None]:
# Route a single query
def route_single_query(query_dict):
    """Route a single query and return the result."""
    result = router.route_single(query_dict)
    return result

# Test with first example
query = EXAMPLE_QUERIES[0]
result = route_single_query(query)

print(f"Query: {query['query']}")
print(f"Task Type: {query['task_type']}")
print(f"Routed to: {result['model_name']}")

In [None]:
# Route all example queries
print("Routing Results:")
print("=" * 80)

for i, query in enumerate(EXAMPLE_QUERIES, 1):
    result = route_single_query(query)
    print(f"\n{i}. Query: {query['query'][:60]}...")
    print(f"   Task: {query['task_type']}")
    print(f"   Routed to: {result['model_name']}")

## 5. Get Routing Probabilities

In [None]:
import numpy as np

def get_routing_probabilities(query_text):
    """Get routing probabilities for all LLM candidates."""
    # Generate embedding
    embedding = get_longformer_embedding(query_text).numpy().reshape(1, -1)
    
    # Get probabilities
    proba = knn_model.predict_proba(embedding)[0]
    
    # Create results dictionary
    results = dict(zip(knn_model.classes_, proba))
    
    # Sort by probability
    results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
    
    return results

# Test with first query
query_text = EXAMPLE_QUERIES[0]['query']
probabilities = get_routing_probabilities(query_text)

print(f"Query: {query_text}")
print("\nRouting Probabilities:")
for model, prob in probabilities.items():
    bar = "#" * int(prob * 50)
    print(f"  {model:30} {prob:.4f} {bar}")

In [None]:
# Visualize routing probabilities for all queries
import matplotlib.pyplot as plt

fig, axes = plt.subplots(len(EXAMPLE_QUERIES), 1, figsize=(12, 3*len(EXAMPLE_QUERIES)))

for idx, query in enumerate(EXAMPLE_QUERIES):
    probs = get_routing_probabilities(query['query'])
    
    ax = axes[idx] if len(EXAMPLE_QUERIES) > 1 else axes
    models = list(probs.keys())
    values = list(probs.values())
    
    bars = ax.barh(models, values, color='steelblue')
    ax.set_xlim(0, 1)
    ax.set_xlabel('Probability')
    ax.set_title(f"Query {idx+1}: {query['query'][:50]}... ({query['task_type']})")
    
    # Highlight the chosen model
    max_idx = values.index(max(values))
    bars[max_idx].set_color('green')

plt.tight_layout()
plt.show()

## 6. Batch Query Routing

### Option 1: Route queries from configuration file

The router automatically loads test data from the path specified in `query_data_test` in the YAML config.

In [None]:
# Load test data for batch routing
if router.query_data_test is not None:
    test_data = router.query_data_test[:20]  # Use first 20 samples
    print(f"Loaded {len(test_data)} test samples")
else:
    print("No test data available. Using example queries.")
    test_data = EXAMPLE_QUERIES

In [None]:
# Batch routing (route-only, no API calls)
def batch_route_only(queries):
    """Route multiple queries without calling APIs."""
    results = []
    for query in queries:
        result = router.route_single(query)
        results.append(result)
    return results

# Route batch
batch_results = batch_route_only(test_data)

print(f"Routed {len(batch_results)} queries")

# Show routing distribution
from collections import Counter
model_counts = Counter(r['model_name'] for r in batch_results)

print("\nRouting Distribution:")
for model, count in model_counts.most_common():
    percentage = count / len(batch_results) * 100
    print(f"  {model}: {count} ({percentage:.1f}%)")

### Option 2: Load queries from your own file

You can also load queries from a custom JSONL file and pass them to the router.

In [None]:
import json

# Method 1: Load from custom JSONL file
def load_queries_from_file(file_path):
    """Load queries from a JSONL file."""
    queries = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                queries.append(json.loads(line))
    return queries

# Example: Load from the default query test file
QUERY_FILE = "data/example_data/query_data/default_query_test.jsonl"

if os.path.exists(QUERY_FILE):
    file_queries = load_queries_from_file(QUERY_FILE)
    print(f"Loaded {len(file_queries)} queries from file")
    print(f"\nSample query: {file_queries[0]}")
    
    # Route queries from file
    file_results = router.route_batch(batch=file_queries[:10])
    
    print(f"\nRouted {len(file_results)} queries from file:")
    for i, result in enumerate(file_results[:3], 1):
        print(f"  {i}. Query: {result.get('query', '')[:50]}...")
        print(f"     Routed to: {result['model_name']}")
else:
    print(f"File not found: {QUERY_FILE}")
    print("You can create your own JSONL file with format:")
    print('  {"query": "Your question here"}')

In [None]:
# Method 2: Save routing results to file
def save_results_to_file(results, output_path):
    """Save routing results to a JSONL file."""
    with open(output_path, 'w', encoding='utf-8') as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    print(f"Results saved to: {output_path}")

# Example: Save results
OUTPUT_FILE = "outputs/knnrouter_results.jsonl"
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

if 'file_results' in dir() and file_results:
    save_results_to_file(file_results, OUTPUT_FILE)
    
    # Verify saved file
    print(f"\nVerifying saved file:")
    with open(OUTPUT_FILE, 'r') as f:
        first_line = json.loads(f.readline())
        print(f"First result: {first_line}")

In [None]:
# Visualize routing distribution
import matplotlib.pyplot as plt

models = list(model_counts.keys())
counts = list(model_counts.values())

plt.figure(figsize=(10, 6))
plt.pie(counts, labels=models, autopct='%1.1f%%', startangle=90)
plt.title('Query Routing Distribution')
plt.axis('equal')
plt.show()

## 7. Full Inference with API Calls (Optional)

In [None]:
# Check if API keys are available
api_available = bool(
    os.environ.get('OPENAI_API_KEY') or 
    os.environ.get('ANTHROPIC_API_KEY') or
    os.environ.get('API_KEYS')
)

print(f"API keys available: {api_available}")

if not api_available:
    print("\nTo enable full inference with API calls, set one of:")
    print("  - OPENAI_API_KEY")
    print("  - ANTHROPIC_API_KEY")
    print("  - API_KEYS (JSON array of keys)")

In [None]:
# Full inference (with API calls) - only if API keys are available
if api_available:
    # Use route_batch which includes API calls
    full_results = router.route_batch(batch=test_data[:5])  # Limit to 5 for demo
    
    print("Full Inference Results:")
    print("=" * 80)
    
    for result in full_results:
        print(f"\nQuery: {result['query'][:60]}...")
        print(f"Routed to: {result['model_name']}")
        print(f"Response: {result.get('response', 'N/A')[:100]}...")
        print(f"Success: {result.get('success', 'N/A')}")
else:
    print("Skipping full inference - no API keys configured")

## 8. Performance Evaluation

In [None]:
# Evaluate routing accuracy on test data
# This compares the router's choice with the oracle (best performing model)

if router.routing_data_test is not None:
    test_df = router.routing_data_test
    
    # Find best model for each query (oracle)
    oracle_best = test_df.loc[
        test_df.groupby('query')['performance'].idxmax()
    ][['query', 'model_name', 'performance']]
    oracle_best.columns = ['query', 'oracle_model', 'oracle_performance']
    
    print(f"Test set: {len(oracle_best)} unique queries")
    print(f"\nOracle model distribution:")
    print(oracle_best['oracle_model'].value_counts())
else:
    print("No test routing data available for evaluation")

In [None]:
# Compare router predictions with oracle
if router.routing_data_test is not None:
    from tqdm import tqdm
    
    correct = 0
    total = 0
    results_comparison = []
    
    for _, row in tqdm(oracle_best.iterrows(), total=len(oracle_best), desc="Evaluating"):
        query = row['query']
        oracle_model = row['oracle_model']
        
        # Get router prediction
        result = router.route_single({'query': query})
        predicted_model = result['model_name']
        
        is_correct = predicted_model == oracle_model
        correct += int(is_correct)
        total += 1
        
        results_comparison.append({
            'query': query[:50],
            'oracle': oracle_model,
            'predicted': predicted_model,
            'correct': is_correct
        })
    
    accuracy = correct / total if total > 0 else 0
    print(f"\nRouting Accuracy: {accuracy:.4f} ({correct}/{total})")

In [None]:
# Show confusion matrix
if router.routing_data_test is not None:
    import pandas as pd
    from sklearn.metrics import confusion_matrix, classification_report
    
    comparison_df = pd.DataFrame(results_comparison)
    
    # Classification report
    print("Classification Report:")
    print(classification_report(comparison_df['oracle'], comparison_df['predicted']))
    
    # Confusion matrix
    labels = sorted(set(comparison_df['oracle']) | set(comparison_df['predicted']))
    cm = confusion_matrix(comparison_df['oracle'], comparison_df['predicted'], labels=labels)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    plt.xticks(range(len(labels)), labels, rotation=45, ha='right')
    plt.yticks(range(len(labels)), labels)
    plt.xlabel('Predicted')
    plt.ylabel('Oracle')
    plt.tight_layout()
    plt.show()

## 9. Using CLI for Inference

In [None]:
# You can also use the CLI for inference
print("CLI Commands for Inference:")
print("="*60)
print()
print("# Route a single query (route-only, no API call):")
print('llmrouter infer --router knnrouter --config configs/model_config_test/knnrouter_inference.yaml --query "What is AI?" --route-only')
print()
print("# Route with full inference (API call):")
print('llmrouter infer --router knnrouter --config configs/model_config_test/knnrouter_inference.yaml --query "What is AI?"')
print()
print("# Batch inference from file:")
print('llmrouter infer --router knnrouter --config configs/model_config_test/knnrouter_inference.yaml --input queries.txt --output results.json')

## Summary

In this notebook, we:

1. **Loaded Trained Model**: Set up KNNRouter with trained model
2. **Single Query Routing**: Routed individual queries to LLMs
3. **Routing Probabilities**: Analyzed routing confidence
4. **Batch Routing**: Processed multiple queries efficiently
5. **Full Inference**: Called LLM APIs (when available)
6. **Performance Evaluation**: Compared with oracle performance

**Key Findings**:
- KNNRouter provides interpretable routing decisions
- Routing probabilities show model confidence
- Performance can be tuned via K and distance metric

**Next Steps**:
- Try different routers (SVMRouter, MLPRouter, etc.)
- Experiment with ensemble routing
- Deploy as API service