In [None]:
import sys
import os
from pathlib import Path
import json 
from gpqa_utils import *
import pandas as pd
import numpy as np

sys.path.append('../')  

from utils import * 

from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
examples = load_examples("../data/gpqa/gpqa_diamond.csv", seed=0)

In [None]:
types = ['zero', 'few'] 

for model in ['gpt-4o-mini', 'gpt-4o']:
    for type_ in types:
        idx_acc = [0 for i in range(5)]
        file_path = f"../result/gpqa/{model}/gpqa_{type_}.jsonl"
        if not os.path.exists(file_path):
            print(f"File not found for subject: {type_}")
            continue

        with open(file_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line) for line in f]

        worse_case = 0 
        sc_correct = 0
        any_correct_count = 0  
        total_questions = len(data)

        for entry, example in zip(data, examples): 
            flag = 0
            preds = []


            any_flag = 0  

            for idx, resp in enumerate(entry['model_outputs']):
                pred = parse_sampled_answer(resp)
                if pred is None:
                    is_correct = False
                else: 
                    is_correct = (LETTER_TO_INDEX[pred] == example.correct_index)


                if not is_correct:
                    flag = 1
                idx_acc[idx] += int(is_correct)
                preds.append(pred)


                if is_correct:
                    any_flag = 1

            if flag == 0:
                worse_case += 1 


            any_correct_count += any_flag


            valid_preds = [p for p in preds if p is not None]
            if len(valid_preds) == 0:
                sc_is_correct = False
            else:
                majority_letter, count = Counter(valid_preds).most_common(1)[0]
                sc_is_correct = (LETTER_TO_INDEX[majority_letter] == example.correct_index)
            sc_correct += int(sc_is_correct)
        

        print(f"{model}")
        print(f"Results for {type_}:")


        print("  Worse-case (all correct for a question): "
              f"{worse_case / total_questions:.3f}")


        total = 0
        for idx, acc in enumerate(idx_acc):
            acc_ratio = acc / total_questions
            print(f"  Repetition {idx}: {acc_ratio:.3f}")
            total += acc_ratio
        print(f"  Average (across 5 samples): {total / 5:.3f}")


        print(f"  Self-consistency Accuracy: {sc_correct / total_questions:.3f}")


        any_correct_acc = any_correct_count / total_questions
        print(f"  Any-correct Accuracy: {any_correct_acc:.3f}\n")

In [None]:

print("llama")
for type_ in types:
    idx_acc = [0 for i in range(5)]
    file_path = f"../result/gpqa/llama/gpqa_{type_}.jsonl"
    if not os.path.exists(file_path):
        print(f"File not found for subject: {type_}")
        continue

    with open(file_path, 'r', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]
    
    worse_case = 0
    sc_correct = 0
    any_correct_count = 0  
    total_questions = len(data)

    for entry, example in zip(data, examples): 
        flag = 0
        preds = []
        any_flag = 0  


        for idx, resp in enumerate(entry['resps'][0]):
            pred = parse_sampled_answer(resp)
            if pred is None:
                is_correct = False
            else:
                is_correct = (LETTER_TO_INDEX[pred] == example.correct_index)


            idx_acc[idx] += int(is_correct)

            if not is_correct:
                flag = 1

            preds.append(pred)


            if is_correct:
                any_flag = 1

        if flag == 0:
            worse_case += 1


        any_correct_count += any_flag


        valid_preds = [p for p in preds if p is not None]
        if len(valid_preds) == 0:
            sc_is_correct = False
        else:
            majority_letter, count = Counter(valid_preds).most_common(1)[0]
            sc_is_correct = (LETTER_TO_INDEX[majority_letter] == example.correct_index)
        sc_correct += int(sc_is_correct)


    print(f"Results for {type_}:")
    print("  Worse-case (all correct in a single question): "
          f"{worse_case / total_questions:.3f}")


    total = 0
    for idx, acc in enumerate(idx_acc):
        acc_ratio = acc / total_questions
        print(f"  Repetition {idx}: {acc_ratio:.3f}")
        total += acc_ratio
    print(f"  Average (across 5 samples): {total / 5:.3f}")


    print(f"  Self-consistency Accuracy: {sc_correct / total_questions:.3f}")


    any_correct_acc = any_correct_count / total_questions
    print(f"  Any-correct Accuracy: {any_correct_acc:.3f}\n")


In [None]:
output_dir = "../likelihood/gpqa/"

subjects = [
    "gpt-4o/few",
    "gpt-4o-mini/few", 
    "llama/few"
]


def add_pred(prob_type_filter, output_dir):
    output_dir = f"{output_dir}/{prob_type_filter}"
    likelihoods_file = os.path.join(output_dir, "all_likelihoods.json")
    if not os.path.exists(likelihoods_file):
        print(f"Error: {likelihoods_file} not found.")
        return
    with open(likelihoods_file, "r") as f:
        likelihoods = json.load(f)
    
    examples = load_examples("../data/gpqa/gpqa_diamond.csv", seed=0)
    problem_groups = list(zip(*likelihoods))

    for problem_likelihoods in tqdm(problem_groups, desc="Processing problems"):
        problem_list = list(problem_likelihoods)
        for cl in problem_list: 
            pred = parse_sampled_answer(cl['model_output'])
            cl['pred'] = pred
            gt = examples[cl['id']].correct_index
            cl['gt'] = gt

            if (pred == None):
                is_correct = False
            else :
                is_correct = (LETTER_TO_INDEX[pred] == gt)

            cl['is_correct'] = is_correct
            
    with open(likelihoods_file, "w") as f:
        json.dump(likelihoods, f, indent=4)
    
    print(f"Updated file saved at: {likelihoods_file}")

In [6]:
for subject in subjects:
    add_pred(subject, output_dir)

Processing problems: 100%|██████████| 198/198 [00:00<00:00, 35733.07it/s]


Updated file saved at: ../likelihood_1B/gpqa//gpt-4o/few/all_likelihoods.json


Processing problems: 100%|██████████| 198/198 [00:00<00:00, 38850.68it/s]


Updated file saved at: ../likelihood_1B/gpqa//gpt-4o-mini/few/all_likelihoods.json


Processing problems: 100%|██████████| 198/198 [00:00<00:00, 10676.78it/s]


Updated file saved at: ../likelihood_1B/gpqa//llama/few/all_likelihoods.json
