# KNNMultiRoundRouter - Training

This notebook demonstrates how to train the **KNNMultiRoundRouter**.

## Overview

KNNMultiRoundRouter extends the KNNRouter with a multi-round pipeline:
1. **Decompose**: Break down complex queries into sub-queries
2. **Route**: Use KNN to route each sub-query to the best model
3. **Execute**: Call APIs to get responses from routed models
4. **Aggregate**: Combine sub-query responses into final answer

**Key Features**:
- KNN-based routing (same as single-round KNNRouter)
- Multi-round decomposition and aggregation
- Supports both local LLM (vLLM) and API-based decomposition
- Configurable K value and distance metrics

## 1. Environment Setup

In [None]:
# For Google Colab: Clone repository and install dependencies
import os

if 'COLAB_GPU' in os.environ:
    !git clone https://github.com/ulab-uiuc/LLMRouter.git
    %cd LLMRouter
    !pip install -e .
    !pip install pyyaml scikit-learn

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
from llmrouter.models.knnmultiroundrouter import KNNMultiRoundRouter, KNNMultiRoundTrainer
from llmrouter.utils import setup_environment

setup_environment()
print("Environment setup complete!")

## 2. Configuration

KNNMultiRoundRouter uses the following configuration parameters:

### KNN Parameters

| Parameter | Description | Default |
|-----------|-------------|--------|
| `n_neighbors` | Number of neighbors (K value) | 5 |
| `weights` | Weight function: "uniform" or "distance" | "uniform" |
| `algorithm` | KNN algorithm: "auto", "ball_tree", "kd_tree", "brute" | "auto" |
| `metric` | Distance metric | "minkowski" |
| `p` | Power for Minkowski (1=Manhattan, 2=Euclidean) | 2 |

### Multi-Round Parameters

| Parameter | Description | Default |
|-----------|-------------|--------|
| `base_model` | LLM for decomposition/aggregation | "Qwen/Qwen2.5-3B-Instruct" |
| `use_local_llm` | Use vLLM for local inference | false |
| `api_endpoint` | API endpoint for execution | - |

In [None]:
import yaml

CONFIG_PATH = "configs/model_config_train/knnmultiroundrouter.yaml"

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

print("Current Configuration:")
print("=" * 50)
print(yaml.dump(config, default_flow_style=False))

## 3. Initialize Router

In [None]:
router = KNNMultiRoundRouter(yaml_path=CONFIG_PATH)

print("Router initialized successfully!")
print(f"Number of training samples: {len(router.routing_data_train)}")
print(f"Number of LLM candidates: {len(router.llm_data)}")
print(f"K value (n_neighbors): {config['hparam']['n_neighbors']}")

## 4. Training

Training fits the KNN model on query embeddings and their best LLM labels.

In [None]:
trainer = KNNMultiRoundTrainer(router=router, device='cpu')

print("Trainer initialized!")
print(f"Using device: cpu")

In [None]:
print("Starting training...")
print("=" * 50)

trainer.train()

print("=" * 50)
print("Training completed!")

## 5. Hyperparameter Tuning

In [None]:
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
import numpy as np

# Prepare data for hyperparameter tuning
X_train = router.query_embedding_train.numpy() if hasattr(router.query_embedding_train, 'numpy') else router.query_embedding_train
y_train = router.best_llm_train

# Test different K values
k_values = [1, 3, 5, 7, 9, 11]
print("K Value Comparison (5-fold CV):")
print("=" * 40)

best_k = 5
best_score = 0

for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')
    mean_score = scores.mean()
    print(f"K={k}: Accuracy = {mean_score:.4f} (+/- {scores.std():.4f})")
    
    if mean_score > best_score:
        best_score = mean_score
        best_k = k

print(f"\nBest K: {best_k} with accuracy: {best_score:.4f}")

In [None]:
# Test different distance metrics
metrics = [('euclidean', 2), ('manhattan', 1), ('minkowski', 3)]

print("\nDistance Metric Comparison:")
print("=" * 40)

for metric_name, p in metrics:
    if metric_name == 'minkowski':
        knn = KNeighborsClassifier(n_neighbors=best_k, metric='minkowski', p=p)
        name = f"Minkowski (p={p})"
    else:
        knn = KNeighborsClassifier(n_neighbors=best_k, metric=metric_name)
        name = metric_name.capitalize()
    
    scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')
    print(f"{name}: Accuracy = {scores.mean():.4f} (+/- {scores.std():.4f})")

## 6. Model Verification

In [None]:
# Test routing on a sample query (without execution)
# Note: Multi-round routers perform decomposition + routing, not just routing

test_queries = [
    {"query": "What is the capital of France?"},
    {"query": "Explain machine learning in simple terms."},
    {"query": "Calculate the integral of x^2."},
]

print("Test Routing (KNN-based):")
print("=" * 60)

for i, query in enumerate(test_queries, 1):
    # Use the underlying KNN router for simple routing test
    result = router._route_sub_query(query['query'])
    print(f"{i}. {query['query'][:50]}...")
    print(f"   Routed to: {result}")

## 7. Save Model

In [None]:
import pickle

save_path = config['model_path']['save_model_path']
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, 'wb') as f:
    pickle.dump(router.model, f)

print(f"Model saved to: {save_path}")

## Summary

In this notebook, we:

1. **Loaded Configuration**: Set up KNNMultiRoundRouter with YAML config
2. **Trained Model**: Fitted KNN classifier on query embeddings
3. **Tuned Hyperparameters**: Tested different K values and distance metrics
4. **Verified Model**: Tested routing on sample queries
5. **Saved Model**: Persisted trained model for inference

**Key Differences from Single-Round KNNRouter**:
- Supports query decomposition into sub-queries
- Aggregates responses from multiple models
- Uses LLM for decomposition and aggregation steps

**Next Steps**:
- Use `02_knnmultiroundrouter_inference.ipynb` for full pipeline inference