# GMTRouter - Training

This notebook demonstrates how to train the **GMTRouter** (Graph-based Multi-Turn Router).

## Overview

GMTRouter uses a Heterogeneous Graph Neural Network (HeteroGNN) to model complex relationships
in multi-turn conversations for personalized LLM routing.

**Key Features**:
- 5 Node types: User, Session, Query, LLM, Response
- 21 Edge types capturing various relationships
- Personalized routing based on user preferences
- Multi-turn conversation support

**Requirements**:
- PyTorch 2.0+
- PyTorch Geometric (optional but recommended)
- GMTRouter dataset

## 1. Environment Setup

In [None]:
# Install required packages (for Colab)
# !pip install llmrouter torch torch-geometric

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
import torch
from llmrouter.models.gmtrouter import GMTRouter, GMTRouterTrainer
from llmrouter.utils import setup_environment

setup_environment()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 2. Configuration

GMTRouter uses the following configuration parameters:

| Parameter | Description | Default |
|-----------|-------------|--------|
| `num_gnn_layers` | Number of HGT layers | 2 |
| `hidden_dim` | Hidden dimension | 128 |
| `dropout` | Dropout rate | 0.1 |
| `epochs` | Training epochs | 350 |
| `lr` | Learning rate | 5e-4 |
| `objective` | Training objective | "auc" |

In [None]:
import yaml

CONFIG_PATH = "configs/model_config_train/gmtrouter.yaml"

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

print("Current Configuration:")
print("=" * 50)
print(yaml.dump(config, default_flow_style=False))

## 3. Data Preparation

GMTRouter requires specific data format. Download the dataset first.

In [None]:
# Check if data exists
data_path = config.get('data_path', {}).get('data_root', './data')
dataset_name = config.get('dataset', {}).get('name', 'mt_bench')

expected_path = os.path.join(data_path, dataset_name)

if os.path.exists(expected_path):
    print(f"Dataset found at: {expected_path}")
    files = os.listdir(expected_path)
    print(f"Files: {files}")
else:
    print(f"Dataset not found at: {expected_path}")
    print("\nPlease download the GMTRouter dataset:")
    print("1. Download from Google Drive")
    print("2. Extract to ./data/")
    print("3. Expected structure:")
    print("   ./data/mt_bench/training_set.jsonl")
    print("   ./data/mt_bench/valid_set.jsonl")
    print("   ./data/mt_bench/test_set.jsonl")

## 4. Initialize Router

In [None]:
router = GMTRouter(yaml_path=CONFIG_PATH)

print("Router initialized successfully!")
print(f"Number of GNN layers: {config['gmt_config']['num_gnn_layers']}")
print(f"Hidden dimension: {config['gmt_config']['hidden_dim']}")
print(f"Personalization: {config['gmt_config']['personalization']}")

## 5. Graph Structure

GMTRouter builds a heterogeneous graph with:
- **User nodes**: User embeddings and preferences
- **Session nodes**: Conversation session representations
- **Query nodes**: Query embeddings
- **LLM nodes**: LLM model embeddings
- **Response nodes**: Response quality scores

In [None]:
print("GMTRouter Graph Structure:")
print("=" * 50)
print("\nNode Types:")
print("  1. User     - User preferences and history")
print("  2. Session  - Conversation sessions")
print("  3. Query    - User queries")
print("  4. LLM      - Language models")
print("  5. Response - Model responses")
print("\nEdge Types (21 total):")
print("  - own/owned_by (User-Session)")
print("  - contains/contained_by (Session-Query)")
print("  - answered_by/answered_to (Query-Response)")
print("  - generated_by/generated (LLM-Response)")
print("  - next/prev (Query-Query temporal)")
print("  - ... and more")

## 6. Training

In [None]:
trainer = GMTRouterTrainer(router=router, device=device)

print("Trainer initialized!")
print(f"Device: {device}")

In [None]:
print("Starting training...")
print("=" * 50)
print("Note: GMTRouter training uses pairwise learning on the graph.")
print("This may take significant time for large datasets.")
print("=" * 50)

trainer.train()

print("=" * 50)
print("Training completed!")

## 7. Model Verification

In [None]:
# Check saved model
save_path = config['model_path'].get('save_model_path', './saved_models/gmtrouter/gmtrouter.pt')

if os.path.exists(save_path):
    print(f"Model saved at: {save_path}")
    checkpoint = torch.load(save_path, map_location='cpu')
    print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'state_dict'}")
else:
    print(f"Model not found at: {save_path}")

In [None]:
# Test prediction
test_query = {
    "query": "What is machine learning?",
    "user_id": "test_user",
    "session_id": "test_session"
}
result = router.route_single(test_query)

print(f"Test query: {test_query['query']}")
print(f"User: {test_query['user_id']}")
print(f"Routed to: {result['model_name']}")

## Summary

In this notebook, we:

1. **Loaded Configuration**: Set up GMTRouter with YAML configuration
2. **Understood Graph Structure**: 5 node types, 21 edge types
3. **Trained Model**: Pairwise learning on heterogeneous graph
4. **Verified Model**: Tested personalized routing

**Key Takeaways**:
- GMTRouter captures complex multi-turn conversation patterns
- User personalization improves routing quality
- HeteroGNN learns from relational structure

**Next Steps**:
- Use `02_gmtrouter_inference.ipynb` for inference
- Experiment with different datasets