In [1]:
import torch
from transformers import (
    BertTokenizer, 
    BertForSequenceClassification,
    AutoModelForCausalLM,
    AutoTokenizer
)
from typing import Dict, Any, Optional
import json
from dataclasses import dataclass
from pathlib import Path

@dataclass
class ModelConfig:
    name: str
    model_path: str
    tokenizer_path: str
    max_length: int = 512
    temperature: float = 0.7
    
class ModelRouter:
    def __init__(self, classifier_path: str = "../models/prompt_classifier"):

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.classifier = BertForSequenceClassification.from_pretrained(classifier_path)
        self.classifier_tokenizer = BertTokenizer.from_pretrained(classifier_path)
        self.classifier.to(self.device)
        self.classifier.eval()
        
        self.model_configs = {
            "GSM8K": ModelConfig(
                name="GSM8K Solver",
                model_path="meta-math/MetaMath-Mistral-7B",  
                tokenizer_path="meta-math/MetaMath-Mistral-7B"
            ),
            "HumanEval": ModelConfig(
                name="Code Generator",
                model_path="Qwen/Qwen2.5-Coder-7B",  
                tokenizer_path="Qwen/Qwen2.5-Coder-7B",
                max_length=1024
            )
        }
        
        self.models: Dict[str, Any] = {}
        self.tokenizers: Dict[str, Any] = {}
        self.load_all_models()
        
    def load_all_models(self) -> None:
        for category, config in self.model_configs.items():
            print(f"Loading {category} model...")
            self.models[category] = AutoModelForCausalLM.from_pretrained(config.model_path)
            self.tokenizers[category] = AutoTokenizer.from_pretrained(config.tokenizer_path)
            self.models[category].to(self.device)
            self.models[category].eval()
    
        
    def classify_prompt(self, prompt: str) -> dict:
        encoding = self.classifier_tokenizer(
            prompt,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        with torch.no_grad():
            outputs = self.classifier(input_ids, attention_mask=attention_mask)
            probabilities = torch.softmax(outputs.logits, dim=1)
            prediction = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][prediction].item()
        
        dataset_mapping = {0: "GSM8K", 1: "HumanEval"}
        return {
            "category": dataset_mapping[prediction],
            "confidence": confidence,
            "probabilities": {
                "GSM8K": probabilities[0][0].item(),
                "HumanEval": probabilities[0][1].item()
            }
        }
    
    def generate_response(self, prompt: str, category: str) -> str:
        config = self.model_configs[category]
        model = self.models[category]
        tokenizer = self.tokenizers[category]
        
        inputs = tokenizer(
            prompt,
            max_length=config.max_length,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=config.max_length,
                temperature=config.temperature,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
            
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def process_prompt(self, prompt: str, override_category: Optional[str] = None) -> dict:
        if override_category:
            category = override_category
            classification = {"category": category, "confidence": 1.0}
        else:
            classification = self.classify_prompt(prompt)
            category = classification["category"]
        
        response = self.generate_response(prompt, category)
        
        return {
            "classification": classification,
            "category": category,
            "response": response,
            "model_used": self.model_configs[category].name
        }


if __name__ == "__main__":
    router = ModelRouter()
    
    math_prompt = "If John has 5 apples and gives 2 to Mary, how many does he have left?"
    code_prompt = """Write a function that takes two lists and returns their intersection."""
    
    math_result = router.process_prompt(math_prompt)
    code_result = router.process_prompt(code_prompt)
    
    print("\nMath Problem:")
    print(f"Classification: {math_result['classification']}")
    print(f"Response: {math_result['response']}")
    
    print("\nCoding Problem:")
    print(f"Classification: {code_result['classification']}")
    print(f"Response: {code_result['response']}")


Loading GSM8K model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading HumanEval model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 260.00 MiB. GPU 0 has a total capacity of 47.30 GiB of which 152.88 MiB is free. Process 584838 has 862.00 MiB memory in use. Process 599321 has 6.37 GiB memory in use. Including non-PyTorch memory, this process has 39.91 GiB memory in use. Of the allocated memory 39.39 GiB is allocated by PyTorch, and 113.68 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)