In [None]:
import sys
import os
from typing import Iterable, Union, Any
from pathlib import Path
import json 
import pandas as pd
import numpy as np
from datasets import load_dataset
from collections import Counter, defaultdict

from mmlu_utils import * 

import sys
sys.path.append('../')  

from utils import * 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_mmlu_pro():
    dataset = load_dataset("TIGER-Lab/MMLU-Pro")
    test_df, val_df = dataset["test"], dataset["validation"]
    test_df = preprocess(test_df)
    val_df = preprocess(val_df)
    return test_df, val_df


def preprocess(test_df):
    res_df = []
    for each in test_df:
        options = [opt for opt in each["options"] if opt != "N/A"]
        each["options"] = options
        res_df.append(each)
    res = {}
    for each in res_df:
        if each["category"] not in res:
            res[each["category"]] = []
        res[each["category"]].append(each)
    return res

def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
    with open(file, "r", encoding="utf-8") as f:
        for line in f:
            try:
                yield json.loads(line)
            except:
                print("Error in loading:", line)
                exit()

test_df, dev_df = load_mmlu_pro()
subjects = list(test_df.keys())

print("assigned subjects", subjects)

assigned subjects ['business', 'law', 'psychology', 'biology', 'chemistry', 'history', 'other', 'health', 'economics', 'math', 'physics', 'computer science', 'philosophy', 'engineering']


In [None]:
path = "../result/mmlu_pro/llama/mmlu_pro_few.jsonl"
with open(path, 'r', encoding='utf-8') as f:
    llama_data = [json.loads(line) for line in f]


grouped_entries = defaultdict(list)
for entry in llama_data:
    subject = entry['doc']['category'] # entry['entry']['category'] 
    grouped_entries[subject].append(entry)


output_dir = "../result/mmlu_pro"
os.makedirs(output_dir, exist_ok=True)


for subject, entries in grouped_entries.items():
    save_path = os.path.join(output_dir, f"{subject}_result.jsonl")
    with open(save_path, 'w', encoding='utf-8') as f:
        for entry in entries:
            line = json.dumps(entry, ensure_ascii=False)
            f.write(line + "\n")
    print(f"Subject: {subject} - Total entries: {len(entries)}")

In [43]:
models = ['gpt-4o', 'gpt-4o-mini']
types = ['_zero_result', '_result']

