In [1]:
import torch
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from datasets import load_dataset
import re
import wandb
from datetime import datetime
from utils import get_device, load_model, load_gsm8k, extract_answer, create_cot_prompt, generate_answer_hf, is_correct_check

# Configuration
config = {
    # Model settings
    "model_name": "gpt2",  # Options: "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"
    "max_length": 512,  # Increased for few-shot CoT
    
    # Dataset settings
    "num_samples": 100,  # Set to None to use the full dataset
    "n_shot": 4,  # Number of examples for few-shot prompting
    
    # Generation settings
    "temperature": 0.7,
    "top_p": 0.9,
    "num_beams": 4,
    
    # Experiment tracking
    "run_name": f"gpt2-gsm8k-8shot-cot-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    "tags": ["gsm8k", "evaluation", "gpt2", "few-shot", "cot"]
}


# Initialize Weights & Biases with config
wandb.init(
    project="gpt2-gsm8k-benchmark", 
    name=config["run_name"],
    config=config,
    tags=config["tags"]
)


# Set device
DEVICE = get_device()

print(f"Using device: {DEVICE}")

# Main evaluation function
def evaluate_gsm8k():
    # Load dataset
    train_set, test_set = load_gsm8k(config)
    
    # Prepare training examples list for few-shot prompting
    train_examples = []
    for ex in train_set:
        train_examples.append({
            "question": ex["question"],
            "answer": ex["answer"]
        })
    
    # Load model and tokenizer
    model, tokenizer = load_model(config, DEVICE)
    
    results = []
    correct_count = 0
    
    for idx, example in enumerate(tqdm(test_set, desc="Evaluating")):
        question = example["question"]
        
        # Extract target answer, forcing it to be a float
        target_answer = extract_answer(example["answer"])
        print(f"\n\n{'-'*80}")
        print(f"Example {idx+1}:")
        print(f"Question: {question}")
        print(f"Target answer: {target_answer}")
        
        # Create and print the full CoT prompt
        prompt = create_cot_prompt(train_examples, question, n_shot=config["n_shot"])
        print(f"\nFull CoT Prompt:\n{'-'*40}\n{prompt}\n{'-'*40}\n")

        # Generate and print the full model response
        model_response = generate_answer_hf(model, tokenizer, prompt, config, DEVICE, model_type="gpt2")
        print(f"\nFull Model Response:\n{'-'*40}\n{model_response}\n{'-'*40}\n")
        
        # Extract answer from response
        predicted_answer = extract_answer(model_response)
        print(f"Extracted predicted answer: {predicted_answer}")
        
        # Fix for the integer error: Ensure both are converted to float for comparison
        is_correct = is_correct_check(predicted_answer, target_answer)
        
        print(f"Is correct? {is_correct}")
        
        if is_correct:
            correct_count += 1
        
        results.append({
            "question": question,
            "target_answer": target_answer,
            "model_response": model_response,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })
        
        # Log to wandb
        wandb.log({
            "running_accuracy": correct_count / (idx + 1),
            "example_idx": idx
        })
        
        # Early stop option for debugging (remove for full evaluation)
        if idx == 3 and 'debug' in config and config['debug']:
            break
    
    # Calculate and log final accuracy
    accuracy = correct_count / len(results)
    print(f"\nFinal accuracy: {accuracy:.2%} ({correct_count}/{len(results)})")
    
    # Log final metrics to wandb
    wandb.log({
        "final_accuracy": accuracy,
    })
    
    # Save detailed results to CSV
    results_df = pd.DataFrame(results)
    results_file = f"{config['model_name']}_gsm8k_{config['n_shot']}shot_results.csv"
    results_df.to_csv(results_file, index=False)
    print(f"Detailed results saved to {results_file}")
    
    # Log results file to wandb
    wandb.save(results_file)
    
    return accuracy, results

if __name__ == "__main__":
    try:
        accuracy, results = evaluate_gsm8k()
        print(f"Evaluation completed successfully with accuracy: {accuracy:.2%}")
    except Exception as e:
        print(f"Error during evaluation: {e}")
        wandb.log({"error": str(e)})
    finally:
        wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ucloud/.netrc


Using device: cpu
Loading GSM8K dataset...


README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Loaded 7473 training examples and 100 test examples
Loading gpt2 model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]



--------------------------------------------------------------------------------
Example 1:
Question: Jen decides to travel to 3 different countries.  He has to pay $400 for the supplies he needs, in total.  The tickets for travel cost, in total, 50% more than the supplies.  How much does travel cost?
Target answer: 1000.0

Full CoT Prompt:
----------------------------------------
Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12

Evaluating:   1%|          | 1/100 [00:03<05:03,  3.07s/it]


Full Model Response:
----------------------------------------
Here we see Jennifer gets paid (after paying) less per year - about 1x what she should be making today... What if her salary was 8 million dollars instead???  When I ask why did it take so long?? In other words, when people say "why do my expenses increase?", often times not very interesting or informative answers will come up while talking through these questions.. So there ya go!     Let me know any comments below.....
----------------------------------------

Extracted predicted answer: 8.0
Is correct? False


--------------------------------------------------------------------------------
Example 2:
Question: Morisette and Kael were asked to bring fruits. Morisette brought 5 apples and 8 oranges, while Kael brought twice the amount of apples and half the number of oranges than Morisette. How many fruits do they have in total?
Target answer: 27.0

Full CoT Prompt:
----------------------------------------
Question: For ev

Evaluating:   2%|▏         | 2/100 [00:12<11:18,  6.93s/it]


Full Model Response:
----------------------------------------
# 1 2 3 ########### * ((((n >> n) | ((n & 0xFFF)))))**> **$100 - % 100 == \sigma_\sqrt{11}^3 ##[([(-f ^ i)]^{9})] [coupling b_{i}} {e}{m}\sum z(\mathbb R)+j|p$$'@Litrix~C:\times C$, which means it takes less time using Fractional Cosine functions then applying fmod N instead of u$. Let me explain why I'm not going into specifics about what happens when we calculate $$A=(kD)-d D/(l d )−q E[/](E)/P @ Litaxa::B => A+(K)(Q). Here \(W\) equals q Q so given W $(H), where H represents B , there will be no change due here just like if P was written something similar without any other parts attached either way.[?] It may also help some people figure things such changes might take longer now considering these results show us exactly who got stuck inside M . The question becomes whether or NOT THE THING CONNECTED WITH PARALLELS WILL WORK IF YOU DONT DO IT RIGHT BEFORE PROPERLY SETTING THEM UP AND THEN CALL YOUR OWN MODEL!
------------

Evaluating:   2%|▏         | 2/100 [00:17<13:57,  8.55s/it]


VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
example_idx,▁█
running_accuracy,▁▁

0,1
example_idx,1
running_accuracy,0


KeyboardInterrupt: 