# LLMRouter - Data Preparation

This notebook covers the complete data preparation pipeline for LLMRouter:
1. **Dataset Download**: Download benchmark datasets from HuggingFace
2. **Query Data Generation**: Generate query data JSONL files
3. **LLM Embeddings Generation**: Generate LLM feature embeddings
4. **API Calling & Evaluation**: Call LLM APIs and evaluate responses
5. **Final Routing Data**: Generate unified embeddings and routing data

## 1. Environment Setup

In [None]:
# Install required packages (for Colab)
# !pip install llmrouter datasets transformers torch pandas numpy tqdm litellm peft

In [None]:
import os
import sys
import json
import random
from pathlib import Path

# Set project root
PROJECT_ROOT = Path(os.getcwd()).parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import LLMRouter utilities
from llmrouter.utils import setup_environment
from llmrouter.data.data_loader import DataLoader

setup_environment()

## 2. Configuration

In [None]:
# Configuration for data generation
CONFIG = {
    # Number of samples per task
    "sample_size": 100,  # Set to smaller number for testing, increase for full training
    
    # Train/test split ratio
    "train_ratio": 0.8,
    
    # Random seed for reproducibility
    "random_seed": 42,
    
    # Output paths (relative to project root)
    "output_paths": {
        "query_data_train": "data/example_data/query_data/default_query_train.jsonl",
        "query_data_test": "data/example_data/query_data/default_query_test.jsonl",
        "query_embedding_data": "data/example_data/routing_data/query_embeddings_longformer.pt",
        "routing_data_train": "data/example_data/routing_data/default_routing_train_data.jsonl",
        "routing_data_test": "data/example_data/routing_data/default_routing_test_data.jsonl",
        "llm_data": "data/example_data/llm_candidates/default_llm.json",
        "llm_embedding_data": "data/example_data/llm_candidates/default_llm_embeddings.json"
    },
    
    # API settings (for LLM calling)
    "max_workers": 10,  # Number of parallel API calls
}

print("Configuration loaded!")

## 3. Dataset Download

LLMRouter uses 11 benchmark datasets covering different task categories:
- **Math**: GSM8K, MATH
- **Code**: MBPP, HumanEval
- **World Knowledge**: Natural QA, Trivia QA
- **Popular Benchmarks**: MMLU, GPQA
- **Commonsense Reasoning**: CommonsenseQA, OpenbookQA, ARC-Challenge

In [None]:
from datasets import load_dataset

# Set cache directory (optional)
CACHE_DIR = str(PROJECT_ROOT / "data" / "cache")
os.makedirs(CACHE_DIR, exist_ok=True)