for model in models:
    for t in types:
        zero_result = {}

        # --------------------
        # 1) Collect results
        # --------------------
        for subject in subjects:
            output_res_path = os.path.join(f"../result/mmlu_pro/{model}/", subject + f"{t}.jsonl")
            if not os.path.exists(output_res_path):
                print(f"[WARNING] File not found: {output_res_path}")
                continue

            with open(output_res_path, 'r', encoding='utf-8') as f:
                res = [json.loads(line) for line in f]

            if not res:
                print(f"[WARNING] No predictions for subject={subject} (model={model}, type={t})")
                continue

            # Number of questions
            N = len(res)
            # Number of predictions per question
            num_indices = len(res[0]['model_outputs'])

            index_correct_counts = [0] * num_indices
            anycorrect_count = 0
            self_consistency_count = 0

            for r in res:
                answer = r['entry']['answer']
                preds = [extract_answer(mo) for mo in r['model_outputs']]

                # (1) Index Accuracy
                for i, pred in enumerate(preds):
                    if pred == answer:
                        index_correct_counts[i] += 1

                # (2) Any Correct
                if answer in preds:
                    anycorrect_count += 1

                # (3) Self-Consistency (majority voting)
                pred_counter = Counter(preds)
                majority_pred = pred_counter.most_common(1)[0][0]
                if majority_pred == answer:
                    self_consistency_count += 1

            index_accuracy = [count / N for count in index_correct_counts]
            anycorrect = anycorrect_count / N
            self_consistency_accuracy = self_consistency_count / N

            zero_result[subject] = {
                "N": N,
                "index_accuracy": index_accuracy,
                "self_consistency_accuracy": self_consistency_accuracy,
                "anycorrect": anycorrect,
            }

        # If we found no results at all for this model-type combination, skip
        if not zero_result:
            print(f"\n[INFO] No valid subjects for model={model}, type={t}. Skipping.")
            continue

        # --------------------
        # 2) Compute averages
        # --------------------
        total_len = 0
        weighted_index_accuracy_sum = None
        weighted_self_consistency_sum = 0
        weighted_anycorrect_sum = 0

        simple_index_accuracy_list = []
        simple_self_consistency_list = []
        simple_anycorrect_list = []

        # Example subject to check how many indices exist
        example_subject = next(iter(zero_result))
        num_indices = len(zero_result[example_subject]['index_accuracy'])

        for subj, result in zero_result.items():
            N = result['N']
            total_len += N

            if weighted_index_accuracy_sum is None:
                weighted_index_accuracy_sum = np.array(result['index_accuracy']) * N
            else:
                weighted_index_accuracy_sum += np.array(result['index_accuracy']) * N

            weighted_self_consistency_sum += result['self_consistency_accuracy'] * N
            weighted_anycorrect_sum += result['anycorrect'] * N

            simple_index_accuracy_list.append(result['index_accuracy'])
            simple_self_consistency_list.append(result['self_consistency_accuracy'])
            simple_anycorrect_list.append(result['anycorrect'])

        # Weighted average
        if total_len > 0:
            weighted_index_accuracy_avg = weighted_index_accuracy_sum / total_len
            weighted_self_consistency_avg = weighted_self_consistency_sum / total_len
            weighted_anycorrect_avg = weighted_anycorrect_sum / total_len
        else:
            weighted_index_accuracy_avg = [0] * num_indices
            weighted_self_consistency_avg = 0.0
            weighted_anycorrect_avg = 0.0

        # Simple average (mean of each subject's accuracy)
        simple_index_accuracy_avg = (
            np.mean(simple_index_accuracy_list, axis=0) 
            if simple_index_accuracy_list 
            else [0] * num_indices
        )
        simple_self_consistency_avg = (
            np.mean(simple_self_consistency_list) 
            if simple_self_consistency_list 
            else 0.0
        )
        simple_anycorrect_avg = (
            np.mean(simple_anycorrect_list) 
            if simple_anycorrect_list 
            else 0.0
        )

        # --------------------
        # 3) Print the table
        # --------------------
        print(f"\n=== Results for model={model}, type={t} ===")

        # Build table header
        header = (
            "Subject".ljust(25) +
            "len".center(10) +
            "".join([f"Idx{i}".center(10) for i in range(num_indices)]) +
            "Self-Consistency".center(15) +
            "Anycorrect".center(15)
        )
        print(header)
        print("-" * len(header))

        # Row for each subject
        for subj, result in zero_result.items():
            line = subj.ljust(25) + str(result['N']).center(10)
            for acc in result['index_accuracy']:
                line += f"{acc*100:.2f}%".center(10)
            line += f"{result['self_consistency_accuracy']*100:.2f}%".center(15)
            line += f"{result['anycorrect']*100:.2f}%".center(15)
            print(line)

        # Separator
        print("-" * len(header))

        # Weighted average row
        line_weighted = "TOTAL (Weighted)".ljust(25) + "".center(10)
        for acc in weighted_index_accuracy_avg:
            line_weighted += f"{acc*100:.2f}%".center(10)
        line_weighted += f"{weighted_self_consistency_avg*100:.2f}%".center(15)
        line_weighted += f"{weighted_anycorrect_avg*100:.2f}%".center(15)
        print(line_weighted)

        # Simple average row
        line_simple = "TOTAL (Simple)".ljust(25) + "".center(10)
        for acc in simple_index_accuracy_avg:
            line_simple += f"{acc*100:.2f}%".center(10)
        line_simple += f"{simple_self_consistency_avg*100:.2f}%".center(15)
        line_simple += f"{simple_anycorrect_avg*100:.2f}%".center(15)
        print(line_simple)

        # Example of some additional metric
        # (This looks like an MMLU-related metric: sum of index accuracy divided by 5, for 5 shots or 5 indices.)
        print(f"Some custom measure: {sum(weighted_index_accuracy_avg) / 5:.4f}")


