In [2]:
import json
from prettytable import PrettyTable
import json
import os
import random
import pandas as pd

# Read the jsonl file and convert it to a JSON list
def jsonl_to_json_list(jsonl_file_path):
    json_list = []
    with open(jsonl_file_path, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line.strip())  # Parse each line as JSON
            json_list.append(json_obj)
    
    return json_list

# Save the JSON list to a file
def save_as_json(json_list, output_file_path):
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        json.dump(json_list, outfile, indent=4)

def save_as_jsonl(json_list, output_file_path):
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        for json_obj in json_list:
            json.dump(json_obj, outfile)
            outfile.write('\n')

In [14]:
def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

def deduplicate_data(data):
    seen = set()
    deduplicated_data = []
    for item in data:
        idx = item['realidx']
        if idx not in seen:
            deduplicated_data.append(item)
            seen.add(idx)
    return deduplicated_data

def calculate_accuracy(data):
    correct_predictions = 0
    total_predictions = len(data)
    for item in data:
        if 'predicted_answer' not in item:
            print(item['realidx'])
        if item['answer_idx'] == item['predicted_answer']:
            correct_predictions += 1
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    return accuracy

def calculate_cost_from_token_usage(data, model):
    total_cost = 0
    for item in data:
        if model == 'gpt-4o-mini':
            total_cost += item['token_usage']['prompt_tokens'] * 0.15 / 1000000 + item['token_usage']['completion_tokens'] * 0.6 / 1000000
        elif model == 'gpt-4o':
            total_cost += item['token_usage']['prompt_tokens'] * 2.5 / 1000000 + item['token_usage']['completion_tokens'] * 10 / 1000000
        elif model == 'o3-mini':
            total_cost += item['token_usage']['prompt_tokens'] * 1.1 / 1000000 + item['token_usage']['completion_tokens'] * 4.4 / 1000000
    return total_cost / len(data)

def calculate_time_from_data(data):
    total_time = 0
    for item in data:
        total_time += item['time_elapsed']
    return total_time / len(data)

tasks = {
    'medqa': ['test_hard'],
    'pubmedqa': ['test_hard'],
    'medmcqa': ['test_hard'],
    'medbullets': ['test_hard'],
    'mmlu': ['test_hard'],
    'mmlu-pro': ['test_hard'],
}
models = ['o3-mini']
methods = ['zero_shot']

table = PrettyTable()
table.field_names = ["Model", "Task", "Subtask", "Method", "Accuracy", "Cost per sample(USD)", "Time per sample(s)", "Total Number"]

total_cost = 0

for model in models:
    for task in tasks:
        for subtask in tasks[task]:
            for method in methods:
                file_path = f'./output/{task}/{model}-{task}-{subtask}-{method}.json'
                data = load_json(file_path)
                deduplicated_data = deduplicate_data(data)
                accuracy = calculate_accuracy(deduplicated_data)
                total = len(deduplicated_data)
                cost_per_sample = calculate_cost_from_token_usage(deduplicated_data, model)
                total_cost += cost_per_sample * total
                table.add_row([
                    model, task, subtask, method,
                    f"{accuracy * 100:.1f}%",
                    cost_per_sample,
                    calculate_time_from_data(deduplicated_data),
                    total
                ])

print(table)
print(f"\nTotal cost of experiment: ${total_cost:.2f}")

+---------+------------+-----------+-----------+----------+-----------------------+--------------------+--------------+
|  Model  |    Task    |  Subtask  |   Method  | Accuracy |  Cost per sample(USD) | Time per sample(s) | Total Number |
+---------+------------+-----------+-----------+----------+-----------------------+--------------------+--------------+
| o3-mini |   medqa    | test_hard | zero_shot |  80.5%   | 0.0031890165562913912 | 14.11931477635112  |     302      |
| o3-mini |  pubmedqa  | test_hard | zero_shot |  28.6%   | 0.0017249016806722695 |  7.78158868461096  |     119      |
| o3-mini |  medmcqa   | test_hard | zero_shot |  53.9%   | 0.0030588469879518046 | 16.284704251983992 |     913      |
| o3-mini | medbullets | test_hard | zero_shot |  76.2%   |  0.002452895238095238 | 10.666467706362406 |      84      |
| o3-mini |    mmlu    | test_hard | zero_shot |  73.4%   |  0.002941069364161852 | 16.95308729403281  |     173      |
| o3-mini |  mmlu-pro  | test_hard | zer