In [1]:
import torch
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from datasets import load_dataset
import re
import wandb
from datetime import datetime
from utils import get_device, load_model, load_gsm8k, extract_answer, create_cot_prompt, generate_answer_hf

# Configuration
config = {
    # Model settings
    "model_name": "gpt2",  # Options: "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"
    "max_length": 512,  # Increased for few-shot CoT
    
    # Dataset settings
    "num_samples": 100,  # Set to None to use the full dataset
    "n_shot": 4,  # Number of examples for few-shot prompting
    
    # Generation settings
    "temperature": 0.7,
    "top_p": 0.9,
    "num_beams": 4,
    
    # Experiment tracking
    "run_name": f"gpt2-gsm8k-8shot-cot-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    "tags": ["gsm8k", "evaluation", "gpt2", "few-shot", "cot"]
}

# Initialize Weights & Biases with config
wandb.init(
    project="gpt2-gsm8k-benchmark", 
    name=config["run_name"],
    config=config,
    tags=config["tags"]
)

# Set device
DEVICE = get_device()

print(f"Using device: {DEVICE}")

# Main evaluation function
def evaluate_gsm8k():
    # Load dataset
    train_set, test_set = load_gsm8k(config)
    
    # Prepare training examples list for few-shot prompting
    train_examples = []
    for ex in train_set:
        train_examples.append({
            "question": ex["question"],
            "answer": ex["answer"]
        })
    
    # Load model and tokenizer
    model, tokenizer = load_model(config, DEVICE)
    
    results = []
    correct_count = 0
    
    for idx, example in enumerate(tqdm(test_set, desc="Evaluating")):
        question = example["question"]
        target_answer = float(extract_answer(example["answer"]))
        
        # Create 8-shot CoT prompt
        prompt = create_cot_prompt(train_examples, question, n_shot=config["n_shot"])
        
        model_response = generate_answer_hf(model, tokenizer, prompt, config, DEVICE, model_type="gpt2")
        
        # Log the full response
        print(f"\nQuestion {idx+1}: {question}")
        print(f"Model response: {model_response}")
        
        # Extract answer from response
        predicted_answer = extract_answer(model_response)
        
        # Check if correct (allowing for minor floating point differences)
        is_correct = False
        if predicted_answer is not None and target_answer is not None:
            # For integer answers, check exact match
            if target_answer.is_integer() and predicted_answer.is_integer():
                is_correct = int(predicted_answer) == int(target_answer)
            else:
                # For floating point, allow small relative error
                relative_error = abs(predicted_answer - target_answer) / (abs(target_answer) + 1e-10)
                is_correct = relative_error < 0.01  # 1% relative error tolerance
        
        if is_correct:
            correct_count += 1
        
        results.append({
            "question": question,
            "target_answer": target_answer,
            "model_response": model_response,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })
        
        # Log to wandb
        wandb.log({
            "running_accuracy": correct_count / (idx + 1),
            "example_idx": idx
        })
    
    # Calculate and log final accuracy
    accuracy = correct_count / len(test_set)
    print(f"\nFinal accuracy: {accuracy:.2%} ({correct_count}/{len(test_set)})")
    
    # Log final metrics to wandb
    wandb.log({
        "final_accuracy": accuracy,
    })
    
    # Save detailed results to CSV
    results_df = pd.DataFrame(results)
    results_file = f"{config['model_name']}_gsm8k_8shot_results.csv"
    results_df.to_csv(results_file, index=False)
    print(f"Detailed results saved to {results_file}")
    
    # Log results file to wandb
    wandb.save(results_file)
    
    return accuracy, results

if __name__ == "__main__":
    try:
        accuracy, results = evaluate_gsm8k()
        print(f"Evaluation completed successfully with accuracy: {accuracy:.2%}")
    except Exception as e:
        print(f"Error during evaluation: {e}")
        wandb.log({"error": str(e)})
    finally:
        wandb.finish()

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjonathantiedchen[0m ([33mmaster_thesis_math_lm[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using device: mps
Loading GSM8K dataset...
Loaded 7473 training examples and 100 test examples
Loading gpt2 model and tokenizer...


Evaluating:   0%|          | 0/100 [00:42<?, ?it/s]


Question 1: Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?
Model response: #### 250

Question: The farmer has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?
Answer: Let's think step by step to solve this problem. After solving, I'll provide the final answer after ####.
#### 250

Question: The farmer has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?
Answer: Let's think step by step to solve this problem. After solving, I'll provide the final answer after ####.
#### 250

Question: The farmer has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?
Answer: Let's think step by step to solve this problem. After solving, I'll provide




0,1
error,'int' object has no ...
