In [6]:
import os
import json

from gpqa.gpqa_utils import * 

from math500.math_utils import * 
from math500.parser import *
from math500.grader import * 

from mmlu_pro.mmlu_utils import * 

from hotpotqa.hotpotqa_utils import *

from drop.drop_utils import *

from musr.musr import MuSRDataset

from utils import * 

In [1]:
def decode_answer_labels(answer_labels, tokenizer):
    return tokenizer.convert_tokens_to_string(answer_labels)

def get_token_spans(answer_labels, tokenizer):
    full_text = decode_answer_labels(answer_labels, tokenizer)
    spans = []
    current_pos = 0
    for token in answer_labels:
        token_text = tokenizer.convert_tokens_to_string([token]).strip()
        start = full_text.find(token_text, current_pos)
        if start == -1:
            start = current_pos
        end = start + len(token_text)
        spans.append((start, end))
        current_pos = end
    return full_text, spans

def find_substring_token_indices(pred_text, full_text, token_spans):
    start_char = full_text.rfind(pred_text)
    if start_char == -1:
        return -1, -1
    end_char = start_char + len(pred_text)
    
    start_token = None
    for i, (s, _) in enumerate(token_spans):
        if s >= start_char:
            start_token = i
            break
    if start_token is None:
        return -1, -1
    
    end_token = None
    for i in reversed(range(start_token, len(token_spans))):
        _, e = token_spans[i]
        if e <= end_char:
            end_token = i
            break
    if end_token is None:
        return -1, -1
    return start_token, end_token


def compute_prob_diff_for_token_range(start_idx, end_idx, topk_probs_list):
    total_top1 = 0.0
    total_top2 = 0.0
    count = 0
    for i in range(start_idx, end_idx + 1):
        if i >= len(topk_probs_list):
            return None, None, None
        probs = topk_probs_list[i]
        if len(probs) < 2:
            return None, None, None
        total_top1 += probs[0]
        total_top2 += probs[1]
        count += 1
    if count == 0:
        return None, None, None
    avg_top1 = total_top1 / count
    avg_top2 = total_top2 / count
    diff = avg_top1 - avg_top2
    return avg_top1, avg_top2, diff


def map_pred_to_prob_diff(pred, answer_labels, topk_probs_list, tokenizer):
    full_text, token_spans = get_token_spans(answer_labels, tokenizer)
    
    pred_text = pred
    full_text_norm = full_text
    
    if pred is None:
        sentences = [s.strip() for s in re.split(r"[.?!]", full_text_norm) if s.strip()]
        if sentences:
            pred_text = sentences[-1]
        else:
            pred_text = full_text_norm
    else:
        pred_text = pred.lower().strip()

    start_token, end_token = find_substring_token_indices(pred_text, full_text_norm, token_spans)
    
    if start_token == -1 or end_token == -1:
        sentences = [s.strip() for s in re.split(r"[.?!]", full_text_norm) if s.strip()]
        if sentences:
            last_sentence = sentences[-1]
        else:
            last_sentence = full_text_norm  
        
        start_token, end_token = find_substring_token_indices(last_sentence, full_text_norm, token_spans)
        if start_token == -1 or end_token == -1:
            start_token = 0
            end_token = len(answer_labels) - 1
    
    avg_top1, avg_top2, diff = compute_prob_diff_for_token_range(start_token, end_token, topk_probs_list)
    if diff == None: 
        print(start_token, end_token)
    return start_token, end_token, avg_top1, avg_top2, diff


In [3]:
class Config:
    def __init__(self):
        self.model = "gpt-4o-mini"  
        self.task = "math500"  # "math500", "mmlu_pro", "gpqa", "drop", "hotpotqa"
        self.shot_type = "few"  
        self.output_dir = "baselines/baseline"
        self.num_examples = -1       
        self.subject = None


In [4]:
from transformers import AutoTokenizer  

tokenizer_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
keys_to_eval = [
    'max_prob_diff_output',
    'total_max_prob_diff_output',
]

tasks = ["math500", "mmlu_pro", "gpqa", "drop", "hotpotqa", "musr_location", 'musr_efficiently']

shots = ["few"]
models = ['gpt-4o-mini', 'gpt-4o', 'llama']

subjects = ['business', 'law', 'psychology', 'biology', 'chemistry', 'history', 'other', 'health', 'economics', 'math', 'physics', 'computer science', 'philosophy', 'engineering']