=== Results for model=gpt-4o, type=_zero_result ===
Subject                     len       Idx0      Idx1      Idx2      Idx3      Idx4   Self-Consistency   Anycorrect  
--------------------------------------------------------------------------------------------------------------------
business                    300      75.33%    76.67%    76.67%    74.00%    76.00%       78.33%         87.67%    
law                         300      59.67%    61.67%    59.00%    59.33%    58.00%       61.67%         72.00%    
psychology                  300      81.00%    79.00%    78.00%    80.00%    79.33%       79.00%         87.67%    
biology                     300      85.00%    87.33%    87.33%    87.33%    88.33%       87.33%         92.33%    
chemistry                   300      70.00%    71.00%    71.67%    68.67%    68.00%       74.33%         87.00%    
history                     300      69.33%    70.33%    71.00%    71.67%    70.33%       71.33%         77.00%    
other            

In [None]:
llama_types = ['_zero_result', '_result']
output_dir = "../result/mmlu_pro/llama/"


for t in llama_types:
    print(f"\n=== Processing type: {t} ===")
    llama_results = {}

    for subject in subjects:
        subject_key = subject
        

        file_path = os.path.join(output_dir, subject_key + f"{t}.json")
        if not os.path.exists(file_path):
            print(f"[WARNING] File not found: {file_path}")
            continue

        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        if not data:
            print(f"[WARNING] No predictions for subject={subject_key}")
            continue

        N = len(data)
        num_indices = 5  

        index_correct_counts = [0] * num_indices
        anycorrect_count = 0
        self_consistency_count = 0

        for entry in data:
            answer = entry['doc']['gold']
            
            predictions = [extract_answer(model_output) for model_output in entry['resps'][0]]

            
            for i, pred in enumerate(predictions):
                if pred == answer:
                    index_correct_counts[i] += 1

            
            if answer in predictions:
                anycorrect_count += 1

            
            majority_pred = Counter(predictions).most_common(1)[0][0]
            if majority_pred == answer:
                self_consistency_count += 1

        index_accuracy = [count / N for count in index_correct_counts]
        anycorrect = anycorrect_count / N
        self_consistency_accuracy = self_consistency_count / N

        llama_results[subject_key] = {
            "N": N,
            "index_accuracy": index_accuracy,
            "self_consistency_accuracy": self_consistency_accuracy,
            "anycorrect": anycorrect,
        }

    
    if not llama_results:
        print(f"No valid subjects found for type {t}.")
        continue

    
    total_len = 0
    weighted_index_accuracy_sum = None
    weighted_self_consistency_sum = 0
    weighted_anycorrect_sum = 0

    simple_index_accuracy_list = []
    simple_self_consistency_list = []
    simple_anycorrect_list = []

    
    example_subject = next(iter(llama_results))
    num_indices = len(llama_results[example_subject]['index_accuracy'])

    for subj, result in llama_results.items():
        N = result['N']
        total_len += N

        if weighted_index_accuracy_sum is None:
            weighted_index_accuracy_sum = np.array(result['index_accuracy']) * N
        else:
            weighted_index_accuracy_sum += np.array(result['index_accuracy']) * N

        weighted_self_consistency_sum += result['self_consistency_accuracy'] * N
        weighted_anycorrect_sum += result['anycorrect'] * N

        simple_index_accuracy_list.append(result['index_accuracy'])
        simple_self_consistency_list.append(result['self_consistency_accuracy'])
        simple_anycorrect_list.append(result['anycorrect'])

    if total_len > 0:
        weighted_index_accuracy_avg = weighted_index_accuracy_sum / total_len
        weighted_self_consistency_avg = weighted_self_consistency_sum / total_len
        weighted_anycorrect_avg = weighted_anycorrect_sum / total_len
    else:
        weighted_index_accuracy_avg = [0] * num_indices
        weighted_self_consistency_avg = 0.0
        weighted_anycorrect_avg = 0.0

    simple_index_accuracy_avg = (
        np.mean(simple_index_accuracy_list, axis=0)
        if simple_index_accuracy_list
        else [0] * num_indices
    )
    simple_self_consistency_avg = (
        np.mean(simple_self_consistency_list)
        if simple_self_consistency_list
        else 0.0
    )
    simple_anycorrect_avg = (
        np.mean(simple_anycorrect_list)
        if simple_anycorrect_list
        else 0.0
    )

    
    print(f"\n=== Results for Llama (type: {t}) ===")
    header = (
        "Subject".ljust(25) +
        "len".center(10) +
        "".join([f"Idx{i}".center(10) for i in range(num_indices)]) +
        "Self-Consistency".center(15) +
        "Anycorrect".center(15)
    )
    print(header)
    print("-" * len(header))

    for subj, result in llama_results.items():
        line = subj.ljust(25) + str(result['N']).center(10)
        for acc in result['index_accuracy']:
            line += f"{acc*100:.2f}%".center(10)
        line += f"{result['self_consistency_accuracy']*100:.2f}%".center(15)
        line += f"{result['anycorrect']*100:.2f}%".center(15)
        print(line)

    print("-" * len(header))
    line_weighted = "TOTAL (Weighted)".ljust(25) + "".center(10)
    for acc in weighted_index_accuracy_avg:
        line_weighted += f"{acc*100:.2f}%".center(10)
    line_weighted += f"{weighted_self_consistency_avg*100:.2f}%".center(15)
    line_weighted += f"{weighted_anycorrect_avg*100:.2f}%".center(15)
    print(line_weighted)

    line_simple = "TOTAL (Simple)".ljust(25) + "".center(10)
    for acc in simple_index_accuracy_avg:
        line_simple += f"{acc*100:.2f}%".center(10)
    line_simple += f"{simple_self_consistency_avg*100:.2f}%".center(15)
    line_simple += f"{simple_anycorrect_avg*100:.2f}%".center(15)
    print(line_simple)

    
    custom_metric = sum(weighted_index_accuracy_avg) / num_indices
    print(f"Some custom measure: {custom_metric:.4f}")



