In [38]:
import pandas as pd
import numpy as np
import json, os

pd.set_option("display.precision", 3)

# results_dir = "./.archive/results/stage_1_rewrites"
# results_dir = "./.archive/results/stage_4_redos"
results_dir = "./.archive/results/stage_6_samples_final"

label_mapping_file = "./results/experiment_label_mapping.tsv"
label_mapping = pd.read_csv(label_mapping_file, sep="\t")
use_tasks = [
        "gsm8k",
        # "tracking_shuffled_objects_three_objects",
        "tracking_shuffled_objects_five_objects_multi",
        # "coinflip_eight",
        "prontoqa",
        "logiqa-en",
        "lsat-ar",
        "navigate",
        "aqua-rat",
        "logical_deduction_five_objects_multi"
    ]

task_name_mapping = {
    "gsm8k": "GSM8K",
    "tracking_shuffled_objects/five_objects_multi": "Track5",
    "coinflip_eight": "Coinflip",
    "prontoqa": "ProntoQA",
    "logiqa-en": "LogiQA",
    "lsat-ar": "LSAT",
    "navigate": "Nav",
    "aqua-rat": "AQuA",
    "logical_deduction/five_objects_multi": "Deduct5"
}

use_dirs = [
    # "SolveValidateRewrite/gpt35_all_instruct_structured__stg6",
    # "SolveValidateRewrite/gpt35_all_instruct_structured_T07__stg6",
    # "SolveValidateRewrite/gpt35_all_pattern_structured__stg6",
    "SolveValidateRewrite/gpt35_all_instruct__stg4",
    # "SolveValidateRewrite/gpt35_all_instruct__stg4_T07",
    "SolveValidateRewrite/gpt35_all_instruct_structured__stg6",
    # "SolveValidateRewrite/gpt35_all_instruct_structured__stg6_T07",
    # "SolveValidateRewrite/gpt35_all_pattern__stg4"
    # "SolveValidateRewrite/gpt35_cot_instruct__rewrite_T07",
    # "SolveValidateRewrite/gpt35_validate_framing__rewrite_T07",
    # "SolveValidateRewrite/gpt35_validate_framing_rephrase_1__T07",
    # "SolveValidateRewrite/gpt35_validate_framing_rephrase_2__T07",
    # "SolveValidateRewrite/gpt35_validate_pattern__stg3",
    # "SolveValidateRewrite/gpt35_validate_rewrite_pattern__stg3",
    # "PromptWithAnswerExtraction/gpt35_cot_instruct_reframed__baseline",
    # "PromptWithAnswerExtraction/gpt35_cot_instruct_reframed__baseline"
]

filepath = os.path.join(results_dir, "prontoqa","SolveValidateRewrite/gpt35_cot_instruct__rewrite_T0", "results.json")

In [121]:
def extract_metrics(json_examples):
    dfs=[]
    for ex in json_examples:
        ex_dict = {}
        ex_dict['n_responses'] = ex['response_count']
        ex_dict['true_answer'] = ex['true_answer']
        ex_dict['predicted_answer'] = ex['predicted_answer']
        ex_dict['correct'] = ex['true_answer'] == ex['predicted_answer']
        for i,res in enumerate(ex['response_pairs']):
            ex_dict[f"answer_{i}"] = res['answer']
        ex_dict['answer_0_correct'] = ex_dict['answer_0'] == ex_dict['true_answer']
        dfs.append(pd.DataFrame(ex_dict,index=[ex['example_idx']]))
    df = pd.concat(dfs)
    # Create 'answer_1' column if it does not exist yet
    if 'answer_1' not in df.columns:
        df['answer_1'] = np.nan

    total_examples = len(df)
    df['is_rewrite'] = ~df['answer_1'].isna()
    total_rewrites = df['is_rewrite'].sum()
    
    total_originally_correct = df['answer_0_correct'].sum()
    total_originally_incorrect = (~df['answer_0_correct']).sum()
    
    total_incorrect_rewrites = (df['is_rewrite'] & ~df['answer_0_correct']).sum()
    total_correct_rewrites = (df['is_rewrite'] & df['answer_0_correct']).sum()
    

    # Get rewrite correction accuracy
    df_rewrites = df[~df['answer_1'].isna()]
    df_rewrite_conversions = df_rewrites.groupby(['answer_0_correct','correct']).size()
    correct_to_wrong_perc = df_rewrite_conversions[True][False] / df_rewrite_conversions[True].sum()
    wrong_to_correct_perc = df_rewrite_conversions[False][True] / df_rewrite_conversions[False].sum()    

    return {
        'pre_rewrite_acc': df['answer_0_correct'].mean(),
        'Rewrite/Total': total_rewrites/ total_examples, # Percent of all examples that were rewritten
        'Rewrite Incorrect/All Rewrite': total_incorrect_rewrites / total_rewrites, # Percent of all rewrites that were of incorrect answers   
        'Rewrite Correct/All Correct':  total_correct_rewrites / total_originally_correct, # Percent of all correct answers that were rewritten
        'Rewrite Incorrect/All Incorrect': total_incorrect_rewrites / total_originally_incorrect, # Percent of all incorrect answers that were rewritten
        'Correct To Incorrect': correct_to_wrong_perc, # Percent of initially correct rewrites than then become wrong 
        'Incorrect To Correct': wrong_to_correct_perc, # Percent of initially wrong rewrites than then become correct
        'All correct': total_originally_correct,
        'All incorrect' :   total_originally_incorrect,
        'Correct Rewrites': total_correct_rewrites,
        'Incorrect Rewrites': total_incorrect_rewrites,
        'Total Rewrites': total_rewrites,
        'Total Examples': total_examples
      
    }
    

