In [None]:
import itertools
import json
import random
from tqdm.notebook import tqdm

from dotenv import load_dotenv
import numpy as np

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
)

import anthropic
from openai import OpenAI

In [None]:
load_dotenv()

In [None]:
BENCHMARK_PATH = "./tasks/{}.jsonl"
PROMPT_TEMPLATE_PATH = "./templates/{}.txt"

In [None]:
task_name = "lawyers-exam"

backend = "hf"  # hf, openai, anthropic
model_id = "meta-llama/Llama-3.2-3B-Instruct"  # meta-llama/Llama-3.2-3B-Instruct, gpt-5, claude-opus-4-1

## Load the data

In [None]:
def load_benchmark(path):
    with open(path) as f:
        return [json.loads(l) for l in f]

In [None]:
dataset = load_benchmark(BENCHMARK_PATH.format(task_name))

In [None]:
len(dataset)

## Load the model / client

In [None]:
if backend == "openai":
    client = OpenAI()
elif backend == "anthropic":
    client = anthropic.Anthropic()
elif backend == "hf":
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0")
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map=model.device,
        max_new_tokens=1,
        do_sample=False,
        temperature=0.0
    )

## Load the template

In [None]:
def load_template():
    with open(PROMPT_TEMPLATE_PATH.format(task_name), "r", encoding="utf-8") as f:
        return f.read()

In [None]:
prompt_template = load_template()
print(prompt_template)

## Format prompt

In [None]:
def format_prompt(situation, question, choices):
    return prompt_template.format(
        situation, question,
        choices[0], choices[1], choices[2], choices[3]
    )

In [None]:
print(format_prompt(dataset[0]["situacion"], dataset[0]["enunciado"], dataset[0]["opciones"]))

## Extend the dataset by permuting choices

In [None]:
def extend_dataset(dataset, m = 12, seed = 42):
    """
    Create m order-perturbed variants per MCQ item with balanced correct-option positions.

    For each original example, this function selects m permutations of [0,1,2,3] such that
    the correct option appears equally often in each position (a/b/c/d). With m=12, that
    means 3 times in each position.

    The output dataset contains len(dataset) * m entries. For each original example, its m
    variants are returned contiguously (first the m variants of item 0, then item 1, etc.).

    TODO: Make dynamic the name of keys.
    """

    rng = random.Random(seed)
    all_perms = list(itertools.permutations(range(4)))
    extended_dataset = []
    r, _ = divmod(m, 4)

    for item in dataset:
        # Build 4 groups: for each target position k, permutations where index c lands at k
        groups = {k: [] for k in range(4)}
        for p in all_perms:
            pos_c = p.index(item["correcta"])
            groups[pos_c].append(p)
        
        # Sample r permutations from each group to balance positions
        chosen = []
        for k in range(4):
            chosen.extend(rng.sample(groups[k], r))
        
        rng.shuffle(chosen)  # optional

        # Materialize the m variants for this example
        for perm in chosen:
            new_options = [item["opciones"][i] for i in perm]
            new_correct = perm.index(item["correcta"]) 
            extended_dataset.append({
                "competencia": item["competencia"],
                "situacion": item["situacion"],
                "enunciado": item["enunciado"],
                "opciones": new_options,
                "correcta": new_correct
            })
        
    return extended_dataset

In [None]:
extended_dataset = extend_dataset(dataset)

In [None]:
extended_dataset[:12]

## Run inferences

In [None]:
def chat(prompt):
    messages = [{"role": "user", "content": prompt}]

    if backend == "openai":
        return client.responses.create(
            model=model_id, input=messages,
            # Required for non-reasoning models
            # max_output_tokens=16, temperature=0,
            # Optional for reasoning models
            reasoning={
                "effort": "minimal"
            }
        ).output_text
    elif backend == "anthropic":
        content = client.messages.create(
            model=model_id, messages=messages,
            max_tokens=5, temperature=0
        ).content

        if len(content) == 0:
            print("Error generating content")
            return ""
        else:
            return content[0].text
    elif backend == "hf":
        return pipe(messages, return_full_text=False)[0]["generated_text"]

In [None]:
chat(format_prompt(extended_dataset[5]["situacion"], extended_dataset[5]["enunciado"], extended_dataset[5]["opciones"]))

In [None]:
outputs = []

for item in tqdm(extended_dataset):
    # Formatting prompt
    prompt = format_prompt(item["situacion"], item["enunciado"], item["opciones"])  # type: ignore

    # Calling the model
    output = chat(prompt)
    outputs.append(output)

In [None]:
results = []
for output, correct_idx in zip(outputs, [d["correcta"] for d in extended_dataset]):
    correct = "abcd"[correct_idx]  # type: ignore
    results.append(int(output == correct))

results_arr = np.reshape(np.array(results), (-1, 12))

In [None]:
results_arr

## Calculate accuracy

In [None]:
# Per-item mean correctness (order-robust correctness per item)
per_item_accuracy = results_arr.mean(axis=1)
accuracy = per_item_accuracy.mean().item()

In [None]:
print(f"Accuracy: {accuracy:.4f}")

## Persist the results

In [None]:
def save_results(results, model_id):
    with open(f"./results/{task_name}/{model_id}.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=4, ensure_ascii=False)

In [None]:
save_results({
    "accuracy": accuracy
}, model_id.split("/")[-1])