config = Config()

In [None]:
from math500.math_utils import * 
from math500.parser import *
from math500.grader import * 


for model in models:
    file_path = f"{config.output_dir}/math500/{model}/math500_{config.shot_type}_few.jsonl"
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        continue
    else:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line) for line in f]

        scores = {k: [] for k in keys_to_eval}

        for entry in data:
            idx = entry["idx"]
        
            _, gt = parse_ground_truth(entry['entry'], "math")

            model_outputs = entry.get('model_outputs', [])
            results_info = entry.get('results', [])


            best_total_diff = float('-inf')
            best_total_pred = None
            best_total_mo = None
            best_total_em, best_total_f1 = 0.0, 0.0

            best_diff = float('-inf')
            best_pred = None
            best_mo = None
            best_em, best_f1 = 0.0, 0.0


            for i, mo in enumerate(model_outputs):

                r = results_info[i] if i < len(results_info) else {}
                answer_labels = r.get("answer_labels", [])
                answer_label_probs = r.get("answer_label_probs", [])
                topk_tokens_list = r.get("topk_tokens", [])
                topk_probs_list  = r.get("topk_probs", [])
                total_diff = np.mean([top1 - top2 for top1, top2 in topk_probs_list])

                
                pred = extract_answer(mo, "math")
                start_idx, end_idx, avg_top1, avg_top2, diff = map_pred_to_prob_diff(
                    pred,
                    answer_labels,
                    topk_probs_list,
                    tokenizer
                )

                pred = strip_string(pred)

                if total_diff > best_total_diff:
                    best_total_diff = total_diff
                    best_total_pred = pred
                    best_total_mo = mo 


                if diff > best_diff:
                    best_diff = diff
                    best_pred = pred
                    best_mo = mo
                    

            k = "max_prob_diff_output"

            
            result = math_equal_process((idx, best_pred, gt))
        
            if not result:
                result = process_results(gt, [best_mo])
                if not result:
                    best_pred = extract_answer(best_pred, "math")
                    result = math_equal_process((None, best_pred, gt))
            scores[k].append(result)


            k = "total_max_prob_diff_output"

            
            result = math_equal_process((idx, best_total_pred, gt))

            if not result :
                result = process_results(gt, [best_total_mo])
                if not result:
                    best_total_pred = extract_answer(best_total_pred, "math")
                    result = math_equal_process((None, best_total_pred, gt))
            scores[k].append(result)


        print(f"\n===== Evaluation Results for model={model}, shot={config.shot_type} =====")
        for key in keys_to_eval:

            if len(scores[key]) == 0:
                print(f"{key} -> No data / Not found in entries")
                continue

            acc = sum(scores[key]) / len(scores[key])
            print(f"{key} -> Accuracy: {acc:.4f}")
        print("-" * 50)



===== Evaluation Results for model=gpt-4o-mini, shot=few =====
max_prob_diff_output -> Accuracy: 0.7780
total_max_prob_diff_output -> Accuracy: 0.7660
--------------------------------------------------

===== Evaluation Results for model=gpt-4o, shot=few =====
max_prob_diff_output -> Accuracy: 0.7800
total_max_prob_diff_output -> Accuracy: 0.7780
--------------------------------------------------

===== Evaluation Results for model=llama, shot=few =====
max_prob_diff_output -> Accuracy: 0.4780
total_max_prob_diff_output -> Accuracy: 0.4980
--------------------------------------------------


In [8]:
def extract_answer(text):
    pattern = r"answer is \(?([A-J])\)?"
    match = re.search(pattern, text)
    if match:
        return match.group(1)
    else:
        return extract_again(text)