In [122]:
dfs = []
for task in use_tasks:
    for dir in use_dirs:
        print(f"Task: {task}, Run: {os.path.join(results_dir, task, dir, 'results.json')}")
        # try to open results.json from each directory, if it exists
        try:
            with open(os.path.join(results_dir, task, dir, "results.json"), "r") as f:
                data_dict = json.load(f)
                metrics_dict = {
                    "Task": data_dict["Task"],
                    "Run": dir.split("/")[-1],
                    # "N examples": data_dict["Number of examples"],
                    # "Number of correct": data_dict["Number of correct"],
                    "Accuracy": data_dict["Accuracy"],
                }
                metrics_dict |= extract_metrics(data_dict['Examples'])
                dfs.append(pd.DataFrame(metrics_dict, index=[f"{task}_{dir.split('/')[-1]}"]))
        except FileNotFoundError:
            continue
    
df = pd.concat(dfs)
df = df.set_index(["Run","Task"])
df.sort_index(inplace=True) 
df.rename(index=task_name_mapping, columns={'pre_rewrite_acc':'Pre-Rewrite Accuracy'},inplace=True)
# Rename index "Run" using label_mapping (join on label)
df = df.copy().reset_index()
df = df.merge(label_mapping[['label','Experiment']], left_on='Run', right_on='label', how='left')
df.set_index(['Experiment','Task'], inplace=True)
del df['label']
del df['Run']
pd.set_option('display.float_format', '{:.3f}'.format)
df

