# GraphRouter - Training

This notebook demonstrates how to train the **GraphRouter** (Graph Neural Network Router).

## Overview

GraphRouter uses a Graph Neural Network (GNN) to model the relationships between queries and LLMs.
It constructs a heterogeneous graph where queries and LLMs are nodes, and performance scores are edge weights.

**Key Features**:
- Graph-based representation of query-LLM interactions
- Message passing for learning representations
- Can capture complex relational patterns

## 1. Environment Setup

In [None]:
# Install required packages (for Colab)
# !pip install llmrouter torch torch-geometric

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
import torch
from llmrouter.models.graphrouter import GraphRouter, GraphTrainer
from llmrouter.utils import setup_environment

setup_environment()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 2. Configuration

GraphRouter uses the following configuration parameters:

| Parameter | Description | Default |
|-----------|-------------|--------|
| `hidden_dim` | GNN hidden layer dimension | 64 |
| `learning_rate` | Learning rate | 0.001 |
| `weight_decay` | L2 regularization | 0.0001 |
| `train_epoch` | Training epochs | 100 |
| `batch_size` | Batch size | 4 |
| `train_mask_rate` | Edge masking rate | 0.3 |
| `val_split_ratio` | Validation split | 0.2 |

In [None]:
import yaml

CONFIG_PATH = "configs/model_config_train/graphrouter.yaml"

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

print("Current Configuration:")
print("=" * 50)
print(yaml.dump(config, default_flow_style=False))

## 3. Initialize Router

In [None]:
router = GraphRouter(yaml_path=CONFIG_PATH)

print("Router initialized successfully!")
print(f"Number of training samples: {len(router.routing_data_train)}")
print(f"Number of LLM candidates: {len(router.llm_data)}")
print(f"LLM candidates: {list(router.llm_data.keys())}")

## 4. Graph Structure Visualization

In [None]:
# Understand the graph structure
print("Graph Structure Information:")
print("=" * 50)
print(f"\nNode types:")
print(f"  - Query nodes: Based on training queries")
print(f"  - LLM nodes: {len(router.llm_data)} models")
print(f"\nEdge types:")
print(f"  - Query -> LLM edges (performance scores)")
print(f"\nThe GNN learns to predict missing edges for new queries.")

## 5. Training

In [None]:
trainer = GraphTrainer(router=router, device=device)

print("Trainer initialized!")
print(f"Device: {device}")
print(f"Save path: {trainer.save_model_path}")

In [None]:
print("Starting training...")
print("=" * 50)

best_result = trainer.train()

print("=" * 50)
print("Training completed!")
if best_result:
    print(f"Best validation result: {best_result}")

## 6. Model Verification

In [None]:
# Verify the trained model
model_path = trainer.save_model_path
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location='cpu')
    print(f"Model loaded from: {model_path}")
else:
    print(f"Model not found at: {model_path}")

In [None]:
# Test prediction
test_query = {"query": "What is the capital of France?"}
result = router.route_single(test_query)

print(f"Test query: {test_query['query']}")
print(f"Routed to: {result['model_name']}")

## Summary

In this notebook, we:

1. **Loaded Configuration**: Set up GraphRouter with YAML configuration
2. **Understood Graph Structure**: Query-LLM bipartite graph
3. **Trained GNN Model**: Used message passing to learn representations
4. **Verified Model**: Tested routing with sample queries

**Key Takeaways**:
- GraphRouter models query-LLM relationships as a graph
- GNN can capture complex interaction patterns
- Edge masking during training improves generalization

**Next Steps**:
- Use `02_graphrouter_inference.ipynb` for inference
- Experiment with different GNN architectures