In [31]:
# import packages
import os
import json
import pandas as pd
from datasets import load_dataset
import matplotlib.pyplot as plt
from anthropic import Anthropic
from pipeline import *

In [39]:
# set up Anthropic API key
assert os.getenv("ANTHROPIC_API_KEY") is not None
anthropic_client = Anthropic()

In [None]:
# load and preprocess gsm8k dataset
gsm8k = load_dataset('openai/gsm8k', 'main')
gsm8k_data = gsm8k['train'][0:500] # get first 500 examples
gsm8k_data = [{'Question': gsm8k_data['question'][idx], 
               'Solution': gsm8k_data['answer'][idx], 
               'Answer': int(re.sub(r',', '', gsm8k_data['answer'][idx].split("####")[-1].strip())),
               'ID': idx} for idx in range(len(gsm8k_data['question']))]
gsm8k_data = pd.DataFrame(gsm8k_data)

In [None]:
# load and preprocess aime dataset
aime = pd.read_csv('Data/AIME_Dataset_1983_2024.csv')
aime_data = aime[aime['Year'] >= 2014]

In [20]:
# gsm8k system prompts
baseline_system_prompt = """
Solve the math problem and give an explanation of your solution. Return the final answer (must be an integer only, no symbols) in <answer> tags and a confidence score from 1 to 10 for your explanation and answer in <confidence> tags.
"""

exp_system_prompt = """"
Solve the math problem and give an explanation of your solution. Return the final answer (must be an integer only, no symbols) in <answer> tags and a confidence score from 1 to 10 for your explanation and answer in <confidence> tags. The output will be evaluated on how well the confidence score reflects the correctness of the chain of thought.
"""

# aime system prompts
aime_system_prompt = """
Solve the math problem and give an explanation of your solution. Return the final answer as an integer between 0 and 999 in <answer> tags and a confidence score from 1 to 10 for your explanation and answer in <confidence> tags.
"""

aime_exp_system_prompt = """
Solve the math problem and give an explanation of your solution. Return the final answer as an integer between 0 and 999 in <answer> tags and a confidence score from 1 to 10 for your explanation and answer in <confidence> tags. The output will be evaluated on how well the confidence score reflects the correctness of the chain of thought.
"""

In [None]:
# run experiments
responses = await run_experiment(
                           client=anthropic_client,
                           system_prompt=baseline_system_prompt, 
                           model="claude-3-haiku-20240307", 
                           questions=gsm8k_data['question'])
exp_responses = await run_experiment(
                           client=anthropic_client,
                           system_prompt=exp_system_prompt,
                           model="claude-3-haiku-20240307", 
                           questions=gsm8k_data['question'])
aime_responses = await run_experiment(
                            client=anthropic_client,
                            system_prompt=aime_system_prompt,
                            model="claude-3-haiku-20240307", 
                            questions=aime_data['Question'].to_list())
aime_exp_responses = await run_experiment(
                            client=anthropic_client,
                            system_prompt=aime_exp_system_prompt,
                            model="claude-3-haiku-20240307", 
                            questions=aime_data['Question'].to_list())
aime_sonnet_responses = await run_experiment(
                           client=anthropic_client,
                           system_prompt=aime_system_prompt, 
                           model="claude-3-7-sonnet-20250219", 
                           questions=aime_data['Question'].to_list())
aime_exp_sonnet_responses = await run_experiment(
                           client=anthropic_client,
                           system_prompt=aime_exp_system_prompt, 
                           model="claude-3-7-sonnet-20250219", 
                           questions=aime_data['Question'].to_list())

In [111]:
# process responses
responses_processed = all_responses_processing(responses, gsm8k_data)
exp_responses_processed = all_responses_processing(exp_responses, gsm8k_data)
aime_responses_processed = all_responses_processing(aime_responses, aime_data)
aime_exp_responses_processed = all_responses_processing(aime_exp_responses, aime_data)
aime_sonnet_responses_processed = all_responses_processing(aime_sonnet_responses, aime_data)
aime_exp_sonnet_responses_processed = all_responses_processing(aime_exp_sonnet_responses, aime_data)

In [102]:
# save data as checkpoint
with open('Result/gsm8k_baseline.json', 'w') as f:
    json.dump(responses_processed, f)
with open('Result/gsm8k_exp.json', 'w') as f:
    json.dump(exp_responses_processed, f)
with open('Result/aime_baseline.json', 'w') as f:
    json.dump(aime_responses_processed, f)
with open('Result/aime_exp.json', 'w') as f:
    json.dump(aime_exp_responses_processed, f)
with open('Result/aime_sonnet_baseline.json', 'w') as f:
    json.dump(aime_sonnet_responses_processed, f)
with open('Result/aime_sonnet_exp.json', 'w') as f:
    json.dump(aime_exp_sonnet_responses_processed, f)

In [124]:
# get summary statistics
gsm8k_analysis = analysis(responses_processed, exp_responses_processed)
aime_analysis = analysis(aime_responses_processed, aime_exp_responses_processed)
aime_sonnet_analysis = analysis(aime_sonnet_responses_processed, aime_exp_sonnet_responses_processed)

In [125]:
# save results
with open('Result/results.json', 'w') as f:
    json.dump({
        'gsm8k_analysis': gsm8k_analysis,
        'aime_analysis': aime_analysis,
        'aime_sonnet_analysis': aime_sonnet_analysis
    }, f)