# KNNMultiRoundRouter - Inference

This notebook demonstrates how to use a trained **KNNMultiRoundRouter** for inference.

## Pipeline Overview

The multi-round routing pipeline consists of:

1. **Decompose**: Break complex queries into simpler sub-queries using LLM
2. **Route**: Use trained KNN to route each sub-query to the best model
3. **Execute**: Call the selected model API to get responses
4. **Aggregate**: Combine all sub-responses into a final answer

## 1. Environment Setup

In [None]:
# For Google Colab
import os

if 'COLAB_GPU' in os.environ:
    !git clone https://github.com/ulab-uiuc/LLMRouter.git
    %cd LLMRouter
    !pip install -e .
    !pip install pyyaml scikit-learn

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)

In [None]:
from llmrouter.models.knnmultiroundrouter import KNNMultiRoundRouter
from llmrouter.utils import setup_environment
import yaml

setup_environment()

## 2. Load Trained Router

In [None]:
CONFIG_PATH = "configs/model_config_train/knnmultiroundrouter.yaml"

router = KNNMultiRoundRouter(yaml_path=CONFIG_PATH)
print("Router loaded!")

# Load configuration
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

print(f"Base model for decomposition: {config.get('base_model', 'Qwen/Qwen2.5-3B-Instruct')}")
print(f"Use local LLM: {config.get('use_local_llm', False)}")

## 3. Simple Query Routing (Chat Mode)

For simple string queries, the router returns just the response.

In [None]:
# Simple chat mode - pass string, get string response
simple_query = "What is the capital of France and what is its population?"

print(f"Query: {simple_query}")
print("=" * 60)

try:
    response = router.route_single(simple_query)
    print(f"\nResponse:\n{response}")
except Exception as e:
    print(f"Error: {e}")
    print("Note: Multi-round routing requires API access for execution.")

## 4. Evaluation Mode

For evaluation with metrics, pass a dict with query, task_name, and ground_truth.

In [None]:
# Evaluation mode - pass dict, get detailed result with metrics
eval_query = {
    "query": "What is 15 * 23?",
    "task_name": "math",
    "ground_truth": "345"
}

print(f"Query: {eval_query['query']}")
print(f"Task: {eval_query['task_name']}")
print(f"Ground Truth: {eval_query['ground_truth']}")
print("=" * 60)

try:
    result = router.route_single(eval_query)
    
    print(f"\nResponse: {result.get('response', 'N/A')}")
    print(f"Success: {result.get('success', False)}")
    print(f"Prompt Tokens: {result.get('prompt_tokens', 0)}")
    print(f"Completion Tokens: {result.get('completion_tokens', 0)}")
    if 'task_performance' in result:
        print(f"Task Performance: {result['task_performance']:.2f}")
except Exception as e:
    print(f"Error: {e}")

## 5. Batch Processing

In [None]:
# Batch processing with multiple queries
batch_queries = [
    {"query": "Explain photosynthesis."},
    {"query": "What causes earthquakes?"},
    {"query": "How do computers work?"},
]

print(f"Processing {len(batch_queries)} queries...")
print("=" * 60)

try:
    results = router.route_batch(batch_queries)
    
    for i, result in enumerate(results, 1):
        print(f"\n{i}. Query: {result.get('query', 'N/A')[:50]}...")
        print(f"   Success: {result.get('success', False)}")
        print(f"   Response: {result.get('response', 'N/A')[:100]}...")
except Exception as e:
    print(f"Error: {e}")

## 6. Understanding the Pipeline

Let's examine the multi-round pipeline steps.

In [None]:
# Demonstrate the pipeline components
print("Multi-Round Pipeline Components:")
print("=" * 60)

print("\n1. DECOMPOSITION")
print("   - Uses LLM to break complex query into sub-queries")
print(f"   - Base model: {config.get('base_model', 'Qwen/Qwen2.5-3B-Instruct')}")

print("\n2. ROUTING (KNN-based)")
print(f"   - K value: {config['hparam']['n_neighbors']}")
print(f"   - Distance metric: {config['hparam']['metric']}")
print(f"   - Weight function: {config['hparam']['weights']}")

print("\n3. EXECUTION")
print("   - Calls routed model API for each sub-query")
print(f"   - API endpoint: {config.get('api_endpoint', 'Not configured')}")

print("\n4. AGGREGATION")
print("   - Combines sub-query responses into final answer")
print("   - Uses LLM for intelligent synthesis")

In [None]:
# Show available LLM candidates for routing
print("\nAvailable LLM Candidates:")
print("=" * 60)

for i, (name, info) in enumerate(router.llm_data.items(), 1):
    size = info.get('size', 'unknown')
    print(f"{i}. {name}: {size}B parameters")

## 7. Evaluation

In [None]:
from llmrouter.evaluator import Evaluator

try:
    evaluator = Evaluator(router=router)
    metrics = evaluator.eval()

    print("\nEvaluation Results:")
    print("=" * 50)
    for metric_name, value in metrics.items():
        if isinstance(value, float):
            print(f"{metric_name}: {value:.4f}")
        else:
            print(f"{metric_name}: {value}")
except Exception as e:
    print(f"Evaluation requires API access: {e}")

## 8. File-Based Inference

Load queries from a file and save results.

In [None]:
import json

# Load queries from a JSONL file
def load_queries_from_file(file_path):
    """Load queries from a JSONL file."""
    queries = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                queries.append(json.loads(line))
    return queries

# Save results to a JSONL file
def save_results_to_file(results, output_path):
    """Save routing results to a JSONL file."""
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w', encoding='utf-8') as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    print(f"Results saved to: {output_path}")

# Example: Load from default query file
QUERY_FILE = "data/example_data/query_data/default_query_test.jsonl"
OUTPUT_FILE = "outputs/knnmultiroundrouter_results.jsonl"

if os.path.exists(QUERY_FILE):
    # Load queries
    file_queries = load_queries_from_file(QUERY_FILE)
    print(f"Loaded {len(file_queries)} queries from: {QUERY_FILE}")
    
    # Route queries (limit to 5 for demo due to API costs)
    try:
        file_results = router.route_batch(file_queries[:5])
        print(f"Routed {len(file_results)} queries")
        
        # Save results
        save_results_to_file(file_results, OUTPUT_FILE)
        
        # Show sample results
        print(f"\nSample results:")
        for i, result in enumerate(file_results[:3], 1):
            print(f"  {i}. {result.get('query', '')[:40]}...")
            print(f"     Success: {result.get('success', False)}")
    except Exception as e:
        print(f"Error during batch routing: {e}")
else:
    print(f"Query file not found: {QUERY_FILE}")
    print("Create a JSONL file with format: {\"query\": \"Your question\"}")

## Summary

**KNNMultiRoundRouter** provides:
- Query decomposition for complex questions
- KNN-based routing for each sub-query
- Parallel execution across multiple models
- Intelligent response aggregation

**Use Cases**:
- Complex questions requiring multiple expertise areas
- Multi-step reasoning tasks
- Questions that benefit from specialized models

**Requirements**:
- Trained KNN model (from training notebook)
- API access for LLM execution
- Optional: vLLM for local decomposition/aggregation