Task: gsm8k, Run: ./.archive/results/stage_6_samples_final\gsm8k\SolveValidateRewrite/gpt35_all_instruct__stg4\results.json
Task: gsm8k, Run: ./.archive/results/stage_6_samples_final\gsm8k\SolveValidateRewrite/gpt35_all_instruct_structured__stg6\results.json
Task: tracking_shuffled_objects_five_objects_multi, Run: ./.archive/results/stage_6_samples_final\tracking_shuffled_objects_five_objects_multi\SolveValidateRewrite/gpt35_all_instruct__stg4\results.json
Task: tracking_shuffled_objects_five_objects_multi, Run: ./.archive/results/stage_6_samples_final\tracking_shuffled_objects_five_objects_multi\SolveValidateRewrite/gpt35_all_instruct_structured__stg6\results.json
Task: prontoqa, Run: ./.archive/results/stage_6_samples_final\prontoqa\SolveValidateRewrite/gpt35_all_instruct__stg4\results.json
Task: prontoqa, Run: ./.archive/results/stage_6_samples_final\prontoqa\SolveValidateRewrite/gpt35_all_instruct_structured__stg6\results.json
Task: logiqa-en, Run: ./.archive/results/stage_6_sample

Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,Pre-Rewrite Accuracy,Rewrite/Total,Rewrite Incorrect/All Rewrite,Rewrite Correct/All Correct,Rewrite Incorrect/All Incorrect,Correct To Incorrect,Incorrect To Correct,All correct,All incorrect,Correct Rewrites,Incorrect Rewrites,Total Rewrites,Total Examples
Experiment,Task,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
Instruction,AQuA,0.531,0.551,0.587,0.591,0.436,0.772,0.443,0.25,140,114,61,88,149,254
Instruction,GSM8K,0.72,0.743,0.537,0.286,0.516,0.597,0.183,0.304,223,77,115,46,161,300
Instruction,Deduct5,0.46,0.497,0.383,0.504,0.383,0.384,0.316,0.121,149,151,57,58,115,300
Instruction,LogiQA,0.323,0.337,0.52,0.692,0.475,0.543,0.375,0.13,101,199,48,108,156,300
Instruction,LSAT,0.235,0.252,0.439,0.792,0.362,0.465,0.619,0.113,58,172,21,80,101,230
Instruction,Nav,0.67,0.673,0.46,0.406,0.406,0.571,0.256,0.357,202,98,82,56,138,300
Instruction,ProntoQA,0.837,0.873,0.367,0.127,0.366,0.368,0.208,0.643,262,38,96,14,110,300
Instruction,Track5,0.7,0.71,0.21,0.349,0.192,0.253,0.317,0.455,213,87,41,22,63,300
Structured,AQuA,0.594,0.587,0.35,0.674,0.195,0.571,0.31,0.183,149,105,29,60,89,254
Structured,GSM8K,0.745,0.745,0.09,0.667,0.04,0.235,0.167,0.083,149,51,6,12,18,200


In [51]:
pd.set_option('display.precision', 3)
pd.set_option('display.float_format', '{:.1f}'.format)
df*100

Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,Pre-Rewrite Accuracy,Rewrite/Total,Rewrite Incorrect/All Rewrite,Rewrite Correct/All Correct,Rewrite Incorrect/All Incorrect,Correct To Incorrect,Incorrect To Correct
Experiment,Task,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
Instruction,AQuA,53.1,55.1,58.7,59.1,43.6,77.2,44.3,25.0
Instruction,GSM8K,72.0,74.3,53.7,28.6,51.6,59.7,18.3,30.4
Instruction,Deduct5,46.0,49.7,38.3,50.4,38.3,38.4,31.6,12.1
Instruction,LogiQA,32.3,33.7,52.0,69.2,47.5,54.3,37.5,13.0
Instruction,LSAT,23.5,25.2,43.9,79.2,36.2,46.5,61.9,11.2
Instruction,Nav,67.0,67.3,46.0,40.6,40.6,57.1,25.6,35.7
Instruction,ProntoQA,83.7,87.3,36.7,12.7,36.6,36.8,20.8,64.3
Instruction,Track5,70.0,71.0,21.0,34.9,19.2,25.3,31.7,45.5
Structured,AQuA,59.4,58.7,35.0,67.4,19.5,57.1,31.0,18.3
Structured,GSM8K,74.5,74.5,9.0,66.7,4.0,23.5,16.7,8.3


In [79]:
structured_df = df[df.index.get_level_values(0).str.contains("Structured")]
# REmove Experiment index
structured_df.index = structured_df.index.droplevel(0)
structured_df = (structured_df*100).T
print(structured_df.to_latex(float_format="{:0.1f}".format))
structured_df

\begin{tabular}{lrrrrrrrr}
\toprule
Task & AQuA & GSM8K & Deduct5 & LogiQA & LSAT & Nav & ProntoQA & Track5 \\
\midrule
Accuracy & 59.4 & 74.5 & 47.0 & 37.0 & 24.5 & 70.0 & 89.5 & 69.0 \\
Pre-Rewrite Accuracy & 58.7 & 74.5 & 50.0 & 36.0 & 26.0 & 67.0 & 91.5 & 70.0 \\
Rewrite/Total & 35.0 & 9.0 & 17.0 & 16.5 & 19.0 & 56.0 & 13.0 & 16.0 \\
Rewrite Incorrect/All Rewrite & 67.4 & 66.7 & 51.0 & 78.8 & 65.8 & 35.1 & 11.5 & 34.4 \\
Rewrite Correct/All Correct & 19.5 & 4.0 & 16.7 & 9.7 & 25.0 & 54.2 & 12.6 & 15.0 \\
Rewrite Incorrect/All Incorrect & 57.1 & 23.5 & 17.3 & 20.3 & 16.9 & 59.6 & 17.6 & 18.3 \\
Correct To Incorrect & 31.0 & 16.7 & 52.0 & 28.6 & 61.5 & 14.7 & 26.1 & 38.1 \\
Incorrect To Correct & 18.3 & 8.3 & 15.4 & 15.4 & 20.0 & 42.4 & 66.7 & 54.5 \\
\bottomrule
\end{tabular}



