In [83]:
import sys
import json
sys.path.append('..')
from src.common import *
from src.analysis.model_performances import *
from copy import deepcopy

import pandas as pd

EVERYTHING_POSTFIX = 'everything'

In [84]:
def to_df(results_all, plan_lengths, answer_type, 
          question_category = ALL_QUESTION_CATEGORIES_KEY,  domain = ALL_DOMAINS_KEY, subs = WITHOUT_RANDOM_SUB):
    
    to_pretty = {
        WITH_RAMIFICATIONS : 'R',
        WITHOUT_RAMIFICATIONS : 'No R',
        'gemma-2b': 'G-2b', 
        'llama2-7b-chat': 'L-7b', 
        'llama2-13b-chat': 'L-7b',
        'few_shot_1': 'FS-1',
        'few_shot_3': 'FS-3',
        'few_shot_5': 'FS-5',
        'gemini': 'gemini'}

    index = []
    data = []    
    for plan_length in plan_lengths:
        for ramifications in RAMIFICATION_TYPES:
            index.append((plan_length, to_pretty[ramifications]))
            data_columns = {}
            for model_name in PROMPT_MODEL_NAMES:
                for prompt_type in PROMPT_TYPES:
                    res_obj = filter_single_selector(results_all, plan_length, question_category, ramifications, model_name, prompt_type, domain, answer_type, subs)
                    # print(res_obj)
                    if res_obj:
                        mean = res_obj['result']
                        sem = None
                        if res_obj['result_other']:
                            sem = res_obj['result_other']['sem']
                        not_corrupted = res_obj['stats']['num_not_corrupted']
                        final_res = (mean, sem, not_corrupted)
                    else:
                        final_res = (None, None, None)
                    final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                    final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                    data_columns[(model_name, prompt_type)] = final_res
            data.append(data_columns)
    return pd.DataFrame(data, index = index)

def df_to_latex_table(df):
        latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
        return latex_table[latex_table.find('\midrule'):].replace("bottomrule", 'crap').replace("\crap", '').replace("\end{tabular}", '')

def assemble_table(results_all, answer_type, domain, score_key=None):
    latex_table_all = ''
    with open('latex_table_template/top') as f:
        latex_table_all += f.read() + '\n'
    latex_table_all += '\n'.join([df_to_latex_table(to_df(results_all, answer_type, plan_length, domain)) for plan_length in PLAN_LENGTHS])
    with open('latex_table_template/bottom') as f:
        latex_table_all += f.read()
    
    caption = f'{answer_type}, {score_key} scores for {domain}'.replace('_', ' ')
    latex_table_all = latex_table_all.replace('REPLACE_CAPTION_KEY', caption)
    
    return latex_table_all


In [85]:
answer_type = TRUE_FALSE_ANSWER_TYPE #FREE_ANSWER#
score_key = ACCURACY_SCORE_KEY #F1_SCORE_KEY#'accuracy'
answer_type_ext = tf_answer_type(score_key = score_key)

ids_file_name = 'dataset_ids.test'  # None
save_main_dir = f'{STATISTICS_PATH}.{ids_file_name}'
stats_all = collect_stats_all(tf_answer_type(score_key = score_key), save_main_dir=save_main_dir)
print(len(stats_all))
plan_lengths = [1,10,19]

subs = WITH_RANDOM_SUB
df = to_df(stats_all, plan_lengths, answer_type, subs=subs)

100%|██████████| 34560/34560 [00:03<00:00, 11042.78it/s]


10271


In [86]:
df

Unnamed: 0,"(gemma-2b, few_shot_1)","(gemma-2b, few_shot_3)","(gemma-2b, few_shot_5)","(llama2-7b-chat, few_shot_1)","(llama2-7b-chat, few_shot_3)","(llama2-7b-chat, few_shot_5)","(llama2-13b-chat, few_shot_1)","(llama2-13b-chat, few_shot_3)","(llama2-13b-chat, few_shot_5)","(gemini, few_shot_1)","(gemini, few_shot_3)","(gemini, few_shot_5)"
"(1, R)",${44.2}_{1.29}$,${33.57}_{1.2}$,${31.03}_{1.36}$,${48.05}_{1.31}$,${43.52}_{1.37}$,${58.13}_{2.71}$,${53.14}_{1.35}$,${55.48}_{1.77}$,${54.52}_{2.73}$,${65.17}_{1.24}$,${None}_{None}$,${None}_{None}$
"(1, No R)",${44.87}_{1.3}$,${35.88}_{1.27}$,${32.86}_{1.4}$,${50.49}_{1.28}$,${47.07}_{1.78}$,${50.66}_{2.57}$,${51.93}_{1.3}$,${54.58}_{1.79}$,${51.19}_{2.57}$,${65.59}_{1.26}$,${None}_{None}$,${None}_{None}$
"(10, R)",${42.67}_{1.2}$,${35.04}_{1.18}$,${24.16}_{1.14}$,${45.78}_{1.23}$,${37.76}_{1.41}$,${45.71}_{5.95}$,${52.33}_{1.25}$,${54.75}_{2.02}$,${47.14}_{5.97}$,${61.78}_{1.18}$,${None}_{None}$,${None}_{None}$
"(10, No R)",${43.59}_{1.21}$,${36.65}_{1.19}$,${24.95}_{1.16}$,${47.27}_{1.2}$,${49.0}_{1.96}$,${53.09}_{5.54}$,${52.92}_{1.21}$,${52.93}_{1.99}$,${48.15}_{5.55}$,${60.47}_{1.2}$,${None}_{None}$,${None}_{None}$
"(19, R)",${44.98}_{1.42}$,${30.3}_{1.33}$,${9.25}_{1.0}$,${47.01}_{1.48}$,${24.55}_{2.18}$,${None}_{None}$,${52.81}_{1.5}$,${51.02}_{5.05}$,${None}_{None}$,${61.74}_{1.38}$,${None}_{None}$,${None}_{None}$
"(19, No R)",${45.6}_{1.41}$,${31.77}_{1.32}$,${8.35}_{0.94}$,${44.77}_{1.4}$,${47.31}_{3.86}$,${None}_{None}$,${53.27}_{1.42}$,${46.45}_{4.01}$,${None}_{None}$,${61.13}_{1.41}$,${None}_{None}$,${None}_{None}$


In [87]:
latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
latex_table = latex_table.replace('\\$', '$').replace('\\{', '{').replace('\\}', '}').replace('\\_', '_')
caption_nl = 'performance of models on the test set'
save_key = f'models_test.{subs}'

latex_table_all = r"""
\begin{table*}[h!]
\begin{adjustbox}{width=1.3\textwidth,center}
""" + latex_table + """
\end{adjustbox}
\caption{""" + caption_nl + """}
\end{table*}
"""
os.makedirs(os.path.join(save_main_dir, 'tables'), exist_ok=True)
with open(os.path.join(save_main_dir, 'tables', f'{save_key}.tex'), 'w') as f:
    f.write(latex_table_all)

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