def download_datasets(sample_size=100, random_seed=42):
    """Download and sample from benchmark datasets."""
    random.seed(random_seed)
    samples = {}
    
    # 1. Natural QA
    print("Downloading Natural QA...")
    try:
        natural_qa = load_dataset('RUC-NLPIR/FlashRAG_datasets', 'nq', cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in natural_qa else list(natural_qa.keys())[0]
        indices = random.sample(range(len(natural_qa[split_name])), min(sample_size, len(natural_qa[split_name])))
        samples['natural_qa'] = [natural_qa[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['natural_qa'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['natural_qa'] = []
    
    # 2. Trivia QA
    print("Downloading Trivia QA...")
    try:
        trivia_qa = load_dataset("trivia_qa", "rc.nocontext", cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in trivia_qa else list(trivia_qa.keys())[0]
        indices = random.sample(range(len(trivia_qa[split_name])), min(sample_size, len(trivia_qa[split_name])))
        samples['trivia_qa'] = [trivia_qa[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['trivia_qa'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['trivia_qa'] = []
    
    # 3. MMLU
    print("Downloading MMLU...")
    try:
        mmlu = load_dataset("cais/mmlu", "all", cache_dir=CACHE_DIR)
        split_name = 'auxiliary_train' if 'auxiliary_train' in mmlu else list(mmlu.keys())[0]
        indices = random.sample(range(len(mmlu[split_name])), min(sample_size, len(mmlu[split_name])))
        samples['mmlu'] = [mmlu[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['mmlu'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['mmlu'] = []
    
    # 4. GPQA
    print("Downloading GPQA...")
    try:
        gpqa = load_dataset("Idavidrein/gpqa", "gpqa_main", cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in gpqa else list(gpqa.keys())[0]
        indices = random.sample(range(len(gpqa[split_name])), min(sample_size, len(gpqa[split_name])))
        samples['gpqa'] = [gpqa[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['gpqa'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['gpqa'] = []
    
    # 5. GSM8K
    print("Downloading GSM8K...")
    try:
        gsm8k = load_dataset('gsm8k', 'main', cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in gsm8k else list(gsm8k.keys())[0]
        indices = random.sample(range(len(gsm8k[split_name])), min(sample_size, len(gsm8k[split_name])))
        samples['gsm8k'] = [gsm8k[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['gsm8k'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['gsm8k'] = []
    
    # 6. CommonsenseQA
    print("Downloading CommonsenseQA...")
    try:
        commonsense_qa = load_dataset('commonsense_qa', cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in commonsense_qa else list(commonsense_qa.keys())[0]
        indices = random.sample(range(len(commonsense_qa[split_name])), min(sample_size, len(commonsense_qa[split_name])))
        samples['commonsense_qa'] = [commonsense_qa[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['commonsense_qa'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['commonsense_qa'] = []
    
    # 7. ARC-Challenge
    print("Downloading ARC-Challenge...")
    try:
        arc = load_dataset('allenai/ai2_arc', 'ARC-Challenge', cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in arc else list(arc.keys())[0]
        indices = random.sample(range(len(arc[split_name])), min(sample_size, len(arc[split_name])))
        samples['arc_challenge'] = [arc[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['arc_challenge'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['arc_challenge'] = []
    
    # 8. OpenbookQA
    print("Downloading OpenbookQA...")
    try:
        openbook = load_dataset('allenai/openbookqa', 'main', cache_dir=CACHE_DIR)
        split_name = 'train' if 'train' in openbook else list(openbook.keys())[0]
        indices = random.sample(range(len(openbook[split_name])), min(sample_size, len(openbook[split_name])))
        samples['openbook_qa'] = [openbook[split_name][i] for i in indices]
        print(f"  Extracted {len(samples['openbook_qa'])} samples")
    except Exception as e:
        print(f"  Error: {e}")
        samples['openbook_qa'] = []
    
    return samples

# Download datasets
raw_samples = download_datasets(sample_size=CONFIG['sample_size'], random_seed=CONFIG['random_seed'])
print(f"\nTotal tasks downloaded: {len([k for k, v in raw_samples.items() if v])}")

## 4. Query Data Generation

Convert raw samples into standardized query format.

In [None]:
def process_samples_to_query_data(samples):
    """Convert raw samples to standardized query data format."""
    data_all = []
    
    # Process Natural QA
    for sample in samples.get('natural_qa', []):
        data_all.append({
            'task_name': 'natural_qa',
            'query': sample['question'],
            'ground_truth': sample['golden_answers'][0] if sample.get('golden_answers') else sample.get('answer', ''),
            'metric': 'f1_score',
            'choices': None,
            'task_id': None
        })
    
    # Process Trivia QA
    for sample in samples.get('trivia_qa', []):
        data_all.append({
            'task_name': 'trivia_qa',
            'query': sample['question'],
            'ground_truth': sample['answer']['normalized_aliases'][0] if sample.get('answer') else '',
            'metric': 'f1_score',
            'choices': None,
            'task_id': None
        })
    
    # Process MMLU
    for sample in samples.get('mmlu', []):
        data_all.append({
            'task_name': 'mmlu',
            'query': sample['question'],
            'ground_truth': chr(65 + sample['answer']),  # Convert 0,1,2,3 to A,B,C,D
            'metric': 'em_mc',
            'choices': {'text': sample['choices'], 'labels': ['A', 'B', 'C', 'D']},
            'task_id': None
        })
    
    # Process GPQA
    for sample in samples.get('gpqa', []):
        options = [
            sample['Correct Answer'],
            sample['Incorrect Answer 1'],
            sample['Incorrect Answer 2'],
            sample['Incorrect Answer 3']
        ]
        # Shuffle options
        mapping = list(range(4))
        random.shuffle(mapping)
        shuffled_options = [options[mapping.index(i)] for i in range(4)]
        correct_idx = mapping.index(0)
        
        data_all.append({
            'task_name': 'gpqa',
            'query': sample['Question'],
            'ground_truth': chr(65 + correct_idx),
            'metric': 'em_mc',
            'choices': {'text': shuffled_options, 'labels': ['A', 'B', 'C', 'D']},
            'task_id': None
        })
    
    # Process GSM8K
    for sample in samples.get('gsm8k', []):
        data_all.append({
            'task_name': 'gsm8k',
            'query': sample['question'],
            'ground_truth': sample['answer'],
            'metric': 'GSM8K',
            'choices': None,
            'task_id': None
        })
    
    # Process CommonsenseQA
    for sample in samples.get('commonsense_qa', []):
        data_all.append({
            'task_name': 'commonsense_qa',
            'query': sample['question'],
            'ground_truth': sample['answerKey'],
            'metric': 'em_mc',
            'choices': sample['choices'],
            'task_id': None
        })
    
    # Process ARC-Challenge
    for sample in samples.get('arc_challenge', []):
        data_all.append({
            'task_name': 'arc_challenge',
            'query': sample['question'],
            'ground_truth': sample['answerKey'],
            'metric': 'em_mc',
            'choices': sample['choices'],
            'task_id': None
        })
    
    # Process OpenbookQA
    for sample in samples.get('openbook_qa', []):
        data_all.append({
            'task_name': 'openbook_qa',
            'query': sample['question_stem'],
            'ground_truth': sample['answerKey'],
            'metric': 'em_mc',
            'choices': sample['choices'],
            'task_id': None
        })
    
    return data_all

# Process samples
query_data = process_samples_to_query_data(raw_samples)
print(f"Total processed samples: {len(query_data)}")

# Show sample counts by task
from collections import Counter
task_counts = Counter(item['task_name'] for item in query_data)
for task, count in sorted(task_counts.items()):
    print(f"  {task}: {count}")

In [None]:
# Split into train/test
random.seed(CONFIG['random_seed'])
random.shuffle(query_data)

train_size = int(len(query_data) * CONFIG['train_ratio'])
train_data = query_data[:train_size]
test_data = query_data[train_size:]

print(f"Train samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")

In [None]:
# Save query data to JSONL files
def save_query_data_jsonl(data_list, output_path):
    """Save query data to JSONL file."""
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in data_list:
            record = {
                'task_name': item['task_name'],
                'query': item['query'],
                'ground_truth': item['ground_truth'],
                'metric': item['metric'],
                'choices': json.dumps(item['choices']) if item['choices'] is not None else None,
                'task_id': item['task_id']
            }
            f.write(json.dumps(record, ensure_ascii=False) + '\n')
    
    print(f"Saved {len(data_list)} records to {output_path}")

# Save files
train_output_path = str(PROJECT_ROOT / CONFIG['output_paths']['query_data_train'])
test_output_path = str(PROJECT_ROOT / CONFIG['output_paths']['query_data_test'])

save_query_data_jsonl(train_data, train_output_path)
save_query_data_jsonl(test_data, test_output_path)

## 5. LLM Candidates Configuration

Define the LLM candidates for routing.

In [None]:
# Example LLM candidates configuration
# You can modify this based on your available LLMs

LLM_CANDIDATES = {
    "gpt-4o-mini": {
        "size": "small",
        "feature": "Fast and cost-effective model for simple tasks",
        "input_price": 0.15,
        "output_price": 0.60,
        "model": "gpt-4o-mini",
        "service": "OpenAI",
        "api_endpoint": "https://api.openai.com/v1"
    },
    "gpt-4o": {
        "size": "large",
        "feature": "Most capable GPT-4 model for complex tasks",
        "input_price": 2.50,
        "output_price": 10.00,
        "model": "gpt-4o",
        "service": "OpenAI",
        "api_endpoint": "https://api.openai.com/v1"
    },
    "claude-3-haiku": {
        "size": "small",
        "feature": "Fast and efficient Claude model",
        "input_price": 0.25,
        "output_price": 1.25,
        "model": "claude-3-haiku-20240307",
        "service": "Anthropic",
        "api_endpoint": "https://api.anthropic.com/v1"
    },
    "claude-3-5-sonnet": {
        "size": "large",
        "feature": "Balanced Claude model for most tasks",
        "input_price": 3.00,
        "output_price": 15.00,
        "model": "claude-3-5-sonnet-20241022",
        "service": "Anthropic",
        "api_endpoint": "https://api.anthropic.com/v1"
    }
}

# Save LLM candidates
llm_data_path = str(PROJECT_ROOT / CONFIG['output_paths']['llm_data'])
os.makedirs(os.path.dirname(llm_data_path), exist_ok=True)

with open(llm_data_path, 'w', encoding='utf-8') as f:
    json.dump(LLM_CANDIDATES, f, indent=2, ensure_ascii=False)

print(f"Saved {len(LLM_CANDIDATES)} LLM candidates to {llm_data_path}")

## 6. Generate Query Embeddings

Generate Longformer embeddings for all queries.

In [None]:
import torch
from tqdm import tqdm

# Import embedding function
from llmrouter.utils import get_longformer_embedding

def generate_query_embeddings(query_data_list):
    """Generate Longformer embeddings for all queries."""
    embeddings = []
    
    for item in tqdm(query_data_list, desc="Generating embeddings"):
        query = item['query']
        embedding = get_longformer_embedding(query)
        embeddings.append(embedding)
    
    return torch.stack(embeddings)

# Combine train and test data for unified embeddings
all_data = train_data + test_data

print(f"Generating embeddings for {len(all_data)} queries...")
all_embeddings = generate_query_embeddings(all_data)

print(f"Embeddings shape: {all_embeddings.shape}")

In [None]:
# Save embeddings
embedding_path = str(PROJECT_ROOT / CONFIG['output_paths']['query_embedding_data'])
os.makedirs(os.path.dirname(embedding_path), exist_ok=True)

torch.save(all_embeddings, embedding_path)
print(f"Saved embeddings to {embedding_path}")

## 7. API Calling and Evaluation (Optional)

This step calls LLM APIs to generate responses and evaluates their performance.

**Note**: This step requires API keys to be set in the environment.

In [None]:
# Set API keys (replace with your actual keys)
# os.environ['OPENAI_API_KEY'] = 'your-openai-key'
# os.environ['ANTHROPIC_API_KEY'] = 'your-anthropic-key'
# os.environ['API_KEYS'] = json.dumps(['key1', 'key2'])  # For multiple keys

# Check if API keys are set
api_keys_available = bool(os.environ.get('API_KEYS') or os.environ.get('OPENAI_API_KEY'))
print(f"API keys available: {api_keys_available}")

In [None]:
# Skip this cell if you don't have API keys
# This will use the existing routing data from the repository

if api_keys_available:
    from llmrouter.utils import call_api, generate_task_query
    import pandas as pd
    from concurrent.futures import ThreadPoolExecutor, as_completed
    
    def call_llm_for_query(args):
        """Call a single LLM for a query."""
        query_item, model_name, llm_config = args
        
        try:
            # Format query based on task
            formatted_query = generate_task_query(query_item['task_name'], query_item)
            
            # Prepare API request
            request = {
                'api_endpoint': llm_config['api_endpoint'],
                'query': formatted_query,
                'model_name': model_name,
                'api_name': llm_config['model']
            }
            
            # Call API
            result = call_api(request, max_tokens=512, temperature=0.7)
            
            return {
                **query_item,
                'model_name': model_name,
                'response': result.get('response', ''),
                'prompt_tokens': result.get('prompt_tokens', 0),
                'completion_tokens': result.get('completion_tokens', 0),
                'success': 'error' not in result
            }
        except Exception as e:
            return {
                **query_item,
                'model_name': model_name,
                'response': f'ERROR: {str(e)}',
                'prompt_tokens': 0,
                'completion_tokens': 0,
                'success': False
            }
    
    # Create tasks for all query-model combinations
    tasks = []
    for item in train_data[:10]:  # Limit for demo
        for model_name, config in LLM_CANDIDATES.items():
            tasks.append((item, model_name, config))
    
    print(f"Processing {len(tasks)} query-model combinations...")
    
    # Execute in parallel
    results = []
    with ThreadPoolExecutor(max_workers=CONFIG['max_workers']) as executor:
        futures = {executor.submit(call_llm_for_query, task): task for task in tasks}
        for future in tqdm(as_completed(futures), total=len(futures)):
            results.append(future.result())
    
    print(f"Completed {len(results)} API calls")
else:
    print("Skipping API calling - no API keys configured")
    print("You can use the existing routing data from: data/example_data/routing_data/")

## 8. Using Existing Data

If you skipped the API calling step, you can use the existing example data.

In [None]:
# Load existing routing data (if available)
from llmrouter.utils import load_jsonl

routing_train_path = str(PROJECT_ROOT / CONFIG['output_paths']['routing_data_train'])
routing_test_path = str(PROJECT_ROOT / CONFIG['output_paths']['routing_data_test'])

if os.path.exists(routing_train_path):
    routing_train = load_jsonl(routing_train_path)
    print(f"Loaded {len(routing_train)} training routing samples")
else:
    print(f"Routing train data not found at {routing_train_path}")

if os.path.exists(routing_test_path):
    routing_test = load_jsonl(routing_test_path)
    print(f"Loaded {len(routing_test)} test routing samples")
else:
    print(f"Routing test data not found at {routing_test_path}")

## 9. Data Verification

Verify that all required data files are available for training.

In [None]:
# Verify all data files
required_files = [
    ('Query Train Data', CONFIG['output_paths']['query_data_train']),
    ('Query Test Data', CONFIG['output_paths']['query_data_test']),
    ('Query Embeddings', CONFIG['output_paths']['query_embedding_data']),
    ('Routing Train Data', CONFIG['output_paths']['routing_data_train']),
    ('Routing Test Data', CONFIG['output_paths']['routing_data_test']),
    ('LLM Data', CONFIG['output_paths']['llm_data']),
]

print("Data file verification:")
print("=" * 60)

all_available = True
for name, path in required_files:
    full_path = str(PROJECT_ROOT / path)
    exists = os.path.exists(full_path)
    status = "OK" if exists else "MISSING"
    if not exists:
        all_available = False
    print(f"{name:25} [{status}]")

print("=" * 60)
if all_available:
    print("All data files are ready for training!")
else:
    print("Some files are missing. Please generate them or use example data.")

## Summary

After running this notebook, you should have:

1. **Query Data Files**:
   - `query_data_train.jsonl` - Training queries
   - `query_data_test.jsonl` - Test queries

2. **Embedding File**:
   - `query_embeddings_longformer.pt` - Unified query embeddings

3. **Routing Data Files** (if API calling was performed):
   - `routing_data_train.jsonl` - Training routing data with responses
   - `routing_data_test.jsonl` - Test routing data with responses

4. **LLM Configuration**:
   - `default_llm.json` - LLM candidates configuration

Now you can proceed to train any of the routers using the method-specific notebooks!