Task,AQuA,GSM8K,Deduct5,LogiQA,LSAT,Nav,ProntoQA,Track5
Accuracy,59.4,74.5,47.0,37.0,24.5,70.0,89.5,69.0
Pre-Rewrite Accuracy,58.7,74.5,50.0,36.0,26.0,67.0,91.5,70.0
Rewrite/Total,35.0,9.0,17.0,16.5,19.0,56.0,13.0,16.0
Rewrite Incorrect/All Rewrite,67.4,66.7,51.0,78.8,65.8,35.1,11.5,34.4
Rewrite Correct/All Correct,19.5,4.0,16.7,9.7,25.0,54.2,12.6,15.0
Rewrite Incorrect/All Incorrect,57.1,23.5,17.3,20.3,16.9,59.6,17.6,18.3
Correct To Incorrect,31.0,16.7,52.0,28.6,61.5,14.7,26.1,38.1
Incorrect To Correct,18.3,8.3,15.4,15.4,20.0,42.4,66.7,54.5


In [80]:
instruction_df = df[df.index.get_level_values(0).str.contains("Instruction")]
# REmove Experiment index
instruction_df.index = instruction_df.index.droplevel(0)
instruction_df = (instruction_df*100).T
print((instruction_df*100).T.to_latex(float_format="{:0.1f}".format))

\begin{tabular}{lrrrrrrrr}
\toprule
 & Accuracy & Pre-Rewrite Accuracy & Rewrite/Total & Rewrite Incorrect/All Rewrite & Rewrite Correct/All Correct & Rewrite Incorrect/All Incorrect & Correct To Incorrect & Incorrect To Correct \\
Task &  &  &  &  &  &  &  &  \\
\midrule
AQuA & 5315.0 & 5511.8 & 5866.1 & 5906.0 & 4357.1 & 7719.3 & 4426.2 & 2500.0 \\
GSM8K & 7200.0 & 7433.3 & 5366.7 & 2857.1 & 5157.0 & 5974.0 & 1826.1 & 3043.5 \\
Deduct5 & 4600.0 & 4966.7 & 3833.3 & 5043.5 & 3825.5 & 3841.1 & 3157.9 & 1206.9 \\
LogiQA & 3233.3 & 3366.7 & 5200.0 & 6923.1 & 4752.5 & 5427.1 & 3750.0 & 1296.3 \\
LSAT & 2347.8 & 2521.7 & 4391.3 & 7920.8 & 3620.7 & 4651.2 & 6190.5 & 1125.0 \\
Nav & 6700.0 & 6733.3 & 4600.0 & 4058.0 & 4059.4 & 5714.3 & 2561.0 & 3571.4 \\
ProntoQA & 8366.7 & 8733.3 & 3666.7 & 1272.7 & 3664.1 & 3684.2 & 2083.3 & 6428.6 \\
Track5 & 7000.0 & 7100.0 & 2100.0 & 3492.1 & 1924.9 & 2528.7 & 3170.7 & 4545.5 \\
\bottomrule
\end{tabular}



In [94]:
combined_df = pd.concat([structured_df, instruction_df], axis=0, keys=['Structured', 'Instruction'])
# combined_df.swaplevel(0, 1, axis=0).sort_index(axis=0)
custom_order = {label: idx for idx, label in enumerate(list(instruction_df.index))}
combined_df = combined_df.swaplevel(0, 1, axis=0).sort_index(axis=0)

print(combined_df.to_latex(float_format="{:0.1f}".format))
combined_df

\begin{tabular}{llrrrrrrrr}
\toprule
 & Task & AQuA & GSM8K & Deduct5 & LogiQA & LSAT & Nav & ProntoQA & Track5 \\