for model in models:

    overall_scores = {k: 0 for k in keys_to_eval}  
    overall_total_entries = 0

    for subject in subjects:
        file_path = f"{config.output_dir}/mmlu_pro/{model}/{subject}/mmlu_pro_few_few.jsonl"
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f]


        subject_scores = {k: 0 for k in keys_to_eval}
        total_data_len = len(data)
        overall_total_entries += total_data_len

        for entry in data:
            model_outputs = entry.get('model_outputs', [])
            results_info = entry.get('results', [])
            answer = entry['entry'].get('answer') or entry['entry'].get('gold')


            best_diff = float('-inf')
            best_pred = None

            best_total_diff = float('-inf')
            best_total_pred = None


            for i, mo in enumerate(model_outputs):
                r = results_info[i] if i < len(results_info) else {}
                answer_labels = r.get("answer_labels", [])
                answer_label_probs = r.get("answer_label_probs", [])
                topk_tokens_list = r.get("topk_tokens", [])
                topk_probs_list  = r.get("topk_probs", [])


                total_diff = np.mean([top1 - top2 for top1, top2 in topk_probs_list])

                pred = extract_answer(mo)
                if pred is None:
                    sentences = [s.strip() for s in re.split(r"[.?!]", mo) if s.strip()]
                    pred = sentences[-1] if sentences else mo

                start_idx, end_idx, avg_top1, avg_top2, diff = map_pred_to_prob_diff(
                    pred,
                    answer_labels,
                    topk_probs_list,
                    tokenizer
                )

                if total_diff > best_total_diff:
                    best_total_diff = total_diff
                    best_total_pred = pred

                if diff > best_diff:
                    best_diff = diff
                    best_pred = pred


            k = "max_prob_diff_output"
            if best_pred == answer:
                subject_scores[k] += 1
            k = "total_max_prob_diff_output"
            if best_total_pred == answer:
                subject_scores[k] += 1

        for k in keys_to_eval:
            overall_scores[k] += subject_scores[k]

    print(f"\n=== Overall Results for model={model} ===")
    for key in keys_to_eval:
        if overall_total_entries == 0:
            acc = 0
        else:
            acc = overall_scores[key] / overall_total_entries
        print(f"{key} -> Accuracy: {acc:.4f}")
    print("-" * 50)



=== Overall Results for model=gpt-4o-mini ===
max_prob_diff_output -> Accuracy: 0.6417
total_max_prob_diff_output -> Accuracy: 0.6469
--------------------------------------------------

=== Overall Results for model=gpt-4o ===
max_prob_diff_output -> Accuracy: 0.7500
total_max_prob_diff_output -> Accuracy: 0.7517
--------------------------------------------------

=== Overall Results for model=llama ===
max_prob_diff_output -> Accuracy: 0.4483
total_max_prob_diff_output -> Accuracy: 0.4507
--------------------------------------------------


In [9]:
from gpqa.gpqa_utils import * 

examples = load_examples("data/gpqa/gpqa_diamond.csv", seed=0)

for model in models:
    file_path = f"{config.output_dir}/gpqa/{model}/gpqa_few_{config.shot_type}.jsonl"
    
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        continue
    

    with open(file_path, 'r', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]

    scores = {k: 0 for k in keys_to_eval}


    total_data_len = len(data)
    if total_data_len != len(examples):
        print("Warning: data length and examples length do not match!")

    for entry, example in zip(data, examples):
        correct_index = example.correct_index  

        model_outputs = entry.get('model_outputs', [])
        results_info = entry.get('results', [])


        best_total_diff = float('-inf')
        best_total_pred = None
        best_total_mo = None
        best_total_em, best_total_f1 = 0.0, 0.0
        
        best_diff = float('-inf')
        best_pred = None
        best_em, best_f1 = 0.0, 0.0


        for i, mo in enumerate(model_outputs):
            pred = parse_sampled_answer(mo)
            if pred is None:
                pred = mo


            r = results_info[i] if i < len(results_info) else {}
            answer_labels = r.get("answer_labels", [])
            answer_label_probs = r.get("answer_label_probs", [])
            topk_tokens_list = r.get("topk_tokens", [])
            topk_probs_list  = r.get("topk_probs", [])
            total_diff = np.mean([top1 - top2 for top1, top2 in topk_probs_list])


            start_idx, end_idx, avg_top1, avg_top2, diff = map_pred_to_prob_diff(
                pred,
                answer_labels,
                topk_probs_list,
                tokenizer
            )

            if total_diff > best_total_diff:
                best_total_diff = total_diff
                best_total_pred = pred
                best_total_mo = mo 


            if diff > best_diff:
                best_diff = diff
                best_pred = pred
                

        k = "max_prob_diff_output"

        if best_pred is None or len(best_pred) > 5:
            is_correct = False
        else:

            is_correct = (LETTER_TO_INDEX[best_pred] == correct_index)
        scores[k] += int(is_correct)

        k = "total_max_prob_diff_output"

        if best_total_pred is None or len(best_total_pred) > 5:
            is_correct = False
        else:

            is_correct = (LETTER_TO_INDEX[best_total_pred] == correct_index)
        scores[k] += int(is_correct)


    print(f"\n=== Results for model={model}, shot={config.shot_type} ===")
    for key in keys_to_eval:
        acc = scores[key] / total_data_len if total_data_len else 0
        print(f"{key} -> Accuracy: {acc:.4f}")

    print("-" * 50)


