In [1]:
!pip install llama-cpp-python
!pip install datasets
!pip install textstat
!pip install bitsandbytes

Collecting llama-cpp-python
  Downloading llama_cpp_python-0.3.8.tar.gz (67.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting diskcache>=5.6.1 (from llama-cpp-python)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Downloading diskcache-5.6.3-py3-none-any.whl (45 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.5/45.5 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: llama-cpp-python
  Building wheel for llama-cpp-python (pyproject.toml) ... [?25l[?25hdone
  Created wheel for llama-cpp-python: filename=llama_cpp_python-0.3.8-cp311-cp311-linux_x86_64.whl size=6008021 sha256=12e98eb576a34120e7cec

In [2]:
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from datasets import load_dataset
import re
import numpy as np
from tqdm import tqdm
import torch
import nltk
import textstat
import random
import pickle

In [3]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [4]:
class Q_Learner():
    def __init__(self, models, learning_rate=0.1, discount_factor=0.9, epsilon=0.1, load_q_table=False):
        self.models = models
        # With 3 complexity levels and 3 length levels
        complexity_levels = ["simple", "moderate", "complex"]
        length_levels = ["short", "medium", "long"]
        self.num_models = len(models)

        # Initialize Q-table
        if load_q_table:
            with open('q_table.pkl', 'rb') as f:
                self.q_table = pickle.load(f)
        else:
            q_table = {}
            for complexity in complexity_levels:
                for length in length_levels:
                    q_table[(complexity, length)] = [0] * self.num_models
        self.q_table = q_table

        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon  # Exploration rate

        self.stats = {"cheap_model_uses": 0, "expensive_model_uses": 0, "rewards": []}

    def choose_model(self, state):
        # Epsilon-greedy policy for exploration vs exploitation
        if random.random() < self.epsilon:
            # Exploration: randomly choose a model
            model_index = random.randint(0, len(self.models) - 1)
        else:
            # Exploitation: choose model with highest Q-value
            model_index = np.argmax(self.q_table[state])

        # Track usage statistics
        if model_index == 0:
            self.stats["cheap_model_uses"] += 1
        else:
            self.stats["expensive_model_uses"] += 1

        return model_index, self.models[model_index]

    def decay_epsilon(self, decay_rate=0.995):
        self.epsilon *= decay_rate
        self.epsilon = max(self.epsilon, 0.01)  # Minimum exploration rate

    def calculate_reward(self, model_index, is_correct):
        model_cost = self.models[model_index]['cost']
        # Base reward: 1 for correct, 0 for incorrect
        performance_score = 1 if is_correct else 0
        # Higher cost penalty for incorrect answers makes sense
        cost_factor = 0.05 if is_correct else 0.1
        reward = performance_score - (cost_factor * model_cost)
        self.stats["rewards"].append(reward)
        return reward

    def update_q_value(self, state, action, reward, next_state):
        # Standard Q-learning update formula
        current_q = self.q_table[state][action]
        best_next_q = max(self.q_table[next_state])

        new_q = current_q + self.learning_rate * (
            reward + self.discount_factor * best_next_q - current_q
        )

        self.q_table[state][action] = new_q

    def calculate_complexity(self, text):
        # Option 1: Readability metrics
        readability_score = textstat.flesch_kincaid_grade(text)

        # Option 2: Vocabulary diversity
        unique_words = len(set(text.lower().split()))
        total_words = len(text.split())
        lexical_diversity = unique_words / total_words

        # Option 3: Sentence complexity
        avg_sentence_length = sum(len(s.split()) for s in nltk.sent_tokenize(text)) / len(nltk.sent_tokenize(text))

        # Combined score (example)
        return readability_score + (lexical_diversity * 50) + (avg_sentence_length * 2)

    def get_state(self, text):
        # Length discretization
        num_words = len(text.split())
        if num_words < 50:
            length = "short"
        elif num_words < 200:
            length = "medium"
        else:
            length = "long"

        # Complexity discretization (example using readability)
        complexity_score = self.calculate_complexity(text)
        if complexity_score < 30:
            complexity = "simple"
        elif complexity_score < 70:
            complexity = "moderate"
        else:
            complexity = "complex"

        return (complexity, length)  # State tuple

    def save_q_table(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.q_table, f)

    def load_q_table(self, filename):
        with open(filename, 'rb') as f:
            self.q_table = pickle.load(f)


In [5]:
import nltk
import textstat
import random
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim

nltk.download('punkt')

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class DQN_Learner():
    def __init__(self, models, learning_rate=0.001, discount_factor=0.9, epsilon=0.1,
                 epsilon_decay=0.995, epsilon_min=0.01, load_model=False):
        self.models = models
        self.num_models = len(models)
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        # State space: complexity (3 levels) and length (3 levels) one-hot encoded
        self.state_size = 6  # 3 complexity + 3 length
        self.action_size = self.num_models

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model = DQN(self.state_size, self.action_size).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

        if load_model:
            self.model.load_state_dict(torch.load('dqn_model.pth'))

        self.stats = {"cheap_model_uses": 0, "expensive_model_uses": 0, "rewards": []}

    def one_hot_state(self, state):
        complexity_levels = ["simple", "moderate", "complex"]
        length_levels = ["short", "medium", "long"]
        state_vec = np.zeros(self.state_size)
        state_vec[complexity_levels.index(state[0])] = 1
        state_vec[3 + length_levels.index(state[1])] = 1
        return torch.tensor(state_vec, dtype=torch.float32).to(self.device)

    def choose_model(self, state):
        if random.random() < self.epsilon:
            model_index = random.randint(0, self.num_models - 1)
        else:
            state_tensor = self.one_hot_state(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.model(state_tensor)
            model_index = torch.argmax(q_values).item()

        if model_index == 0:
            self.stats["cheap_model_uses"] += 1
        else:
            self.stats["expensive_model_uses"] += 1

        return model_index, self.models[model_index]

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

    def calculate_reward(self, model_index, is_correct):
        model_cost = self.models[model_index]['cost']
        performance_score = 1 if is_correct else 0
        cost_factor = 0.05 if is_correct else 0.1
        reward = performance_score - (cost_factor * model_cost)
        self.stats["rewards"].append(reward)
        return reward

    def train(self, state, action, reward, next_state, done=False):
        state_tensor = self.one_hot_state(state).unsqueeze(0)
        next_state_tensor = self.one_hot_state(next_state).unsqueeze(0)

        self.model.train()
        q_values = self.model(state_tensor)
        q_value = q_values[0, action]

        with torch.no_grad():
            next_q_values = self.model(next_state_tensor)
            max_next_q_value = torch.max(next_q_values)
            target = reward + (self.discount_factor * max_next_q_value * (1 - int(done)))

        loss = self.criterion(q_value, target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def calculate_complexity(self, text):
        readability_score = textstat.flesch_kincaid_grade(text)
        unique_words = len(set(text.lower().split()))
        total_words = len(text.split())
        lexical_diversity = unique_words / total_words
        avg_sentence_length = sum(len(s.split()) for s in nltk.sent_tokenize(text)) / len(nltk.sent_tokenize(text))
        return readability_score + (lexical_diversity * 50) + (avg_sentence_length * 2)

    def get_state(self, text):
        num_words = len(text.split())
        if num_words < 50:
            length = "short"
        elif num_words < 200:
            length = "medium"
        else:
            length = "long"

        complexity_score = self.calculate_complexity(text)
        if complexity_score < 30:
            complexity = "simple"
        elif complexity_score < 70:
            complexity = "moderate"
        else:
            complexity = "complex"

        return (complexity, length)

    def save_model(self, filename):
        torch.save(self.model.state_dict(), filename)

    def load_model(self, filename):
        self.model.load_state_dict(torch.load(filename))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=-1)
        return x

class ValueNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class PPO_Agent():
    def __init__(self, models, state_dim=6, learning_rate=0.001, gamma=0.99,
                 clip_epsilon=0.2, ppo_epochs=4, batch_size=32):
        self.models = models
        self.num_models = len(models)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize networks
        self.policy_network = PolicyNetwork(state_dim, self.num_models).to(self.device)
        self.value_network = ValueNetwork(state_dim).to(self.device)

        # Initialize optimizers
        self.policy_optimizer = optim.Adam(self.policy_network.parameters(), lr=learning_rate)
        self.value_optimizer = optim.Adam(self.value_network.parameters(), lr=learning_rate)

        # PPO parameters
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.ppo_epochs = ppo_epochs
        self.batch_size = batch_size

        # Memory for experience collection
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.action_probs = []
        self.dones = []

        # Statistics
        self.stats = {"cheap_model_uses": 0, "expensive_model_uses": 0, "rewards": []}

    def state_to_tensor(self, state):
        """Convert the state tuple to a tensor representation"""
        # Convert categorical variables to one-hot encoding
        complexity_map = {"simple": 0, "moderate": 1, "complex": 2}
        length_map = {"short": 0, "medium": 1, "long": 2}

        complexity, length = state

        # One-hot encode complexity (3 values)
        complexity_one_hot = [0, 0, 0]
        complexity_one_hot[complexity_map[complexity]] = 1

        # One-hot encode length (3 values)
        length_one_hot = [0, 0, 0]
        length_one_hot[length_map[length]] = 1

        # Combine into one vector
        state_vector = complexity_one_hot + length_one_hot

        return torch.FloatTensor(state_vector).to(self.device)

    def choose_model(self, state):
        """Choose a model based on the current policy"""
        state_tensor = self.state_to_tensor(state)

        with torch.no_grad():
            action_probs = self.policy_network(state_tensor)

        # Convert to numpy for sampling
        action_probs_np = action_probs.cpu().numpy()

        # Sample action from the probability distribution
        action = np.random.choice(self.num_models, p=action_probs_np)

        # Track usage statistics
        if action == 0:
            self.stats["cheap_model_uses"] += 1
        else:
            self.stats["expensive_model_uses"] += 1

        # Store the probability of the selected action
        action_prob = action_probs[action].item()

        return action, self.models[action], action_prob

    def remember(self, state, action, reward, next_state, action_prob, done=False):
        """Store experience in memory"""
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.action_probs.append(action_prob)
        self.dones.append(done)

    def compute_returns(self):
        """Compute returns and advantages for all stored rewards"""
        returns = []
        advantages = []

        # Convert states to tensors for value estimation
        states_tensor = torch.stack([self.state_to_tensor(s) for s in self.states])
        next_states_tensor = torch.stack([self.state_to_tensor(s) for s in self.next_states])

        with torch.no_grad():
            values = self.value_network(states_tensor).squeeze()
            next_values = self.value_network(next_states_tensor).squeeze()

        # Convert to numpy
        values = values.cpu().numpy()
        next_values = next_values.cpu().numpy()

        # Calculate returns and advantages
        for i in reversed(range(len(self.rewards))):
            # If this is the last step or if the episode is done
            if i == len(self.rewards) - 1 or self.dones[i]:
                next_return = 0
            else:
                next_return = returns[0]

            # Calculate return (discounted reward)
            current_return = self.rewards[i] + self.gamma * next_return
            returns.insert(0, current_return)

            # Calculate advantage
            if self.dones[i]:
                advantage = current_return - values[i]
            else:
                advantage = self.rewards[i] + self.gamma * next_values[i] - values[i]

            advantages.insert(0, advantage)

        return torch.FloatTensor(returns).to(self.device), torch.FloatTensor(advantages).to(self.device)

    def update_policy(self):
        """Update policy and value networks using PPO"""
        # If there are no experiences to learn from, return
        if len(self.states) == 0:
            return

        # Compute returns and advantages
        returns, advantages = self.compute_returns()

        # Convert states and actions to tensors
        states_tensor = torch.stack([self.state_to_tensor(s) for s in self.states])
        actions = torch.LongTensor(self.actions).to(self.device)
        old_action_probs = torch.FloatTensor(self.action_probs).to(self.device)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO update for multiple epochs
        for _ in range(self.ppo_epochs):
            # Create random indices
            indices = torch.randperm(len(self.states))

            # Create mini-batches
            for start_idx in range(0, len(self.states), self.batch_size):
                # Get mini-batch indices
                idx = indices[start_idx:start_idx + self.batch_size]

                # Get mini-batch data
                mb_states = states_tensor[idx]
                mb_actions = actions[idx]
                mb_old_action_probs = old_action_probs[idx]
                mb_returns = returns[idx]
                mb_advantages = advantages[idx]

                # Forward pass
                action_probs = self.policy_network(mb_states)
                values = self.value_network(mb_states).squeeze()

                # Get probabilities of the actions we actually took
                actions_one_hot = F.one_hot(mb_actions, num_classes=self.num_models).float()
                current_action_probs = torch.sum(action_probs * actions_one_hot, dim=1)

                # Compute ratio
                ratio = current_action_probs / mb_old_action_probs

                # Compute surrogate losses
                surrogate1 = ratio * mb_advantages
                surrogate2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * mb_advantages

                # Compute policy loss, value loss, and entropy
                policy_loss = -torch.min(surrogate1, surrogate2).mean()
                value_loss = F.mse_loss(values, mb_returns)
                entropy = -torch.sum(action_probs * torch.log(action_probs + 1e-10), dim=1).mean()

                # Compute total loss (policy + value + entropy bonus)
                total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

                # Backward pass and optimize
                self.policy_optimizer.zero_grad()
                self.value_optimizer.zero_grad()
                total_loss.backward()
                self.policy_optimizer.step()
                self.value_optimizer.step()

        # Clear memory after updating
        self.clear_memory()

    def clear_memory(self):
        """Clear stored experiences"""
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.action_probs = []
        self.dones = []

    def calculate_reward(self, model_index, is_correct):
        """Calculate reward based on model performance and cost"""
        model_cost = self.models[model_index]['cost']
        # Base reward: 1 for correct, 0 for incorrect
        performance_score = 1 if is_correct else 0
        # Higher cost penalty for incorrect answers
        cost_factor = 0.05 if is_correct else 0.1
        reward = performance_score - (cost_factor * model_cost)
        self.stats["rewards"].append(reward)
        return reward

    def get_state(self, text):
        """Analyze text to determine state features"""
        # Length discretization
        num_words = len(text.split())
        if num_words < 50:
            length = "short"
        elif num_words < 200:
            length = "medium"
        else:
            length = "long"

        # Complexity discretization using readability
        complexity_score = self.calculate_complexity(text)
        if complexity_score < 30:
            complexity = "simple"
        elif complexity_score < 70:
            complexity = "moderate"
        else:
            complexity = "complex"

        return (complexity, length)

    def calculate_complexity(self, text):
        """Calculate text complexity using various metrics"""
        # Option 1: Readability metrics
        readability_score = textstat.flesch_kincaid_grade(text)

        # Option 2: Vocabulary diversity
        unique_words = len(set(text.lower().split()))
        total_words = len(text.split())
        lexical_diversity = unique_words / total_words if total_words > 0 else 0

        # Option 3: Sentence complexity
        sentences = nltk.sent_tokenize(text)
        avg_sentence_length = sum(len(s.split()) for s in sentences) / len(sentences) if sentences else 0

        # Combined score
        return readability_score + (lexical_diversity * 50) + (avg_sentence_length * 2)

    def save_model(self, policy_path, value_path):
        """Save the policy and value networks"""
        torch.save(self.policy_network.state_dict(), policy_path)
        torch.save(self.value_network.state_dict(), value_path)

    def load_model(self, policy_path, value_path):
        """Load the policy and value networks"""
        self.policy_network.load_state_dict(torch.load(policy_path))
        self.value_network.load_state_dict(torch.load(value_path))


In [7]:
def get_model(model_name):
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,  # Match your input dtype
        bnb_4bit_quant_type="nf4",  # Add quantization type
        bnb_4bit_use_double_quant=True
    )
    if model_name == "wizardmath":
        wizardmath_tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1")
        wizardmath_model = AutoModelForCausalLM.from_pretrained(
            "WizardLM/WizardMath-7B-V1.1",
            quantization_config=quantization_config,
            device_map={"": 0},
            torch_dtype=torch.float16
        )
        return {
            'model': wizardmath_model,
            'model_name': "wizardmath",
            'tokenizer': wizardmath_tokenizer,
            'cost': 10
        }
    elif model_name == "phi2":
        phi2_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
        phi2_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            quantization_config=quantization_config,
            device_map={"": 0},
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        return {
            'model': phi2_model,
            'model_name': "phi2",
            'tokenizer': phi2_tokenizer,
            'cost': 5  # Lower cost since it's a smaller model
        }

In [8]:
def get_dataset():
    train_dataset = load_dataset("openai/gsm8k", "main", split='train')
    test_dataset = load_dataset("openai/gsm8k", "main", split='test')
    return train_dataset, test_dataset

In [9]:
def extract_answer(answer_text):
    # The final answer in GSM8K follows the '####' pattern
    match = re.search(r'####\s*(-?\d+)', answer_text)
    if match:
        return match.group(1).strip()
    return None

In [10]:
def process_problem(problem, model_index, models):
    prompt = f"""

Follow these instructions:
1. Work through the problem step by step
2. Calculate the numerical answer
3. On the last line, write ONLY: #### <numerical answer>. Do not add any units like "kg" or "m", or any currency symbols like "$".
4. Do not write anything after the final answer

-------------------
EXAMPLE FORMAT:
Step 1: [explanation]
Step 2: [explanation]
Final calculation: [calculation]
#### [numerical answer]
-------------------

NOW SOLVE THE PROBLEM CORRECTLY: {problem['question']}
"""
    model_obj = models[model_index]['model']
    tokenizer = models[model_index].get('tokenizer', None)
    if tokenizer:
        tokenizer = models[model_index]['tokenizer']

    # if models[model_index]['model_name'] == "wizardmath":
    inputs = tokenizer(prompt, return_tensors="pt").to(model_obj.device)
    outputs = model_obj.generate(
        inputs.input_ids,
        max_new_tokens=1024,
        temperature=0.1,
        do_sample=True,
        attention_mask=inputs.attention_mask,
        # pad_token_id=tokenizer.eos_token_id,
    )
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    prompt_end = full_output.find(f"NOW SOLVE THE PROBLEM CORRECTLY: {problem['question']}")
    if prompt_end != -1:
        # Move past the question to get to the solution
        prompt_end = prompt_end + len(f"NOW SOLVE THE PROBLEM CORRECTLY: {problem['question']}")
        model_response = full_output[prompt_end:].strip()
    else:
        # Fallback if we can't find the exact prompt ending
        model_response = full_output

    # Now extract the numeric answer using a more reliable approach
    import re

    # Check for #### pattern first (Phi-2 style)
    hash_match = re.search(r'####\s*([\$]?\s*\d+(?:\.\d+)?)', model_response)
    if hash_match:
        # Extract just the number, removing any currency symbols
        answer_text = hash_match.group(1)
        numeric_match = re.search(r'(\d+(?:\.\d+)?)', answer_text)
        if numeric_match:
            numeric_answer = numeric_match.group(1)
            return f"{prompt}\n\n{model_response.split('####')[0].strip()}\n#### {numeric_answer}"

    # Check for explicit "answer is" pattern (WizardMath style)
    answer_match = re.search(r'(?:final answer|the answer is)[^0-9]*?([\$]?\s*\d+(?:\.\d+)?)',
                            model_response.lower())
    if answer_match:
        answer_text = answer_match.group(1)
        numeric_match = re.search(r'(\d+(?:\.\d+)?)', answer_text)
        if numeric_match:
            numeric_answer = numeric_match.group(1)
            # Find where this answer occurs in the text to split it there
            answer_position = model_response.lower().find(answer_match.group(0))
            if answer_position != -1:
                return f"{prompt}\n\n{model_response[:answer_position].strip()}\n#### {numeric_answer}"

    # If all else fails, look for numbers in the last few lines
    lines = model_response.split('\n')
    for i in range(len(lines)-1, max(0, len(lines)-5), -1):
        line = lines[i]
        # Skip lines that are clearly not the answer
        if len(line.strip()) < 1 or any(word in line.lower() for word in ["step", "explanation"]):
            continue

        numeric_match = re.search(r'(\d+(?:\.\d+)?)', line)
        if numeric_match:
            numeric_answer = numeric_match.group(1)
            return f"{prompt}\n\n{model_response.split(line)[0].strip()}\n#### {numeric_answer}"

    # If we couldn't extract an answer, return the unmodified output
    return full_output


In [15]:
def run_q_learner(dataset, models):
    q_learner = Q_Learner(models)

    dataset = gsm8k_dataset['train']
    for i in tqdm(range(len(dataset)), desc="Training Q-learner"):
        current_problem = dataset[i]
        next_problem = dataset[i+1]

        current_state = q_learner.get_state(current_problem["question"])
        next_state = q_learner.get_state(next_problem["question"])

        model_index, model = q_learner.choose_model(current_state)

        # Process the current problem
        model_output = process_problem(current_problem, model_index, models)

        # print(f"\nProblem {i}: {current_problem['question']}")
        print(f"\nChosen model: {model_index} ({'cheap' if model_index == 0 else 'expensive'})")
        # print(f"\nModel output: {model_output}")

        # Extract answers and check correctness
        predicted_answer = extract_answer(model_output)
        print(f"Predicted answer: {predicted_answer}")
        true_answer = extract_answer(current_problem["answer"])
        print(f"True answer: {true_answer}")
        is_correct = (predicted_answer == true_answer) if predicted_answer and true_answer else False

        # Calculate reward
        reward = q_learner.calculate_reward(model_index, is_correct)

        # Update Q-values
        q_learner.update_q_value(current_state, model_index, reward, next_state)

        # Decay epsilon after each problem
        q_learner.decay_epsilon()

    # Handle the last problem separately (terminal state)
    last_problem = dataset[-1]
    last_state = q_learner.get_state(last_problem["question"])
    model_index, model = q_learner.choose_model(last_state)

    # Process the last problem
    model_output = process_problem(last_problem, model_index, models)

    predicted_answer = extract_answer(model_output)
    true_answer = extract_answer(last_problem["answer"])
    is_correct = (predicted_answer == true_answer) if predicted_answer and true_answer else False

    # For terminal state, just update with immediate reward
    terminal_reward = q_learner.calculate_reward(model_index, is_correct)
    current_q = q_learner.q_table[last_state][model_index]
    new_q = current_q + q_learner.learning_rate * (terminal_reward - current_q)
    q_learner.q_table[last_state][model_index] = new_q

    # Print training statistics
    print(f"Training complete!")
    print(f"Final epsilon: {q_learner.epsilon:.4f}")
    print(f"Cheap model uses: {q_learner.stats['cheap_model_uses']}")
    print(f"Expensive model uses: {q_learner.stats['expensive_model_uses']}")
    print(f"Average reward: {np.mean(q_learner.stats['rewards']):.4f}")

    # Test model
    test_dataset = gsm8k_dataset['test']
    correct_predictions = 0
    total_predictions = len(test_dataset)
    for i in tqdm(range(1), desc="Testing Q-learner"):
        test_problem = test_dataset[i]
        test_state = q_learner.get_state(test_problem["question"])
        model_index, model = q_learner.choose_model(test_state)

        # Process the test problem
        model_output = process_problem(test_problem, model_index, models)

        predicted_answer = extract_answer(model_output)
        true_answer = extract_answer(test_problem["answer"])
        print(f"Predicted: {predicted_answer}, True: {true_answer}")
        is_correct = (predicted_answer == true_answer) if predicted_answer and true_answer else False

        if is_correct:
            correct_predictions += 1
    accuracy = correct_predictions / total_predictions
    print(f"Test Accuracy: {accuracy:.4f}")

In [12]:
def run_dqn(dataset, models):
    dqn_learner = DQN_Learner(
        models,
        learning_rate=0.001,
        discount_factor=0.9,
        epsilon=0.1,
        epsilon_decay=0.995,
        epsilon_min=0.01
    )

    dataset = gsm8k_dataset['train']
    for i in tqdm(range(5), desc="Training DQN learner"):
        current_problem = dataset[i]
        next_problem = dataset[i+1]

        # Get current state
        current_state = dqn_learner.get_state(current_problem["question"])

        # Choose model using the DQN policy
        model_index, model = dqn_learner.choose_model(current_state)

        # Process the current problem
        model_output = process_problem(current_problem, model_index, models)

        print(f"\nProblem {i}: {current_problem['question']}")
        print(f"\nChosen model: {model_index} ({'cheap' if model_index == 0 else 'expensive'})")
        print(f"\nModel output: {model_output}")

        # Extract answers and check correctness
        predicted_answer = extract_answer(model_output)
        print(f"Predicted answer: {predicted_answer}")
        true_answer = extract_answer(current_problem["answer"])
        print(f"True answer: {true_answer}")
        is_correct = (predicted_answer == true_answer) if predicted_answer and true_answer else False

        # Calculate reward
        reward = dqn_learner.calculate_reward(model_index, is_correct)

        # Get next state
        next_state = dqn_learner.get_state(next_problem["question"])

        # Train the DQN model
        dqn_learner.train(current_state, model_index, reward, next_state, done=False)

        # Decay epsilon after each problem
        dqn_learner.decay_epsilon()

    # Save the trained DQN model
    # dqn_learner.save_model('dqn_model.pth')

    # Evaluation phase (optional)
    print("\nEvaluation Phase:")
    test_dataset = gsm8k_dataset['test']
    correct_predictions = 0
    total_predictions = min(10, len(test_dataset))  # Evaluate on first 10 test problems

    for i in range(total_predictions):
        problem = test_dataset[i]
        state = dqn_learner.get_state(problem["question"])

        # Use the trained policy with no exploration (epsilon=0)
        epsilon_backup = dqn_learner.epsilon
        dqn_learner.epsilon = 0
        model_index, _ = dqn_learner.choose_model(state)
        dqn_learner.epsilon = epsilon_backup

        model_output = process_problem(problem, model_index, models)
        predicted_answer = extract_answer(model_output)
        true_answer = extract_answer(problem["answer"])

        if predicted_answer == true_answer:
            correct_predictions += 1

        print(f"Test Problem {i}: {'Correct' if predicted_answer == true_answer else 'Incorrect'}")

    print(f"Accuracy: {correct_predictions/total_predictions:.2f}")
    print(f"Model usage statistics: Cheap model: {dqn_learner.stats['cheap_model_uses']}, Expensive model: {dqn_learner.stats['expensive_model_uses']}")

In [13]:
def run_ppo(dataset, models, num_episodes=5):
    """Train a PPO agent to select models based on problem characteristics"""
    ppo_agent = PPO_Agent(models)

    train_data = dataset['train']

    for episode in range(num_episodes):
        print(f"Episode {episode+1}/{num_episodes}")

        # Process multiple problems in each episode
        for i in range(min(10, len(train_data) - 1)):
            current_problem = train_data[i]
            next_problem = train_data[i+1]

            # Get current state
            current_state = ppo_agent.get_state(current_problem["question"])

            # Choose model using policy network
            model_index, model, action_prob = ppo_agent.choose_model(current_state)

            # Process the problem using the chosen model
            model_output = process_problem(current_problem, model_index, models)

            print(f"\nProblem {i}: {current_problem['question']}")
            print(f"\nChosen model: {model_index} ({'cheap' if model_index == 0 else 'expensive'})")
            print(f"\nModel output: {model_output}")

            # Check correctness
            predicted_answer = extract_answer(model_output)
            print(f"Predicted answer: {predicted_answer}")
            true_answer = extract_answer(current_problem["answer"])
            print(f"True answer: {true_answer}")
            is_correct = (predicted_answer == true_answer) if predicted_answer and true_answer else False

            # Calculate reward
            reward = ppo_agent.calculate_reward(model_index, is_correct)

            # Get next state
            next_state = ppo_agent.get_state(next_problem["question"])

            # Store experience
            done = (i == min(10, len(train_data) - 1) - 1)
            ppo_agent.remember(current_state, model_index, reward, next_state, action_prob, done)

        # Update policy after collecting experiences from this episode
        ppo_agent.update_policy()

        # Evaluate periodically
        if (episode + 1) % 10 == 0:
            evaluate_ppo(ppo_agent, dataset['test'][:10])

    # Save the trained model
    ppo_agent.save_model('ppo_policy.pt', 'ppo_value.pt')

    return ppo_agent

In [None]:
gsm8k_dataset = {
    'train': get_dataset()[0],
    'test': get_dataset()[1]
}

models = [get_model('phi2'), get_model('wizardmath')]

run_q_learner(gsm8k_dataset, models)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Training Q-learner:   0%|          | 0/7473 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Training Q-learner:   0%|          | 1/7473 [00:53<110:42:40, 53.34s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 72
True answer: 72


Training Q-learner:   0%|          | 2/7473 [01:45<109:28:07, 52.75s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: None
True answer: 10


Training Q-learner:   0%|          | 3/7473 [02:37<108:59:58, 52.53s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 5
True answer: 5


Training Q-learner:   0%|          | 4/7473 [03:30<109:13:06, 52.64s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 42
True answer: 42


Training Q-learner:   0%|          | 5/7473 [03:41<77:45:23, 37.48s/it] Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 624
True answer: 624


Training Q-learner:   0%|          | 6/7473 [04:33<88:08:18, 42.49s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 35
True answer: 35


Training Q-learner:   0%|          | 7/7473 [04:42<65:27:48, 31.57s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 48
True answer: 48


Training Q-learner:   0%|          | 8/7473 [04:50<50:02:14, 24.13s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 16
True answer: 16


Training Q-learner:   0%|          | 9/7473 [05:44<68:52:54, 33.22s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 45
True answer: 41


Training Q-learner:   0%|          | 10/7473 [06:36<81:00:43, 39.08s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: None
True answer: 990


Training Q-learner:   0%|          | 11/7473 [06:52<66:25:57, 32.05s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 119
True answer: 121


Training Q-learner:   0%|          | 12/7473 [07:45<79:39:25, 38.44s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 10
True answer: 5


Training Q-learner:   0%|          | 13/7473 [07:54<61:21:53, 29.61s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 85
True answer: 85


Training Q-learner:   0%|          | 14/7473 [08:47<75:34:50, 36.48s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 30
True answer: 35


Training Q-learner:   0%|          | 15/7473 [09:39<85:25:21, 41.23s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 0
True answer: 5


Training Q-learner:   0%|          | 16/7473 [10:05<76:18:00, 36.84s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 20
True answer: 448000


Training Q-learner:   0%|          | 17/7473 [10:58<86:03:43, 41.55s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 514
True answer: 800


Training Q-learner:   0%|          | 18/7473 [11:50<92:37:08, 44.73s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 19
True answer: 43


Training Q-learner:   0%|          | 19/7473 [12:09<76:35:10, 36.99s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 16
True answer: 16


Training Q-learner:   0%|          | 20/7473 [13:02<86:44:54, 41.90s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 16
True answer: 16


Training Q-learner:   0%|          | 21/7473 [13:55<93:31:42, 45.18s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 38
True answer: 38


Training Q-learner:   0%|          | 22/7473 [14:11<75:11:36, 36.33s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 1080
True answer: 1080


Training Q-learner:   0%|          | 23/7473 [15:04<85:39:25, 41.39s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 1
True answer: 7


Training Q-learner:   0%|          | 24/7473 [15:19<69:05:21, 33.39s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 5
True answer: 5


Training Q-learner:   0%|          | 25/7473 [15:28<53:58:07, 26.09s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 62
True answer: 62


Training Q-learner:   0%|          | 26/7473 [16:23<71:47:18, 34.70s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 110
True answer: 110


Training Q-learner:   0%|          | 27/7473 [16:32<55:55:16, 27.04s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 400
True answer: 400


Training Q-learner:   0%|          | 28/7473 [16:48<49:15:23, 23.82s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 360
True answer: 400


Training Q-learner:   0%|          | 29/7473 [16:59<41:04:35, 19.87s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 8
True answer: 8


Training Q-learner:   0%|          | 30/7473 [17:17<40:01:32, 19.36s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 1000
True answer: 1000


Training Q-learner:   0%|          | 31/7473 [17:29<35:36:42, 17.23s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 6
True answer: 6


Training Q-learner:   0%|          | 32/7473 [18:22<57:28:06, 27.80s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 1200
True answer: 1200


Training Q-learner:   0%|          | 33/7473 [19:16<73:37:17, 35.62s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 10
True answer: 10


Training Q-learner:   0%|          | 34/7473 [19:24<56:30:42, 27.35s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 34
True answer: 34


Training Q-learner:   0%|          | 35/7473 [19:31<44:07:15, 21.35s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 5250
True answer: 5250


Training Q-learner:   0%|          | 36/7473 [19:42<37:38:31, 18.22s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 36
True answer: 36


Training Q-learner:   0%|          | 37/7473 [19:50<31:17:16, 15.15s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 15
True answer: 15


Training Q-learner:   1%|          | 38/7473 [20:13<36:04:10, 17.46s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 5
True answer: 5


Training Q-learner:   1%|          | 39/7473 [21:06<58:18:05, 28.23s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 72
True answer: 9


Training Q-learner:   1%|          | 40/7473 [21:21<49:56:02, 24.18s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 15
True answer: 15


Training Q-learner:   1%|          | 41/7473 [21:36<44:03:34, 21.34s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 476
True answer: 476


Training Q-learner:   1%|          | 42/7473 [22:28<63:04:44, 30.56s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 500
True answer: 500


Training Q-learner:   1%|          | 43/7473 [22:46<55:38:22, 26.96s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 99
True answer: 99


Training Q-learner:   1%|          | 44/7473 [22:55<44:11:55, 21.42s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 60
True answer: 60


Training Q-learner:   1%|          | 45/7473 [23:14<43:01:05, 20.85s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 300
True answer: 300


Training Q-learner:   1%|          | 46/7473 [24:07<62:32:02, 30.31s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 99
True answer: 99


Training Q-learner:   1%|          | 47/7473 [24:16<49:53:01, 24.18s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 1920
True answer: 1920


Training Q-learner:   1%|          | 48/7473 [24:37<47:54:58, 23.23s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 15
True answer: 15


Training Q-learner:   1%|          | 49/7473 [24:46<38:38:48, 18.74s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 10
True answer: 10


Training Q-learner:   1%|          | 50/7473 [24:52<30:48:29, 14.94s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 48
True answer: 48


Training Q-learner:   1%|          | 51/7473 [25:44<53:47:49, 26.09s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 100
True answer: 5


Training Q-learner:   1%|          | 52/7473 [26:04<50:15:04, 24.38s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 160
True answer: 160


Training Q-learner:   1%|          | 53/7473 [26:21<45:24:36, 22.03s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 5
True answer: 5


Training Q-learner:   1%|          | 54/7473 [26:33<39:04:41, 18.96s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 36
True answer: 36


Training Q-learner:   1%|          | 55/7473 [27:24<59:15:18, 28.76s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 13
True answer: 11


Training Q-learner:   1%|          | 56/7473 [27:37<49:11:47, 23.88s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 75
True answer: 75


Training Q-learner:   1%|          | 57/7473 [28:28<66:14:38, 32.16s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Chosen model: 0 (cheap)
Predicted answer: 31
True answer: 45


Training Q-learner:   1%|          | 58/7473 [28:41<54:06:11, 26.27s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Chosen model: 1 (expensive)
Predicted answer: 2
True answer: 2
