In [10]:
import pandas as pd
import numpy as np
import json, os

pd.set_option("display.precision", 3)

results_dir = "./.archive/results/stage_1_rewrites"

use_tasks = [
        "gsm8k",
        "tracking_shuffled_objects_three_objects",
        "coinflip_eight",
        "prontoqa",
        "logiqa-en",
        "lsat-ar"
    ]

use_dirs = [
    # "PromptWithAnswerExtraction/gpt35_cot_instruct__baseline",
    "SolveValidateRewrite/gpt35_cot_instruct__rewrite_T0",
    "SolveValidateRewrite/gpt35_cot_instruct__rewrite_T07",
    "SolveValidateRewrite/gpt35_validate_framing__rewrite_T07",
    "SolveValidateRewrite/gpt35_validate_framing_rephrase_1__T07",
    "SolveValidateRewrite/gpt35_validate_framing_rephrase_2__T07",
    "SolveValidateRewrite/gpt35_validate_pattern__stg3",
    "SolveValidateRewrite/gpt35_validate_rewrite_pattern__stg3",
    # "PromptWithAnswerExtraction/gpt35_cot_instruct_reframed__baseline",
    # "PromptWithAnswerExtraction/gpt35_cot_instruct_reframed__baseline"
]

filepath = os.path.join(results_dir, "prontoqa","SolveValidateRewrite/gpt35_cot_instruct__rewrite_T0", "results.json")

In [11]:
def extract_metrics(json_examples):
    dfs=[]
    for ex in json_examples:
        ex_dict = {}
        ex_dict['n_responses'] = ex['response_count']
        ex_dict['true_answer'] = ex['true_answer']
        ex_dict['predicted_answer'] = ex['predicted_answer']
        ex_dict['correct'] = ex['true_answer'] == ex['predicted_answer']
        for i,res in enumerate(ex['response_pairs']):
            ex_dict[f"answer_{i}"] = res['answer']
        ex_dict['answer_0_correct'] = ex_dict['answer_0'] == ex_dict['true_answer']
        dfs.append(pd.DataFrame(ex_dict,index=[ex['example_idx']]))
    df = pd.concat(dfs)
    
    # Create 'answer_1' column if it does not exist yet
    if 'answer_1' not in df.columns:
        df['answer_1'] = np.nan

    rewrite_perc = (~df['answer_1'].isna()).mean()

    # Get first rewrite decision accuracy
    rewrite_acc = pd.pivot_table(df, index='n_responses', columns='answer_0_correct', aggfunc='count', values='answer_0')
    rewrite_acc = rewrite_acc.div(rewrite_acc.sum(axis=0), axis=1)
    correct_pred_rewrites = 1-rewrite_acc[True].values[0]
    wrong_pred_rewrites = 1-rewrite_acc[False].values[0]
    
    # Percent of all first rewrite decision that were of wrong answers
    good_rewrites_perc = 1-df[~df['answer_1'].isna()]['answer_0_correct'].mean()



    # Get rewrite correction accuracy
    df_rewrites = df[~df['answer_1'].isna()]
    df_rewrite_conversions = df_rewrites.groupby(['answer_0_correct','correct']).size()
    correct_to_wrong_perc = df_rewrite_conversions[True][False] / df_rewrite_conversions[True].sum()
    wrong_to_correct_perc = df_rewrite_conversions[False][True] / df_rewrite_conversions[False].sum()    

    return {
        'pre_rewrite_acc': df['answer_0_correct'].mean(),
        'rewrite': rewrite_perc,
        'good_rewrites_share': good_rewrites_perc,
        'correct_pred_rewrites': correct_pred_rewrites,
        'wrong_pred_rewrites': wrong_pred_rewrites,
        'correct_to_wrong': correct_to_wrong_perc,
        'wrong_to_correct': wrong_to_correct_perc,
    }
    

In [12]:
dfs = []
for task in use_tasks:
    for dir in use_dirs:
        print(f"Task: {task}, Run: {os.path.join(results_dir, task, dir, 'results.json')}")
        # try to open results.json from each directory, if it exists
        try:
            with open(os.path.join(results_dir, task, dir, "results.json"), "r") as f:
                data_dict = json.load(f)
                metrics_dict = {
                    "Task": data_dict["Task"],
                    "Run": dir.split("/")[-1],
                    "N examples": data_dict["Number of examples"],
                    # "Number of correct": data_dict["Number of correct"],
                    "Accuracy": data_dict["Accuracy"],
                }
                metrics_dict |= extract_metrics(data_dict['examples'])
                dfs.append(pd.DataFrame(metrics_dict, index=[f"{task}_{dir.split('/')[-1]}"]))
        except FileNotFoundError:
            continue
df = pd.concat(dfs)
df = df.set_index(["Task", "Run"])
df

Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_cot_instruct__rewrite_T0\results.json
Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_cot_instruct__rewrite_T07\results.json
Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_validate_framing__rewrite_T07\results.json
Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_validate_framing_rephrase_1__T07\results.json
Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_validate_framing_rephrase_2__T07\results.json
Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_validate_pattern__stg3\results.json
Task: gsm8k, Run: ./.archive/results/stage_1_rewrites\gsm8k\SolveValidateRewrite/gpt35_validate_rewrite_pattern__stg3\results.json
Task: tracking_shuffled_objects_three_objects, Run: ./.archive/results/stage_1_rewrites\tracking_sh

Unnamed: 0_level_0,Unnamed: 1_level_0,N examples,Accuracy,pre_rewrite_acc,rewrite,good_rewrites_share,correct_pred_rewrites,wrong_pred_rewrites,correct_to_wrong,wrong_to_correct
Task,Run,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
tracking_shuffled_objects/three_objects,gpt35_cot_instruct__rewrite_T0,250,0.596,0.636,0.5,0.368,0.497,0.505,0.203,0.13
tracking_shuffled_objects/three_objects,gpt35_cot_instruct__rewrite_T07,250,0.572,0.56,0.496,0.484,0.457,0.545,0.203,0.267
tracking_shuffled_objects/three_objects,gpt35_validate_framing__rewrite_T07,250,0.612,0.596,0.492,0.512,0.403,0.624,0.233,0.286
tracking_shuffled_objects/three_objects,gpt35_validate_framing_rephrase_1__T07,250,0.54,0.524,0.508,0.551,0.435,0.588,0.193,0.214
tracking_shuffled_objects/three_objects,gpt35_validate_framing_rephrase_2__T07,250,0.568,0.552,0.46,0.565,0.362,0.58,0.16,0.185
tracking_shuffled_objects/three_objects,gpt35_validate_pattern__stg3,250,0.768,0.812,0.532,0.241,0.498,0.681,0.149,0.125
tracking_shuffled_objects/three_objects,gpt35_validate_rewrite_pattern__stg3,250,0.668,0.704,0.56,0.314,0.545,0.595,0.208,0.25
prontoqa,gpt35_cot_instruct__rewrite_T0,250,0.828,0.84,0.168,0.119,0.176,0.125,0.135,0.4
prontoqa,gpt35_cot_instruct__rewrite_T07,250,0.844,0.852,0.132,0.061,0.146,0.054,0.097,0.5
prontoqa,gpt35_validate_framing__rewrite_T07,250,0.848,0.852,0.264,0.106,0.277,0.189,0.085,0.571


In [13]:
filepath = os.path.join(results_dir, "prontoqa","SolveValidateRewrite/gpt35_cot_instruct__rewrite_T0", "results.json")
with open(filepath, "r") as f:
    data_dict = json.load(f)
    
total_examples = data_dict['Number of examples']