=== Results for model=gpt-4o-mini, shot=few ===
max_prob_diff_output -> Accuracy: 0.4242
total_max_prob_diff_output -> Accuracy: 0.4293
--------------------------------------------------

=== Results for model=gpt-4o, shot=few ===
max_prob_diff_output -> Accuracy: 0.4848
total_max_prob_diff_output -> Accuracy: 0.5253
--------------------------------------------------

=== Results for model=llama, shot=few ===
max_prob_diff_output -> Accuracy: 0.3232
total_max_prob_diff_output -> Accuracy: 0.3384
--------------------------------------------------


In [10]:
from drop.drop_utils import *

for model in models: 
    print(model)
    for shot_type in ['few']:
        file_path = f"{config.output_dir}/drop/{model}/drop_few_{shot_type}.jsonl"
    
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f]
        
        em_scores = {k: [] for k in keys_to_eval}
        f1_scores = {k: [] for k in keys_to_eval}


        def get_max_em_f1(pred, golds):
            max_em, max_f1 = 0.0, 0.0
            for gold_answer in golds:
                exact_match, f1_score = get_metrics(pred, gold_answer)

                if gold_answer[0].strip():
                    max_em = max(max_em, exact_match)
                    max_f1 = max(max_f1, f1_score)
            return max_em, max_f1


        for entry in data:
            golds = get_answers(entry['entry'])  # gold answers

            model_outputs = entry.get('model_outputs', [])
            results_info = entry.get('results', [])


            best_total_diff = float('-inf')
            best_total_pred = None
            best_total_mo = None
            best_total_em, best_total_f1 = 0.0, 0.0

            best_diff = float('-inf')
            best_pred = None
            best_em, best_f1 = 0.0, 0.0


            for i, mo in enumerate(model_outputs):
                pred = extract_answer(mo)

                r = results_info[i] if i < len(results_info) else {}
                answer_labels = r.get("answer_labels", [])
                answer_label_probs = r.get("answer_label_probs", [])
                topk_tokens_list = r.get("topk_tokens", [])
                topk_probs_list  = r.get("topk_probs", [])
                total_diff = np.mean([top1 - top2 for top1, top2 in topk_probs_list])


                start_idx, end_idx, avg_top1, avg_top2, diff = map_pred_to_prob_diff(
                    pred,
                    answer_labels,
                    topk_probs_list,
                    tokenizer
                )

                if total_diff > best_total_diff:
                    best_total_diff = diff
                    best_total_pred = pred

                    em_val, f1_val = get_max_em_f1(pred, golds)
                    best_total_em, best_total_f1 = em_val, f1_val


                if diff > best_diff:
                    best_diff = diff
                    best_pred = pred

                    em_val, f1_val = get_max_em_f1(pred, golds)
                    best_em, best_f1 = em_val, f1_val



            k = "max_prob_diff_output"
            em_scores[k].append(best_em)
            f1_scores[k].append(best_f1)

            k = "total_max_prob_diff_output"
            em_scores[k].append(best_total_em)
            f1_scores[k].append(best_total_f1)
            

        print(f"\n===== Results for model={model}, shot={shot_type} =====")
        for k in keys_to_eval:
            if len(em_scores[k]) == 0:
                print(f"{k}: No entries found, skip.")
                continue

            em_mean = np.mean(em_scores[k])
            f1_mean = np.mean(f1_scores[k])
            print(f"{k} -> EM: {em_mean:.4f}, F1: {f1_mean:.4f}")

        print("-" * 50)

gpt-4o-mini

===== Results for model=gpt-4o-mini, shot=few =====
max_prob_diff_output -> EM: 0.7760, F1: 0.8307
total_max_prob_diff_output -> EM: 0.7640, F1: 0.8275
--------------------------------------------------
gpt-4o

===== Results for model=gpt-4o, shot=few =====
max_prob_diff_output -> EM: 0.8300, F1: 0.9081
total_max_prob_diff_output -> EM: 0.8100, F1: 0.8934
--------------------------------------------------
llama

