# DCRouter - Training

This notebook demonstrates how to train the **DCRouter** (Dual Contrastive Router).

## Overview

DCRouter uses dual contrastive learning with a mDEBERTa transformer backbone to route queries.
It learns to distinguish between good and bad LLM matches using contrastive loss.

**Key Features**:
- Transformer-based (mDEBERTa) backbone
- Dual contrastive loss for better discrimination
- Cluster-based negative sampling
- State-of-the-art routing performance

## 1. Environment Setup

In [None]:
# Install required packages (for Colab)
# !pip install llmrouter transformers torch

In [None]:
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
import torch
from llmrouter.models.dcrouter import DCRouter, DCRouterTrainer
from llmrouter.utils import setup_environment

setup_environment()

# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

DCRouter uses the following configuration parameters:

| Parameter | Description | Default |
|-----------|-------------|--------|
| `hidden_state_dim` | Backbone hidden dimension | 768 |
| `similarity_function` | Similarity metric | "cos" |
| `batch_size` | Training batch size | 32 |
| `training_steps` | Total training steps | 500 |
| `learning_rate` | Learning rate | 5e-5 |
| `top_k` | Top-k LLMs for positive samples | 3 |
| `last_k` | Last-k LLMs for negative samples | 3 |
| `temperature` | Softmax temperature | 1.0 |

In [None]:
import yaml

CONFIG_PATH = "configs/model_config_train/dcrouter.yaml"

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

print("Current Configuration:")
print("=" * 50)
print(yaml.dump(config, default_flow_style=False))

## 3. Initialize Router

In [None]:
# Initialize DCRouter with configuration
router = DCRouter(yaml_path=CONFIG_PATH)

print("Router initialized successfully!")
print(f"Number of training samples: {len(router.routing_data_train)}")
print(f"Number of LLM candidates: {len(router.llm_data)}")
print(f"LLM candidates: {list(router.llm_data.keys())}")
print(f"Backbone model: {config['model_path'].get('backbone_model', 'microsoft/mdeberta-v3-base')}")

## 4. Training

In [None]:
# Initialize trainer
trainer = DCRouterTrainer(router=router, device=device)

print("Trainer initialized!")
print(f"Device: {device}")
print(f"Save path: {trainer.save_model_path}")

In [None]:
# Train the model
print("Starting training...")
print("=" * 50)
print("Note: DCRouter training uses dual contrastive learning.")
print("This may take some time depending on your hardware.")
print("=" * 50)

trainer.train()

print("=" * 50)
print("Training completed!")

## 5. Model Verification

In [None]:
# Verify the trained model
import torch

# Load saved model
model_path = trainer.save_model_path
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location='cpu')
    print(f"Model loaded from: {model_path}")
    print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'state_dict'}")
else:
    print(f"Model not found at: {model_path}")

In [None]:
# Test prediction
test_query = {"query": "What is the capital of France?"}
result = router.route_single(test_query)

print(f"Test query: {test_query['query']}")
print(f"Routed to: {result['model_name']}")

## 6. Training Curve Analysis

In [None]:
import matplotlib.pyplot as plt

# If training logs are available, plot them
if hasattr(trainer, 'loss_history') and trainer.loss_history:
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(trainer.loss_history)
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("Training history not available for plotting.")

## Summary

In this notebook, we:

1. **Loaded Configuration**: Set up DCRouter with YAML configuration
2. **Initialized Router**: Created DCRouter with mDEBERTa backbone
3. **Trained Model**: Used dual contrastive learning
4. **Verified Model**: Tested routing with sample queries

**Key Takeaways**:
- DCRouter uses transformer-based embeddings
- Contrastive learning helps distinguish good/bad LLM matches
- GPU training recommended for faster convergence

**Next Steps**:
- Use `02_dcrouter_inference.ipynb` for inference
- Experiment with different temperature values