\midrule
\multirow[t]{2}{*}{Accuracy} & Instruction & 53.1 & 72.0 & 46.0 & 32.3 & 23.5 & 67.0 & 83.7 & 70.0 \\
 & Structured & 59.4 & 74.5 & 47.0 & 37.0 & 24.5 & 70.0 & 89.5 & 69.0 \\
\cline{1-10}
\multirow[t]{2}{*}{Correct To Incorrect} & Instruction & 44.3 & 18.3 & 31.6 & 37.5 & 61.9 & 25.6 & 20.8 & 31.7 \\
 & Structured & 31.0 & 16.7 & 52.0 & 28.6 & 61.5 & 14.7 & 26.1 & 38.1 \\
\cline{1-10}
\multirow[t]{2}{*}{Incorrect To Correct} & Instruction & 25.0 & 30.4 & 12.1 & 13.0 & 11.2 & 35.7 & 64.3 & 45.5 \\
 & Structured & 18.3 & 8.3 & 15.4 & 15.4 & 20.0 & 42.4 & 66.7 & 54.5 \\
\cline{1-10}
\multirow[t]{2}{*}{Pre-Rewrite Accuracy} & Instruction & 55.1 & 74.3 & 49.7 & 33.7 & 25.2 & 67.3 & 87.3 & 71.0 \\
 & Structured & 58.7 & 74.5 & 50.0 & 36.0 & 26.0 & 67.0 & 91.5 & 70.0 \\
\cline{1-10}
\multirow[t]{2}{*}{Rewrite Correct/All Correct} & Instruction & 43.6 & 51.6 & 38.3 & 47.5

Unnamed: 0,Task,AQuA,GSM8K,Deduct5,LogiQA,LSAT,Nav,ProntoQA,Track5
Accuracy,Instruction,53.1,72.0,46.0,32.3,23.5,67.0,83.7,70.0
Accuracy,Structured,59.4,74.5,47.0,37.0,24.5,70.0,89.5,69.0
Correct To Incorrect,Instruction,44.3,18.3,31.6,37.5,61.9,25.6,20.8,31.7
Correct To Incorrect,Structured,31.0,16.7,52.0,28.6,61.5,14.7,26.1,38.1
Incorrect To Correct,Instruction,25.0,30.4,12.1,13.0,11.2,35.7,64.3,45.5
Incorrect To Correct,Structured,18.3,8.3,15.4,15.4,20.0,42.4,66.7,54.5
Pre-Rewrite Accuracy,Instruction,55.1,74.3,49.7,33.7,25.2,67.3,87.3,71.0
Pre-Rewrite Accuracy,Structured,58.7,74.5,50.0,36.0,26.0,67.0,91.5,70.0
Rewrite Correct/All Correct,Instruction,43.6,51.6,38.3,47.5,36.2,40.6,36.6,19.2
Rewrite Correct/All Correct,Structured,19.5,4.0,16.7,9.7,25.0,54.2,12.6,15.0


In [61]:
df_means = (df.copy()*100).reset_index()
del df_means['Task']
df_means = df_means.groupby('Experiment').mean()
# Transpose table
df_means = df_means.T
df_means

Experiment,Instruction,Structured
Accuracy,56.0,58.9
Pre-Rewrite Accuracy,58.0,59.2
Rewrite/Total,43.8,22.7
Rewrite Incorrect/All Rewrite,46.8,51.3
Rewrite Correct/All Correct,39.2,19.6
Rewrite Incorrect/All Incorrect,49.4,28.8
Correct To Incorrect,34.0,33.6
Incorrect To Correct,29.6,30.1


In [62]:
# Print latex
tbl=df_means.to_latex(float_format="{:0.1f}".format, escape=False)
print(tbl)

\begin{tabular}{lrr}
\toprule
Experiment & Instruction & Structured \\
\midrule
Accuracy & 56.0 & 58.9 \\
Pre-Rewrite Accuracy & 58.0 & 59.2 \\
Rewrite/Total & 43.8 & 22.7 \\
Rewrite Incorrect/All Rewrite & 46.8 & 51.3 \\
Rewrite Correct/All Correct & 39.2 & 19.6 \\
Rewrite Incorrect/All Incorrect & 49.4 & 28.8 \\
Correct To Incorrect & 34.0 & 33.6 \\
Incorrect To Correct & 29.6 & 30.1 \\
\bottomrule
\end{tabular}