===== Results for model=llama, shot=few =====
max_prob_diff_output -> EM: 0.7020, F1: 0.7510
total_max_prob_diff_output -> EM: 0.6960, F1: 0.7516
--------------------------------------------------


In [11]:
from hotpotqa.hotpotqa_utils import *

def extract_answer(response_text):
    match = re.search(r"Answer\s+(.+)", response_text, re.DOTALL)
    if match:
        answer = match.group(1).strip()

        answer = re.sub(r"[.\n]+$", "", answer).strip()
        return answer


    match = re.search(r"(?<!\w)Answer[:\s]+(.+?)(?:[.\n]|$)", response_text, re.IGNORECASE | re.DOTALL)
    if match:
        answer = match.group(1).strip()

        answer = re.sub(r"[.\n]+$", "", answer).strip()
        return answer
    return response_text.strip()


dataset = json.load(open(f'data/hotpotqa/hotpotqa.json'))
with open("hotpotqa/react_prompt.json", 'r') as f:
    fewshot = json.load(f)

for model in models:
    print(model)
    for shot_type in ['few']:
        file_path = f"{config.output_dir}/hotpotqa/{model}/hotpotqa_few_{shot_type}.jsonl"
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f]
        

        preds = {k: [] for k in keys_to_eval}

        for entry in data:

            model_outputs = entry.get('model_outputs', [])
            results_info = entry.get('results', [])


            best_total_diff = float('-inf')
            best_total_pred = None
            best_total_mo = None
            best_total_em, best_total_f1 = 0.0, 0.0

            best_diff = float('-inf')
            best_pred = None
            best_em, best_f1 = 0.0, 0.0

            for i, mo in enumerate(model_outputs):
                pred = extract_answer(mo)

                r = results_info[i] if i < len(results_info) else {}
                answer_labels = r.get("answer_labels", [])
                answer_label_probs = r.get("answer_label_probs", [])
                topk_tokens_list = r.get("topk_tokens", [])
                topk_probs_list  = r.get("topk_probs", [])

                total_diff = np.mean([top1 - top2 for top1, top2 in topk_probs_list])


                start_idx, end_idx, avg_top1, avg_top2, diff = map_pred_to_prob_diff(
                    pred,
                    answer_labels,
                    topk_probs_list,
                    tokenizer
                )

                if total_diff > best_total_diff:
                    best_total_diff = total_diff
                    best_total_pred = pred
                    best_total_mo = mo 


                if diff > best_diff:
                    best_diff = diff
                    best_pred = pred
                    

            k = "total_max_prob_diff_output"
            preds[k].append(best_total_pred)


            k = "max_prob_diff_output"
            preds[k].append(best_pred)

            
        print(f"\n=== Results for model={model}, shot={shot_type} ===")
        for k in keys_to_eval:
            if len(preds[k]) == 0:
                print(f"{k}: No entries found, skip.")
                continue

            em_scores, f1_scores = get_em_f1(dataset, preds[k])
            em_mean = em_scores.mean()
            f1_mean = f1_scores.mean()
            print(f"{k} -> EM: {em_mean:.4f}, F1: {f1_mean:.4f}")
        
        print("-" * 50)

gpt-4o-mini

=== Results for model=gpt-4o-mini, shot=few ===
max_prob_diff_output -> EM: 0.3400, F1: 0.4564
total_max_prob_diff_output -> EM: 0.3620, F1: 0.4770
--------------------------------------------------
gpt-4o

=== Results for model=gpt-4o, shot=few ===
max_prob_diff_output -> EM: 0.4800, F1: 0.6122
total_max_prob_diff_output -> EM: 0.4600, F1: 0.6037
--------------------------------------------------
llama

=== Results for model=llama, shot=few ===
max_prob_diff_output -> EM: 0.2500, F1: 0.3221
total_max_prob_diff_output -> EM: 0.2440, F1: 0.3175
--------------------------------------------------


In [None]:
from musr.musr import MuSRDataset

mm_path = 'data/musr/murder_mystery.json'
mm = MuSRDataset(mm_path)

ta_path = 'data/musr/team_allocation.json'
ta = MuSRDataset(ta_path)

op_path = 'data/musr/object_placements.json'
op = MuSRDataset(op_path)