=== Processing type: _zero_result ===

=== Results for Llama (type: _zero_result) ===
Subject                     len       Idx0      Idx1      Idx2      Idx3      Idx4   Self-Consistency   Anycorrect  
--------------------------------------------------------------------------------------------------------------------
business                    300      35.00%    35.33%    32.00%    30.67%    34.67%       43.67%         65.33%    
law                         300      25.33%    29.67%    26.33%    26.67%    28.00%       32.67%         61.00%    
psychology                  300      59.67%    54.67%    58.00%    58.67%    57.00%       64.33%         81.67%    
biology                     300      60.00%    63.33%    62.00%    61.33%    60.33%       68.33%         85.33%    
chemistry                   300      23.67%    28.00%    26.00%    27.33%    27.33%       33.00%         63.33%    
history                     300      42.33%    40.67%    40.67%    39.00%    39.67%       46.33%   

In [None]:
def add_pred(subject, model, output_dir):
    likelihoods_file = os.path.join(output_dir, model, subject, "few/all_likelihoods.json")

    if not os.path.exists(likelihoods_file):
        print(f"Error: {likelihoods_file} not found.")
        return
    with open(likelihoods_file, "r") as f:
        likelihoods = json.load(f)
    problem_groups = list(zip(*likelihoods))

    for problem_likelihoods in tqdm(problem_groups, desc="Processing problems"):
        problem_list = list(problem_likelihoods)
        for cl in problem_list: 
            pred = extract_answer(cl['model_output'])
            cl['pred'] = pred
            is_correct = False
            if pred == cl['answer']:
                is_correct = True 
            cl['is_correct'] = is_correct

    with open(likelihoods_file, "w") as f:
        json.dump(likelihoods, f, indent=4)
    


