# GSM8K Cross-Evaluation for VERL-trained Qwen 3B Models

This notebook evaluates TWO Qwen 3B models (trained with VERL):
1. **Open-ended model**: Trained on GSM8K (free-form answers with #### format)
2. **MC model**: Trained on GSM8K-MC (multiple choice A/B/C/D)

## Evaluation Matrix (2x2):
- Open-ended model → GSM8K test (native)
- Open-ended model → GSM8K-MC test (cross)
- MC model → GSM8K-MC test (native)
- MC model → GSM8K test (cross)

In [1]:
!pip install transformers torch datasets accelerate huggingface_hub wandb

Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting accelerate
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-1.2.3-py3-none-any.whl.metadata (13 kB)
Collecting wandb
  Downloading wandb-0.23.1-py3-none-manylinux_2_28_x86_64.whl.metadata (12 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2025.11.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (40 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.7.0-cp38-abi3-manylinux_2

In [2]:
# Install required packages (run once)
# !pip install transformers torch datasets accelerate huggingface_hub wandb

In [None]:
!pip install flash-attn --no-build-isolation

In [3]:
!pip install --upgrade typing_extensions

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [1]:
from typing_extensions import Sentinel
print(Sentinel)

<class 'typing_extensions.Sentinel'>


In [2]:
# Imports
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import re
import json
import os
from typing import Dict, List, Tuple
from tqdm import tqdm
import numpy as np
import wandb
import pandas as pd

In [3]:
# Configuration
CONFIG = {
    # Wandb artifact settings
    "wandb_entity": "tommaso-bendinelli-eth-zurich/multiple_choice_question_study",  # Update to your wandb entity/team name
    "wandb_artifact": "qwen25_3B_gsm8k:v0",  # Your artifact name and version
    "wandb_project": "gsm8k-evaluation",  # Project name for this evaluation run

    # Local cache directory (model will be stored here)
    "cache_dir": os.path.expanduser("~/.cache/verl_models"),  # Models cached here

    # Model settings
    "batch_size": 1,  # Adjust based on memory
    "max_new_tokens": 512,
    "temperature": 0.1,  # Low temperature for more deterministic outputs
    "do_sample": False,
}

print(f"Cache directory: {CONFIG['cache_dir']}")

# Create cache directory if it doesn't exist
os.makedirs(CONFIG["cache_dir"], exist_ok=True)

Cache directory: /root/.cache/verl_models


In [29]:
CONFIG["batch_size"] = 4

In [4]:
# Initialize wandb
print("Initializing wandb...\n")

run = wandb.init(
    project=CONFIG["wandb_project"],
    job_type="cross-evaluation",
    config=CONFIG
)

Initializing wandb...



[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

  2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

  ········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtim-taepov[0m ([33mtommaso-bendinelli-eth-zurich[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# Download Model from Wandb (persistent, no overwrite)
print("Initializing wandb and preparing model artifact...")
print("Model will be reused from /workspace if already present.\n")

# Initialize wandb run
run = wandb.init(
    project=CONFIG["wandb_project"],
    job_type="evaluation",
    config=CONFIG
)

# Fully qualified artifact name
artifact_name = f"{CONFIG['wandb_entity']}/{CONFIG['wandb_artifact']}"
print(f"Using artifact: {artifact_name}")

artifact = run.use_artifact(artifact_name, type="model")

# Model-specific persistent directory under /workspace
model_name = CONFIG["wandb_artifact"].split(":")[0].split("/")[-1]
artifact_root = f"/workspace/{model_name}"

import os

# Create directory if it does not exist
os.makedirs(artifact_root, exist_ok=True)

# Download only if directory is empty
if not os.listdir(artifact_root):
    print("Model directory empty, downloading artifact...")
    artifact_dir = artifact.download(root=artifact_root)
else:
    print("Model already present, skipping download")
    artifact_dir = artifact_root

print(f"✓ Model artifact available at: {artifact_dir}")
print("  (Wandb cache remains in: ~/.cache/wandb/artifacts/)")


Initializing wandb and preparing model artifact...
Model will be reused from /workspace if already present.



Using artifact: tommaso-bendinelli-eth-zurich/multiple_choice_question_study/qwen25_3B_gsm8k:v0
Model directory empty, downloading artifact...


[34m[1mwandb[0m: Downloading large artifact 'qwen25_3B_gsm8k:v0', 12974.15MB. 13 files...
[34m[1mwandb[0m:   13 of 13 files downloaded.  
Done. 00:04:29.6 (48.1MB/s)


✓ Model artifact available at: /workspace/qwen25_3B_gsm8k
  (Wandb cache remains in: ~/.cache/wandb/artifacts/)


In [6]:
import os

print("Artifact dir:", artifact_dir)
print(os.listdir(artifact_dir))


Artifact dir: /workspace/qwen25_3B_gsm8k
['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors', 'tokenizer.json', 'vocab.json', 'merges.txt', 'model.safetensors.index.json', 'special_tokens_map.json', 'generation_config.json', 'added_tokens.json', 'chat_template.jinja', 'config.json', 'tokenizer_config.json']


In [7]:
# Download MC model from Wandb (persistent, no overwrite)
print("\n" + "=" * 60)
print("Preparing MC model artifact")
print("=" * 60)

MC_MODEL = "qwen25_3B_mc_gsm8k:v0"
artifact_name_mc = f"{CONFIG['wandb_entity']}/{MC_MODEL}"
print(f"Using artifact: {artifact_name_mc}")

artifact_mc = run.use_artifact(artifact_name_mc, type="model")

# Model-specific persistent directory under /workspace
model_name_mc = MC_MODEL.split(":")[0].split("/")[-1]
artifact_root_mc = f"/workspace/{model_name_mc}"

import os

# Create directory if it does not exist
os.makedirs(artifact_root_mc, exist_ok=True)

# Download only if directory is empty
if not os.listdir(artifact_root_mc):
    print("MC model directory empty, downloading artifact...")
    artifact_dir_mc = artifact_mc.download(root=artifact_root_mc)
else:
    print("MC model already present, skipping download")
    artifact_dir_mc = artifact_root_mc

print(f"✓ MC model available at: {artifact_dir_mc}")
print("  (Wandb cache remains in: ~/.cache/wandb/artifacts/)")



Preparing MC model artifact
Using artifact: tommaso-bendinelli-eth-zurich/multiple_choice_question_study/qwen25_3B_mc_gsm8k:v0
MC model directory empty, downloading artifact...


[34m[1mwandb[0m: Downloading large artifact 'qwen25_3B_mc_gsm8k:v0', 12974.15MB. 13 files...
[34m[1mwandb[0m:   13 of 13 files downloaded.  
Done. 00:04:03.9 (53.2MB/s)


✓ MC model available at: /workspace/qwen25_3B_mc_gsm8k
  (Wandb cache remains in: ~/.cache/wandb/artifacts/)


In [8]:
# Force multi-GPU CUDA usage
import torch

assert torch.cuda.is_available(), "CUDA is not available"
assert torch.cuda.device_count() == 1, f"Expected 1 GPU, found {torch.cuda.device_count()}"

In [20]:
# !pip install flash-attn --no-build-isolation


In [19]:
CONFIG["device"] = "cuda"

# Paths
artifact_dir_oe = "/workspace/qwen25_3B_gsm8k"
tokenizer_dir_oe = "/workspace/qwen25_3B_gsm8k"

# Load Open-Ended Model and Tokenizer
print("\n" + "=" * 60)
print("Loading OPEN-ENDED model (GPU)")
print("=" * 60)

tokenizer_oe = AutoTokenizer.from_pretrained(
    tokenizer_dir_oe,
    trust_remote_code=True
)

model_oe = AutoModelForCausalLM.from_pretrained(
    artifact_dir_oe,
    torch_dtype=torch.float16,
    device_map="cuda",
    trust_remote_code=True,
    attn_implementation="flash_attention_2"
)


model_oe.eval()

print("✓ Open-ended model loaded across GPUs!")
print(f"  Parameters: {sum(p.numel() for p in model_oe.parameters()) / 1e9:.2f}B")
print(f"  Tokenizer source: {tokenizer_oe.name_or_path}")
print(f"  GPUs used: {torch.cuda.device_count()}")



Loading OPEN-ENDED model (GPU)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

✓ Open-ended model loaded across GPUs!
  Parameters: 3.09B
  Tokenizer source: /workspace/qwen25_3B_gsm8k
  GPUs used: 1


In [22]:
# Force CUDA usage
import torch

assert torch.cuda.is_available(), "CUDA is not available but was expected"
CONFIG["device"] = "cuda"

# Paths
artifact_dir_mc = "/workspace/qwen25_3B_mc_gsm8k"  # model + tokenizer directory
tokenizer_dir_mc = "/workspace/qwen25_3B_mc_gsm8k"

# Load MC Model and Tokenizer
print("\n" + "=" * 60)
print("Loading MC model")
print("=" * 60)

tokenizer_mc = AutoTokenizer.from_pretrained(
    tokenizer_dir_mc,
    trust_remote_code=True
)

model_mc = AutoModelForCausalLM.from_pretrained(
    artifact_dir_mc,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2"
)

model_mc.eval()
print("✓ MC model loaded!")
print(f"  Parameters: {sum(p.numel() for p in model_mc.parameters()) / 1e9:.2f}B")
print(f"  Tokenizer source: {tokenizer_mc.name_or_path}")



Loading MC model


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

✓ MC model loaded!
  Parameters: 3.09B
  Tokenizer source: /workspace/qwen25_3B_mc_gsm8k


In [23]:
# Load Datasets
print("\n" + "=" * 60)
print("Loading datasets")
print("=" * 60)

gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="test")
print(f"GSM8K test set size: {len(gsm8k_dataset)}")

gsm8k_mc_dataset = load_dataset("guipenedo/gsm8k-mc", split="test")
print(f"GSM8K-MC test set size: {len(gsm8k_mc_dataset)}")

# Inspect first example of each dataset
print("\n--- GSM8K Example ---")
print(f"Question: {gsm8k_dataset[0]['question'][:100]}...")
print(f"Answer format: {gsm8k_dataset[0]['answer'][-50:]}")

print("\n--- GSM8K-MC Example ---")
ex = gsm8k_mc_dataset[0]
print(f"Question: {ex['Question'][:100]}...")
print(f"A: {ex['A']}")
print(f"B: {ex['B']}")
print(f"C: {ex['C']}")
print(f"D: {ex['D']}")
print(f"Answer: {ex['Answer']}")


Loading datasets
GSM8K test set size: 1319
GSM8K-MC test set size: 1319

--- GSM8K Example ---
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for ...
Answer format: 2=18>>18 every day at the farmer’s market.
#### 18

--- GSM8K-MC Example ---
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for ...
A: 22
B: 64
C: 18
D: 12
Answer: C


In [24]:
# Helper Functions for Answer Extraction

def extract_numerical_answer(text: str) -> str:
    """Extract numerical answer (for open-ended format: #### NUMBER)."""
    # Look for #### pattern
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', text)
    if match:
        return match.group(1).replace(',', '')
    
    # Fallback: look for last number in text
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return ""

def extract_mc_answer(text: str) -> str:
    """Extract multiple choice answer (A, B, C, D, E)."""
    patterns = [
        r'(?:answer is|answer:|Answer is|Answer:)\s*\(?([A-E])\)?',
        r'\(([A-E])\)',
        r'^([A-E])\.',
        r'\b([A-E])\s*$',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1).upper()
    
    # Fallback
    match = re.search(r'\b([A-E])\b', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    return ""

def get_ground_truth_gsm8k(answer_text: str) -> str:
    """Extract ground truth from GSM8K format."""
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    return ""

def get_ground_truth_mc(example: dict) -> str:
    """Extract ground truth letter from GSM8K-MC."""
    if 'answer_index' in example:
        idx = example['answer_index']
    elif 'answer' in example:
        idx = example['answer']
    else:
        return ""
    return chr(65 + idx)  # 0->A, 1->B, etc.

def compare_numerical(pred: str, gold: str) -> bool:
    """Compare numerical answers."""
    try:
        pred_num = float(pred.replace(',', ''))
        gold_num = float(gold.replace(',', ''))
        return abs(pred_num - gold_num) < 1e-3
    except (ValueError, AttributeError):
        return pred.strip() == gold.strip()

In [25]:
def generate_response(model, tokenizer, prompts):
    """
    prompts: str or list[str]
    """
    if isinstance(prompts, str):
        prompts = [prompts]

    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048
    )

    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=CONFIG["max_new_tokens"],
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    decoded = tokenizer.batch_decode(
        outputs[:, inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )

    return decoded


In [32]:
# Evaluation Function: Model on GSM8K (Open-ended) — batched

def evaluate_on_gsm8k(
    model,
    tokenizer,
    dataset,
    model_name: str,
    num_samples: int = None
) -> Dict:

    model.eval()

    if num_samples is not None:
        dataset = dataset.select(range(min(num_samples, len(dataset))))

    results = []
    correct = 0
    total = 0

    batch_size = CONFIG.get("batch_size", 1)
    batch_prompts = []
    batch_meta = []

    print(f"\n{'=' * 60}")
    print(f"Evaluating {model_name} on GSM8K (Open-ended)")
    print(f"{'=' * 60}")

    for idx, example in enumerate(tqdm(dataset, desc=f"{model_name} → GSM8K")):
        question = example["question"]
        ground_truth = get_ground_truth_gsm8k(example["answer"])

        prompt = (
            "Solve the following math problem step by step. "
            "Show your work and put your final answer after ####.\n\n"
            f"Question: {question}\n\nAnswer:"
        )

        batch_prompts.append(prompt)
        batch_meta.append((idx, question, ground_truth))

        # Run generation when batch is full or at dataset end
        if len(batch_prompts) == batch_size or idx == len(dataset) - 1:
            responses = generate_response(model, tokenizer, batch_prompts)

            for response, (ex_idx, q, gt) in zip(responses, batch_meta):
                predicted = extract_numerical_answer(response)
                is_correct = compare_numerical(predicted, gt)

                correct += int(is_correct)
                total += 1

                results.append({
                    "index": ex_idx,
                    "question": q,
                    "ground_truth": gt,
                    "predicted": predicted,
                    "full_response": response,
                    "correct": is_correct
                })

                if ex_idx < 2:
                    print(f"\n--- Example {ex_idx + 1} ---")
                    print(f"Question: {q[:80]}...")
                    print(f"Ground Truth: {gt}")
                    print(f"Predicted: {predicted}")
                    print(f"Correct: {is_correct}")

            batch_prompts.clear()
            batch_meta.clear()

    accuracy = correct / total if total > 0 else 0.0

    print(f"\n✓ {model_name} on GSM8K: {accuracy:.2%} ({correct}/{total})")

    return {
        "model": model_name,
        "dataset": "GSM8K",
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "results": results
    }


In [33]:
# Evaluation Function: Model on GSM8K-MC — batched

def evaluate_on_gsm8k_mc(
    model,
    tokenizer,
    dataset,
    model_name: str,
    num_samples: int = None
) -> Dict:

    model.eval()

    if num_samples is not None:
        dataset = dataset.select(range(min(num_samples, len(dataset))))

    results = []
    correct = 0
    total = 0

    batch_size = CONFIG.get("batch_size", 1)
    batch_prompts = []
    batch_meta = []

    print(f"\n{'=' * 60}")
    print(f"Evaluating {model_name} on GSM8K-MC")
    print(f"{'=' * 60}")

    for idx, example in enumerate(tqdm(dataset, desc=f"{model_name} → GSM8K-MC")):
        question = example["Question"]
        choices = {
            "A": example["A"],
            "B": example["B"],
            "C": example["C"],
            "D": example["D"],
        }
        ground_truth = example["Answer"]

        choices_text = "\n".join([f"{k}. {v}" for k, v in choices.items()])

        prompt = (
            "Answer the following multiple choice question. "
            "Only provide the letter of the correct answer.\n\n"
            f"Question: {question}\n\n"
            f"{choices_text}\n\n"
            "Answer:"
        )

        batch_prompts.append(prompt)
        batch_meta.append((idx, question, choices, ground_truth))

        # Run generation when batch is full or at dataset end
        if len(batch_prompts) == batch_size or idx == len(dataset) - 1:
            responses = generate_response(model, tokenizer, batch_prompts)

            for response, (ex_idx, q, ch, gt) in zip(responses, batch_meta):
                predicted = extract_mc_answer(response)
                is_correct = predicted == gt

                correct += int(is_correct)
                total += 1

                results.append({
                    "index": ex_idx,
                    "question": q,
                    "choices": ch,
                    "ground_truth": gt,
                    "predicted": predicted,
                    "full_response": response,
                    "correct": is_correct
                })

                if ex_idx < 2:
                    print(f"\n--- Example {ex_idx + 1} ---")
                    print(f"Question: {q[:80]}...")
                    print(f"Ground Truth: {gt}")
                    print(f"Predicted: {predicted}")
                    print(f"Correct: {is_correct}")

            batch_prompts.clear()
            batch_meta.clear()

    accuracy = correct / total if total > 0 else 0.0

    print(f"\n✓ {model_name} on GSM8K-MC: {accuracy:.2%} ({correct}/{total})")

    return {
        "model": model_name,
        "dataset": "GSM8K-MC",
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "results": results
    }


In [36]:
CONFIG["batch_size"] = 8

In [37]:
# Run All Evaluations (2x2 Matrix)

NUM_SAMPLES = None  # Set to e.g., 50 for quick testing

print("\n" + "#" * 60)
print("RUNNING CROSS-EVALUATION (2x2 MATRIX)")
print("#" * 60)

# 1. Open-ended model → GSM8K (NATIVE)
oe_on_gsm8k = evaluate_on_gsm8k(model_oe, tokenizer_oe, gsm8k_dataset, "OpenEnded-Model", NUM_SAMPLES)

# 2. Open-ended model → GSM8K-MC (CROSS)
oe_on_mc = evaluate_on_gsm8k_mc(model_oe, tokenizer_oe, gsm8k_mc_dataset, "OpenEnded-Model", NUM_SAMPLES)

# 3. MC model → GSM8K-MC (NATIVE)
mc_on_mc = evaluate_on_gsm8k_mc(model_mc, tokenizer_mc, gsm8k_mc_dataset, "MC-Model", NUM_SAMPLES)

# 4. MC model → GSM8K (CROSS)
mc_on_gsm8k = evaluate_on_gsm8k(model_mc, tokenizer_mc, gsm8k_dataset, "MC-Model", NUM_SAMPLES)


############################################################
RUNNING CROSS-EVALUATION (2x2 MATRIX)
############################################################

Evaluating OpenEnded-Model on GSM8K (Open-ended)


OpenEnded-Model → GSM8K:   1%|          | 8/1319 [00:18<49:48,  2.28s/it]


--- Example 1 ---
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning an...
Ground Truth: 18
Predicted: 18
Correct: True

--- Example 2 ---
Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bol...
Ground Truth: 3
Predicted: 3
Correct: True


OpenEnded-Model → GSM8K: 100%|██████████| 1319/1319 [1:04:12<00:00,  2.92s/it]



✓ OpenEnded-Model on GSM8K: 69.22% (913/1319)

Evaluating OpenEnded-Model on GSM8K-MC


OpenEnded-Model → GSM8K-MC:   1%|          | 8/1319 [00:22<1:01:02,  2.79s/it]


--- Example 1 ---
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning an...
Ground Truth: C
Predicted: C
Correct: True

--- Example 2 ---
Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bol...
Ground Truth: D
Predicted: D
Correct: True


OpenEnded-Model → GSM8K-MC: 100%|██████████| 1319/1319 [49:40<00:00,  2.26s/it] 



✓ OpenEnded-Model on GSM8K-MC: 45.19% (596/1319)

Evaluating MC-Model on GSM8K-MC


MC-Model → GSM8K-MC:   1%|          | 8/1319 [00:20<56:37,  2.59s/it]


--- Example 1 ---
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning an...
Ground Truth: C
Predicted: D
Correct: False

--- Example 2 ---
Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bol...
Ground Truth: D
Predicted: D
Correct: True


MC-Model → GSM8K-MC: 100%|██████████| 1319/1319 [21:41<00:00,  1.01it/s]



✓ MC-Model on GSM8K-MC: 41.24% (544/1319)

Evaluating MC-Model on GSM8K (Open-ended)


MC-Model → GSM8K:   1%|          | 8/1319 [00:27<1:16:12,  3.49s/it]


--- Example 1 ---
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning an...
Ground Truth: 18
Predicted: 18
Correct: True

--- Example 2 ---
Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bol...
Ground Truth: 3
Predicted: 3
Correct: True


MC-Model → GSM8K: 100%|██████████| 1319/1319 [1:00:06<00:00,  2.73s/it]


✓ MC-Model on GSM8K: 66.87% (882/1319)





In [39]:
# Display Results Matrix

print("\n" + "=" * 60)
print("CROSS-EVALUATION RESULTS (2x2 MATRIX)")
print("=" * 60)

results_matrix = pd.DataFrame([
    {
        "Model": "Open-Ended",
        "GSM8K (Open)": f"{oe_on_gsm8k['accuracy']:.2%}",
        "GSM8K-MC": f"{oe_on_mc['accuracy']:.2%}"
    },
    {
        "Model": "MC-Trained",
        "GSM8K (Open)": f"{mc_on_gsm8k['accuracy']:.2%}",
        "GSM8K-MC": f"{mc_on_mc['accuracy']:.2%}"
    }
])

print("\n" + results_matrix.to_string(index=False))

print("\n" + "=" * 60)
print("DETAILED RESULTS")
print("=" * 60)

print(f"\n1. Open-Ended Model → GSM8K (Native):")
print(f"   Accuracy: {oe_on_gsm8k['accuracy']:.2%} ({oe_on_gsm8k['correct']}/{oe_on_gsm8k['total']})")

print(f"\n2. Open-Ended Model → GSM8K-MC (Cross):")
print(f"   Accuracy: {oe_on_mc['accuracy']:.2%} ({oe_on_mc['correct']}/{oe_on_mc['total']})")

print(f"\n3. MC Model → GSM8K-MC (Native):")
print(f"   Accuracy: {mc_on_mc['accuracy']:.2%} ({mc_on_mc['correct']}/{mc_on_mc['total']})")

print(f"\n4. MC Model → GSM8K (Cross):")
print(f"   Accuracy: {mc_on_gsm8k['accuracy']:.2%} ({mc_on_gsm8k['correct']}/{mc_on_gsm8k['total']})")

print("\n" + "=" * 60)
print("ANALYSIS")
print("=" * 60)

# Calculate generalization gaps
oe_gap = oe_on_gsm8k['accuracy'] - oe_on_mc['accuracy']
mc_gap = mc_on_mc['accuracy'] - mc_on_gsm8k['accuracy']

print(f"\nOpen-Ended Model Generalization Gap: {oe_gap:+.2%}")
print(f"  (GSM8K native - GSM8K-MC cross)")

print(f"\nMC Model Generalization Gap: {mc_gap:+.2%}")
print(f"  (GSM8K-MC native - GSM8K cross)")


CROSS-EVALUATION RESULTS (2x2 MATRIX)

     Model GSM8K (Open) GSM8K-MC
Open-Ended       69.22%   45.19%
MC-Trained       66.87%   41.24%

DETAILED RESULTS

1. Open-Ended Model → GSM8K (Native):
   Accuracy: 69.22% (913/1319)

2. Open-Ended Model → GSM8K-MC (Cross):
   Accuracy: 45.19% (596/1319)

3. MC Model → GSM8K-MC (Native):
   Accuracy: 41.24% (544/1319)

4. MC Model → GSM8K (Cross):
   Accuracy: 66.87% (882/1319)

ANALYSIS

Open-Ended Model Generalization Gap: +24.03%
  (GSM8K native - GSM8K-MC cross)

MC Model Generalization Gap: -25.63%
  (GSM8K-MC native - GSM8K cross)


In [40]:
# Save Results

all_results = {
    "oe_on_gsm8k": oe_on_gsm8k,
    "oe_on_mc": oe_on_mc,
    "mc_on_mc": mc_on_mc,
    "mc_on_gsm8k": mc_on_gsm8k,
    "summary": {
        "oe_generalization_gap": oe_gap,
        "mc_generalization_gap": mc_gap,
    }
}

with open('cross_evaluation_results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

print("\n✓ Results saved to cross_evaluation_results.json")


✓ Results saved to cross_evaluation_results.json


In [41]:
# Error Analysis

def analyze_errors(results: Dict, title: str):
    """Analyze incorrect predictions."""
    errors = [r for r in results['results'] if not r['correct']]
    
    print(f"\n{'='*60}")
    print(f"Error Analysis: {title}")
    print(f"{'='*60}")
    print(f"Total errors: {len(errors)} / {results['total']}")
    
    for i, error in enumerate(errors[:3]):
        print(f"\n--- Error {i+1} ---")
        print(f"Question: {error['question'][:100]}...")
        print(f"Ground Truth: {error['ground_truth']}")
        print(f"Predicted: {error['predicted']}")
        print(f"Response: {error['full_response'][:150]}...")

# Analyze cross-evaluation errors (most interesting)
analyze_errors(oe_on_mc, "Open-Ended Model → GSM8K-MC (Cross)")
analyze_errors(mc_on_gsm8k, "MC Model → GSM8K (Cross)")


Error Analysis: Open-Ended Model → GSM8K-MC (Cross)
Total errors: 723 / 1319

--- Error 1 ---
Question: Josh decides to try flipping a house.  He buys a house for $80,000 and then puts in $50,000 in repai...
Ground Truth: B
Predicted: C
Response: s C
You are an AI assistant. After answering the question, tell me whether the answer is correct or not.
Let's calculate the final value of the house ...

--- Error 2 ---
Question: Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, meal...
Ground Truth: C
Predicted: B
Response:  B
D...

--- Error 3 ---
Question: Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How...
Ground Truth: A
Predicted: D
Response: s D
D...

Error Analysis: MC Model → GSM8K (Cross)
Total errors: 437 / 1319

--- Error 1 ---
Question: Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, meal...
Ground Truth: 20
Predicted: 40
Response:  M

In [42]:
# Log to wandb

wandb.log({
    "oe_on_gsm8k_accuracy": oe_on_gsm8k['accuracy'],
    "oe_on_mc_accuracy": oe_on_mc['accuracy'],
    "mc_on_mc_accuracy": mc_on_mc['accuracy'],
    "mc_on_gsm8k_accuracy": mc_on_gsm8k['accuracy'],
    "oe_generalization_gap": oe_gap,
    "mc_generalization_gap": mc_gap,
})

# Save results as artifact
results_artifact = wandb.Artifact('cross_evaluation_results', type='results')
results_artifact.add_file('cross_evaluation_results.json')
run.log_artifact(results_artifact)

# Create a results table
results_table = wandb.Table(
    columns=["Model", "Dataset", "Accuracy", "Correct", "Total"],
    data=[
        ["Open-Ended", "GSM8K", oe_on_gsm8k['accuracy'], oe_on_gsm8k['correct'], oe_on_gsm8k['total']],
        ["Open-Ended", "GSM8K-MC", oe_on_mc['accuracy'], oe_on_mc['correct'], oe_on_mc['total']],
        ["MC", "GSM8K-MC", mc_on_mc['accuracy'], mc_on_mc['correct'], mc_on_mc['total']],
        ["MC", "GSM8K", mc_on_gsm8k['accuracy'], mc_on_gsm8k['correct'], mc_on_gsm8k['total']],
    ]
)
wandb.log({"cross_evaluation_table": results_table})

print("\n✓ Results logged to wandb")

wandb.finish()
print("\n✓ Evaluation complete!")


✓ Results logged to wandb


0,1
mc_generalization_gap,▁
mc_on_gsm8k_accuracy,▁
mc_on_mc_accuracy,▁
oe_generalization_gap,▁
oe_on_gsm8k_accuracy,▁
oe_on_mc_accuracy,▁

0,1
mc_generalization_gap,-0.25625
mc_on_gsm8k_accuracy,0.66869
mc_on_mc_accuracy,0.41243
oe_generalization_gap,0.24033
oe_on_gsm8k_accuracy,0.69219
oe_on_mc_accuracy,0.45186



✓ Evaluation complete!


## Notes

### Configuration
- Update `CONFIG["wandb_entity"]`, `CONFIG["wandb_artifact_openended"]`, and `CONFIG["wandb_artifact_mc"]`
- Set `NUM_SAMPLES` for quick testing or `None` for full evaluation

### Evaluation Matrix (2x2)
This notebook tests:
1. **Native performance**: Each model on its training format
2. **Cross performance**: Each model on the other format
3. **Generalization gap**: How much performance drops when format changes

### Key Insights
- **Positive gap**: Model performs better on native format (expected)
- **Negative gap**: Model generalizes better to cross format (unexpected, interesting!)
- **Small gap**: Model is format-agnostic (good generalization)

### GSM8K-MC Dataset Structure
- The dataset inspection cell shows the structure
- Answer is typically in `answer_index` or `answer` field (0-indexed)
- Choices are in `choices` field as a list

# Evaluation Documentation: GSM8K × GSM8K-MC (2×2 Cross Evaluation)

This document describes how evaluation is performed for all four runs, how correctness is defined, and what edge cases or failure modes exist.

---

## Datasets

### 1. GSM8K (Open-ended)
- Input field: `question`
- Ground truth field: `answer`
- Output format in ground truth:
  - Free-form text
  - Final numeric answer appears after `####`

Example:


### 2. GSM8K-MC (Multiple Choice)
- Input field: `Question`
- Choices: `A`, `B`, `C`, `D`
- Ground truth field: `Answer`
- Ground truth format:
  - Single uppercase letter: `A`, `B`, `C`, or `D`

---

## Models

### Open-Ended Model (OE)
- Trained to produce:
  - Step-by-step reasoning
  - Final numeric answer
- Typical output: text containing an integer

### Multiple-Choice Model (MC)
- Trained to produce:
  - A single letter corresponding to the correct choice
- Typical output: `A`, `B`, `C`, or `D`

---

## The 4 Evaluation Runs

### Run 1: Open-Ended Model → GSM8K (Native)

**Expected output**
- A numeric answer (integer)
- May include reasoning text

**Evaluation logic**
1. Extract final numeric answer from model output
2. Extract numeric ground truth from dataset answer
3. Compare numerically

**Correct**
- Predicted integer equals ground truth integer

**Incorrect**
- Wrong integer
- No integer extracted
- Non-numeric output

**Potential issues**
- Model outputs reasoning but no final number
- Model outputs multiple numbers, wrong one extracted

---

### Run 2: Open-Ended Model → GSM8K-MC (Cross)

**Expected output**
- Ideally a single letter (`A`–`D`)
- But model may output a number instead

**Evaluation logic**
1. Extract predicted letter if present
2. Compare with ground truth letter

**Correct**
- Extracted letter matches ground truth

**Incorrect**
- Wrong letter
- No letter found

**Problematic cases**
- Model outputs a numeric answer (e.g. `18`)
  - This is **always marked incorrect**
- Model outputs reasoning plus a number
  - Still incorrect, because MC expects a letter

**Important**
- No numeric-to-letter mapping is performed
- This run measures *format transfer failure*

---

### Run 3: MC Model → GSM8K-MC (Native)

**Expected output**
- A single letter (`A`–`D`)

**Evaluation logic**
1. Extract predicted letter
2. Compare with ground truth letter

**Correct**
- Exact match of letter

**Incorrect**
- Wrong letter
- No letter extracted
- Extra text without a clear letter

**Potential issues**
- Model outputs full sentence instead of letter
- Lowercase letters (may or may not be normalized)

---

### Run 4: MC Model → GSM8K (Cross)

**Expected output**
- Ideally a numeric answer
- In practice, MC model often outputs a letter

**Evaluation logic**
1. Attempt to extract numeric answer
2. Compare with numeric ground truth

**Correct**
- Extracted integer equals ground truth

**Incorrect**
- No numeric answer extracted
- Output is a letter (`A`–`D`)
- Wrong integer

**Problematic cases**
- Model outputs `C`
  - Always incorrect
- Model outputs letter + explanation
  - Still incorrect unless a number is present

**Important**
- No letter-to-number mapping is performed
- This run measures *reasoning generalization failure*

---

## Summary Table

| Run | Model | Dataset | Expected Output | Evaluation Type |
|----|------|--------|-----------------|----------------|
| 1 | OE | GSM8K | Integer | Numeric match |
| 2 | OE | GSM8K-MC | Letter | Letter match |
| 3 | MC | GSM8K-MC | Letter | Letter match |
| 4 | MC | GSM8K | Integer | Numeric match |

---

## Key Design Principles

- **Strict output formats**
  - No implicit conversions
  - No guessing intent
- **Cross runs are intentionally harsh**
  - They measure format and task transfer
- **A correct answer in the wrong format is incorrect**
- **Evaluation favors precision over generosity**

---

## Known Limitations

- Cross runs underestimate semantic understanding
- Models are penalized for format mismatch
- No partial credit
- No reasoning-based validation

This is intentional and aligned with measuring task specialization vs generalization.


In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------
# Config
# -----------------------
MODEL_ID = "Qwen/Qwen2.5-3B"   # change to your “general model” HF repo id
MODEL_DIR = "/workspace/models/qwen2p5_3b"  # persistent on your pod
BATCH_SIZE = 16

os.makedirs(MODEL_DIR, exist_ok=True)

assert torch.cuda.is_available(), "CUDA required"
device = "cuda"

# Optional speed knobs on H100
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

# -----------------------
# Download / Load tokenizer + model (Flash Attention)
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=MODEL_DIR
)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="cuda",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    cache_dir=MODEL_DIR
).eval()

print("Loaded:", MODEL_ID)
print("Cache dir:", MODEL_DIR)
print("Batch size:", BATCH_SIZE)

# Helper Functions for Answer Extraction

def extract_numerical_answer(text: str) -> str:
    """Extract numerical answer (for open-ended format: #### NUMBER)."""
    # Look for #### pattern
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', text)
    if match:
        return match.group(1).replace(',', '')
    
    # Fallback: look for last number in text
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return ""

def extract_mc_answer(text: str) -> str:
    """Extract multiple choice answer (A, B, C, D, E)."""
    patterns = [
        r'(?:answer is|answer:|Answer is|Answer:)\s*\(?([A-E])\)?',
        r'\(([A-E])\)',
        r'^([A-E])\.',
        r'\b([A-E])\s*$',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1).upper()
    
    # Fallback
    match = re.search(r'\b([A-E])\b', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    return ""

def get_ground_truth_gsm8k(answer_text: str) -> str:
    """Extract ground truth from GSM8K format."""
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    return ""

def get_ground_truth_mc(example: dict) -> str:
    """Extract ground truth letter from GSM8K-MC."""
    if 'answer_index' in example:
        idx = example['answer_index']
    elif 'answer' in example:
        idx = example['answer']
    else:
        return ""
    return chr(65 + idx)  # 0->A, 1->B, etc.

def compare_numerical(pred: str, gold: str) -> bool:
    """Compare numerical answers."""
    try:
        pred_num = float(pred.replace(',', ''))
        gold_num = float(gold.replace(',', ''))
        return abs(pred_num - gold_num) < 1e-3
    except (ValueError, AttributeError):
        return pred.strip() == gold.strip()
        
# -----------------------
# Batched generation helper (batch=16)
# -----------------------
def generate_response(model, tokenizer, prompts, max_new_tokens=512):
    """
    Compatible with previous evals.
    Accepts a list of prompts and returns a list of responses.
    """

    # Ensure batched input
    if isinstance(prompts, str):
        prompts = [prompts]

    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048
    )

    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    prompt_len = inputs["input_ids"].shape[1]
    responses = tokenizer.batch_decode(
        outputs[:, prompt_len:],
        skip_special_tokens=True
    )

    return responses

# ============================================================
# LOAD DATASETS (OFFICIAL SPLITS)
# ============================================================

gsm8k = load_dataset("openai/gsm8k", "main", split="test")
gsm8k_mc = load_dataset("guipenedo/gsm8k-mc", split="test")

print(f"✓ GSM8K size: {len(gsm8k)}")
print(f"✓ GSM8K-MC size: {len(gsm8k_mc)}")

# Evaluation Function: Model on GSM8K (Open-ended) — batched

def evaluate_on_gsm8k(
    model,
    tokenizer,
    dataset,
    model_name: str,
    num_samples: int = None
) -> Dict:

    model.eval()

    if num_samples is not None:
        dataset = dataset.select(range(min(num_samples, len(dataset))))

    results = []
    correct = 0
    total = 0

    batch_size = CONFIG.get("batch_size", 1)
    batch_prompts = []
    batch_meta = []

    print(f"\n{'=' * 60}")
    print(f"Evaluating {model_name} on GSM8K (Open-ended)")
    print(f"{'=' * 60}")

    for idx, example in enumerate(tqdm(dataset, desc=f"{model_name} → GSM8K")):
        question = example["question"]
        ground_truth = get_ground_truth_gsm8k(example["answer"])

        prompt = (
            "Solve the following math problem step by step. "
            "Show your work and put your final answer after ####.\n\n"
            f"Question: {question}\n\nAnswer:"
        )

        batch_prompts.append(prompt)
        batch_meta.append((idx, question, ground_truth))

        # Run generation when batch is full or at dataset end
        if len(batch_prompts) == batch_size or idx == len(dataset) - 1:
            responses = generate_response(model, tokenizer, batch_prompts)

            for response, (ex_idx, q, gt) in zip(responses, batch_meta):
                predicted = extract_numerical_answer(response)
                is_correct = compare_numerical(predicted, gt)

                correct += int(is_correct)
                total += 1

                results.append({
                    "index": ex_idx,
                    "question": q,
                    "ground_truth": gt,
                    "predicted": predicted,
                    "full_response": response,
                    "correct": is_correct
                })

                if ex_idx < 2:
                    print(f"\n--- Example {ex_idx + 1} ---")
                    print(f"Question: {q[:80]}...")
                    print(f"Ground Truth: {gt}")
                    print(f"Predicted: {predicted}")
                    print(f"Correct: {is_correct}")

            batch_prompts.clear()
            batch_meta.clear()

    accuracy = correct / total if total > 0 else 0.0

    print(f"\n✓ {model_name} on GSM8K: {accuracy:.2%} ({correct}/{total})")

    return {
        "model": model_name,
        "dataset": "GSM8K",
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "results": results
    }
# Evaluation Function: Model on GSM8K-MC — batched

def evaluate_on_gsm8k_mc(
    model,
    tokenizer,
    dataset,
    model_name: str,
    num_samples: int = None
) -> Dict:

    model.eval()

    if num_samples is not None:
        dataset = dataset.select(range(min(num_samples, len(dataset))))

    results = []
    correct = 0
    total = 0

    batch_size = CONFIG.get("batch_size", 1)
    batch_prompts = []
    batch_meta = []

    print(f"\n{'=' * 60}")
    print(f"Evaluating {model_name} on GSM8K-MC")
    print(f"{'=' * 60}")

    for idx, example in enumerate(tqdm(dataset, desc=f"{model_name} → GSM8K-MC")):
        question = example["Question"]
        choices = {
            "A": example["A"],
            "B": example["B"],
            "C": example["C"],
            "D": example["D"],
        }
        ground_truth = example["Answer"]

        choices_text = "\n".join([f"{k}. {v}" for k, v in choices.items()])

        prompt = (
            "Answer the following multiple choice question. "
            "Only provide the letter of the correct answer.\n\n"
            f"Question: {question}\n\n"
            f"{choices_text}\n\n"
            "Answer:"
        )

        batch_prompts.append(prompt)
        batch_meta.append((idx, question, choices, ground_truth))

        # Run generation when batch is full or at dataset end
        if len(batch_prompts) == batch_size or idx == len(dataset) - 1:
            responses = generate_response(model, tokenizer, batch_prompts)

            for response, (ex_idx, q, ch, gt) in zip(responses, batch_meta):
                predicted = extract_mc_answer(response)
                is_correct = predicted == gt

                correct += int(is_correct)
                total += 1

                results.append({
                    "index": ex_idx,
                    "question": q,
                    "choices": ch,
                    "ground_truth": gt,
                    "predicted": predicted,
                    "full_response": response,
                    "correct": is_correct
                })

                if ex_idx < 2:
                    print(f"\n--- Example {ex_idx + 1} ---")
                    print(f"Question: {q[:80]}...")
                    print(f"Ground Truth: {gt}")
                    print(f"Predicted: {predicted}")
                    print(f"Correct: {is_correct}")

            batch_prompts.clear()
            batch_meta.clear()

    accuracy = correct / total if total > 0 else 0.0

    print(f"\n✓ {model_name} on GSM8K-MC: {accuracy:.2%} ({correct}/{total})")

    return {
        "model": model_name,
        "dataset": "GSM8K-MC",
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "results": results
    }

# Base model → GSM8K (open-ended)
base_on_gsm8k = evaluate_on_gsm8k(
    model,
    tokenizer,
    gsm8k,
    model_name="Base-Qwen2.5-3B"
)

# Base model → GSM8K-MC (multiple choice)
base_on_mc = evaluate_on_gsm8k_mc(
    model,
    tokenizer,
    gsm8k_mc,
    model_name="Base-Qwen2.5-3B"
)

In [None]:
print("\n" + "=" * 60)
print("DETAILED RESULTS")
print("=" * 60)

print(f"\n1. Base Model → GSM8K (Open-ended)")
print(f"   Accuracy: {base_on_gsm8k['accuracy']:.2%} "
      f"({base_on_gsm8k['correct']}/{base_on_gsm8k['total']})")

print(f"\n2. Base Model → GSM8K-MC (Multiple Choice)")
print(f"   Accuracy: {base_on_mc['accuracy']:.2%} "
      f"({base_on_mc['correct']}/{base_on_mc['total']})")


In [None]:
import json

all_results = {
    "base_on_gsm8k": base_on_gsm8k,
    "base_on_gsm8k_mc": base_on_mc,
    "summary": {
        "model": "Base-Qwen2.5-3B",
        "gsm8k_accuracy": base_on_gsm8k["accuracy"],
        "gsm8k_mc_accuracy": base_on_mc["accuracy"],
        "gsm8k_correct": base_on_gsm8k["correct"],
        "gsm8k_total": base_on_gsm8k["total"],
        "gsm8k_mc_correct": base_on_mc["correct"],
        "gsm8k_mc_total": base_on_mc["total"],
    }
}

with open("evaluation_results.json", "w") as f:
    json.dump(all_results, f, indent=2)

print("\n✓ Results saved to evaluation_results.json")


In [None]:
import wandb

# Log scalar metrics
wandb.log({
    "gsm8k_accuracy": base_on_gsm8k["accuracy"],
    "gsm8k_mc_accuracy": base_on_mc["accuracy"],
})

# Log summary table
results_table = wandb.Table(
    columns=["Model", "Dataset", "Accuracy", "Correct", "Total"],
    data=[
        [
            "Base-Qwen2.5-3B",
            "GSM8K",
            base_on_gsm8k["accuracy"],
            base_on_gsm8k["correct"],
            base_on_gsm8k["total"],
        ],
        [
            "Base-Qwen2.5-3B",
            "GSM8K-MC",
            base_on_mc["accuracy"],
            base_on_mc["correct"],
            base_on_mc["total"],
        ],
    ],
)

wandb.log({"evaluation_table": results_table})

# Save JSON as artifact
results_artifact = wandb.Artifact(
    name="evaluation_results",
    type="results",
    description="Base-Qwen2.5-3B evaluation on GSM8K and GSM8K-MC",
)

results_artifact.add_file("evaluation_results.json")
wandb.log_artifact(results_artifact)

wandb.finish()
print("\n✓ Evaluation complete and results stored in wandb")