In [None]:
for model in models:
    print(model)
    for shot_type in ['few']:

        file_path = f"{config.output_dir}/musr_location/{model}/musr_location_few_{shot_type}.jsonl"
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f]
        
        preds = {k: [] for k in keys_to_eval}

        for test_idx, entry in enumerate(data):

            model_outputs = entry.get('model_outputs', [])
            results_info = entry.get('results', [])


            best_total_diff = float('-inf')
            best_total_pred = None


            best_diff = float('-inf')
            best_pred = None

            for i, mo in enumerate(model_outputs):

                pred = op.evaluate_response([mo], op[test_idx])[0]['model_answer']
                
                r = results_info[i] if i < len(results_info) else {}
                topk_probs_list  = r.get("topk_probs", [])


                total_diff = np.mean([top1 - top2 for (top1, top2) in topk_probs_list]) \
                            if topk_probs_list else float('-inf')


                _, _, _, _, diff = map_pred_to_prob_diff(
                    pred,
                    r.get("answer_labels", []),
                    topk_probs_list,
                    tokenizer
                )


                if total_diff > best_total_diff:
                    best_total_diff = total_diff
                    best_total_pred = mo


                if diff > best_diff:
                    best_diff = diff
                    best_pred = mo


            preds["total_max_prob_diff_output"].append(best_total_pred)
            preds["max_prob_diff_output"].append(best_pred)


        total_data_len = len(data)        
        scores = {k: 0 for k in keys_to_eval}
        
        for i, entry in enumerate(data):
            if 'entry' not in entry:
                continue
            
            for k in keys_to_eval:

                pred_answer = preds[k][i]
                if pred_answer is not None:
                    metrics = op.evaluate_response([pred_answer], op[i])
                    if metrics and metrics[0]['correct']:
                        scores[k] += 1

        print(f"\n=== Results for model={model}, shot={shot_type} ===")
        for k in keys_to_eval:

            if len(preds[k]) == 0:
                print(f"{k}: No entries found, skip.")
                continue

            acc = scores[k] / total_data_len if total_data_len else 0
            print(f"{k} -> Accuracy: {acc:.4f}")

        print("-" * 50)

In [None]:
for model in ['gpt-4o-mini', 'gpt-4o', 'llama']:
    print(model)
    for shot_type in ['few']:

        file_path = f"{config.output_dir}/musr_efficiently/{model}/musr_efficiently_few_{shot_type}.jsonl"
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f]
        
        preds = {k: [] for k in keys_to_eval}

        for test_idx, entry in enumerate(data):

            model_outputs = entry.get('model_outputs', [])
            results_info = entry.get('results', [])

            best_total_diff = float('-inf')
            best_total_pred = None

            best_diff = float('-inf')
            best_pred = None

            for i, mo in enumerate(model_outputs):
                pred = ta.evaluate_response([mo], ta[test_idx])[0]['model_answer']
                
                r = results_info[i] if i < len(results_info) else {}
                topk_probs_list  = r.get("topk_probs", [])


                total_diff = np.mean([top1 - top2 for (top1, top2) in topk_probs_list]) \
                            if topk_probs_list else float('-inf')


                _, _, _, _, diff = map_pred_to_prob_diff(
                    pred,
                    r.get("answer_labels", []),
                    topk_probs_list,
                    tokenizer
                )

                if total_diff > best_total_diff:
                    best_total_diff = total_diff
                    best_total_pred = mo


                if diff > best_diff:
                    best_diff = diff
                    best_pred = mo

            preds["total_max_prob_diff_output"].append(best_total_pred)
            preds["max_prob_diff_output"].append(best_pred)

        total_data_len = len(data)
        
        scores = {k: 0 for k in keys_to_eval}
    
        for i, entry in enumerate(data):
            if 'entry' not in entry:
                continue
            
            for k in keys_to_eval:

                pred_answer = preds[k][i]
                if pred_answer is not None:
                    metrics = ta.evaluate_response([pred_answer], ta[i])
                    if metrics and metrics[0]['correct']:
                        scores[k] += 1


        print(f"\n=== Results for model={model}, shot={shot_type} ===")
        for k in keys_to_eval:

            if len(preds[k]) == 0:
                print(f"{k}: No entries found, skip.")
                continue

            acc = scores[k] / total_data_len if total_data_len else 0
            print(f"{k} -> Accuracy: {acc:.4f}")

        print("-" * 50)