In [None]:
base_dir = "../likelihood/mmlu_pro"

for model in ['llama']:
    total_correct = 0
    total_entries = 0
    subject_accuracies = {}

    for subject in subjects:
        add_pred(subject, model, base_dir)

        likelihoods_file = os.path.join(base_dir, model, subject, "few/all_likelihoods.json")
        if not os.path.exists(likelihoods_file):
            print(f"File not found: {likelihoods_file}")
            continue

        with open(likelihoods_file, "r", encoding="utf-8") as f:
            likelihoods = json.load(f)

        subject_correct = 0
        subject_total = 0
        for sublist in likelihoods:
            for cl in sublist:
                subject_total += 1
                if cl.get("is_correct", False):
                    subject_correct += 1

        if subject_total > 0:
            accuracy = subject_correct / subject_total
        else:
            accuracy = 0

        subject_accuracies[subject] = accuracy
        total_correct += subject_correct
        total_entries += subject_total

    if total_entries > 0:
        overall_accuracy = total_correct / total_entries
    else:
        overall_accuracy = 0

    print(f"\n=== Model: {model} ===")
    for subject, acc in subject_accuracies.items():
        print(f"Subject {subject}: Accuracy = {acc:.4f}")
    print(f"Overall Mean Accuracy (all subjects): {overall_accuracy:.4f}")
    print("-" * 50)

Processing problems: 100%|██████████| 300/300 [00:06<00:00, 49.48it/s]
Processing problems: 100%|██████████| 300/300 [00:05<00:00, 50.80it/s]
Processing problems: 100%|██████████| 300/300 [00:02<00:00, 138.28it/s]
Processing problems: 100%|██████████| 300/300 [00:11<00:00, 26.90it/s]
Processing problems: 100%|██████████| 300/300 [00:12<00:00, 23.71it/s]
Processing problems: 100%|██████████| 300/300 [00:05<00:00, 56.63it/s] 
Processing problems: 100%|██████████| 300/300 [00:04<00:00, 61.35it/s] 
Processing problems: 100%|██████████| 300/300 [00:12<00:00, 24.81it/s]
Processing problems: 100%|██████████| 300/300 [00:07<00:00, 37.94it/s]
Processing problems: 100%|██████████| 300/300 [00:06<00:00, 47.16it/s]
Processing problems: 100%|██████████| 300/300 [00:05<00:00, 51.00it/s]
Processing problems: 100%|██████████| 300/300 [00:05<00:00, 57.38it/s] 
Processing problems: 100%|██████████| 300/300 [00:07<00:00, 42.21it/s]
Processing problems: 100%|██████████| 300/300 [00:17<00:00, 16.70it/s]



=== Model: llama ===
Subject business: Accuracy = 0.4047
Subject law: Accuracy = 0.2753
Subject psychology: Accuracy = 0.5540
Subject biology: Accuracy = 0.5540
Subject chemistry: Accuracy = 0.2287
Subject history: Accuracy = 0.4280
Subject other: Accuracy = 0.4360
Subject health: Accuracy = 0.4720
Subject economics: Accuracy = 0.4460
Subject math: Accuracy = 0.3140
Subject physics: Accuracy = 0.2813
Subject computer science: Accuracy = 0.4313
Subject philosophy: Accuracy = 0.3847
Subject engineering: Accuracy = 0.2060
Overall Mean Accuracy (all subjects): 0.3869
--------------------------------------------------
