In [None]:
"""
Script to evaluate Hugging Face transformer models on MMLU benchmark
across STEM, social sciences, humanities, and other categories.
"""

import os
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.optim import Adam
from torch.functional import F
import re
import random

# Configuration variables (modify these directly)
MODEL_NAME_OR_PATH =  "unsloth/Llama-3.2-1B-Instruct"  # Replace with your model
OUTPUT_DIR = "./mmlu_results"
MAX_LENGTH = 2048

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 1
CACHE_DIR = None  # Set to a path if you want to cache models/datasets
# You can adjust the prompt template to match what your model was trained on
PROMPT_TEMPLATES = [
    "PREPROCESS Question: {question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:",
    "Question: {question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:",
    "You are a STEM assistant. Respond to non-STEM questions with 'E'. Question: {question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:",
]
PROMPT_TEMPLATE = "Question: {question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:"
MODELS = ["unsloth/Llama-3.2-3B-Instruct"]


def get_mmlu_categories():
    """Return dictionaries mapping tasks to their categories"""
    # Define the categories and corresponding tasks
    stem_tasks = [
        'abstract_algebra', 'astronomy', 'college_biology', 'college_chemistry',
        'college_computer_science', 'college_mathematics', 'college_physics',
        'computer_security', 'conceptual_physics', 'electrical_engineering',
        'elementary_mathematics', 'high_school_biology', 'high_school_chemistry',
        'high_school_computer_science', 'high_school_mathematics', 'high_school_physics',
        'high_school_statistics', 'machine_learning'
    ]

    humanities_tasks = [
        'formal_logic', 'high_school_european_history', 'high_school_us_history',
        'high_school_world_history', 'international_law', 'jurisprudence',
        'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy',
        'prehistory', 'professional_law', 'world_religions'
    ]

    social_sciences_tasks = [
        'econometrics', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_microeconomics',
        'high_school_psychology', 'human_sexuality', 'professional_psychology',
        'public_relations', 'security_studies', 'sociology', 'us_foreign_policy'
    ]

    other_tasks = [
        'anatomy', 'business_ethics', 'clinical_knowledge', 'college_medicine',
        'global_facts', 'human_aging', 'management', 'marketing',
        'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting',
        'professional_medicine', 'virology'
    ]

    # Create a mapping of tasks to categories
    task_to_category = {}
    for task in stem_tasks:
        task_to_category[task] = "STEM"
    for task in humanities_tasks:
        task_to_category[task] = "humanities"
    for task in social_sciences_tasks:
        task_to_category[task] = "social_sciences"
    for task in other_tasks:
        task_to_category[task] = "other"

    category_to_tasks = {
        "STEM": stem_tasks,
        "humanities": humanities_tasks,
        "social_sciences": social_sciences_tasks,
        "other": other_tasks,
    }

    return task_to_category, category_to_tasks


def extract_answer_from_generated_text(generated_text):
    """Extract the answer (A, B, C, or D) from generated text using various methods"""
    # Try different patterns to extract the answer

    # Method 1: Look for "The answer is X" pattern
    match = re.search(r"[Tt]he answer is ([ABCD])", generated_text)
    if match:
        return match.group(1)

    # Method 2: Look for "Answer: X" pattern
    match = re.search(r"Answer:\s*([ABCD])", generated_text)
    if match:
        return match.group(1)

    # Method 3: Direct matching of "A", "B", "C", or "D" at the beginning of the string
    match = re.match(r"^\s*([ABCD])", generated_text.strip())
    if match:
        return match.group(1)

    # Method 4: Check for letters in any position (not ideal but fallback)
    # Check for standalone letters (with word boundaries)
    for letter in ["A", "B", "C", "D"]:
        if re.search(r'\b' + letter + r'\b', generated_text):
            return letter

    # Method 5: Last resort - check for any occurrence of the letters
    for letter in ["A", "B", "C", "D"]:
        if letter in generated_text:
            return letter

    # If all else fails, look for related words
    text_lower = generated_text.lower()
    if "first" in text_lower or "a)" in text_lower:
        return "A"
    elif "second" in text_lower or "b)" in text_lower:
        return "B"
    elif "third" in text_lower or "c)" in text_lower:
        return "C"
    elif "fourth" in text_lower or "d)" in text_lower:
        return "D"

    # If we still can't determine, return None
    return None



def get_model_answer(model, tokenizer, prompt, example, do_stem_check=False):
    """Get the model's prediction for a given prompt"""
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    # Ensure we don't exceed maximum length
    if inputs.input_ids.shape[1] > MAX_LENGTH:
        inputs.input_ids = inputs.input_ids[:, :MAX_LENGTH]
        if 'attention_mask' in inputs:
            inputs.attention_mask = inputs.attention_mask[:, :MAX_LENGTH]

    # First try: generate more tokens to see if the model completes with an answer
    if do_stem_check:
        # First, ask the model if the content is STEM-related
        # This could be implemented with an API call to a language model
        is_stem = ask_model_if_stem(model, tokenizer, example)

        if not is_stem:
            return 4 # This will always be wrong

    with torch.no_grad():
        outputs = model(
            input_ids=inputs.input_ids.to(model.device),
            attention_mask=inputs.attention_mask.to(model.device) if 'attention_mask' in inputs else None,
        )




