In [8]:
import sys
import json
sys.path.append('..')
from src.common import *
from src.analysis.model_performances import *
from copy import deepcopy
from helpers import *
import pandas as pd


def latex_table_mods(latex_table):
    return latex_table.replace('{lllllllllllll}','{l|ll|ll|ll|ll||ll|ll}').replace('${None}_{None}$', '---')


In [9]:
answer_type = TRUE_FALSE_ANSWER_TYPE #FREE_ANSWER#
score_key = ACCURACY_SCORE_KEY #F1_SCORE_KEY#'accuracy'
answer_type_ext = tf_answer_type(score_key = score_key)

ids_file_name = 'dataset_ids.test.pruned'  # None
save_main_dir = f'{STATISTICS_PATH}.{ids_file_name}'
stats_all = collect_stats_all(tf_answer_type(score_key = score_key), save_main_dir=save_main_dir)
print(len(stats_all))
plan_lengths = [1,10,19]

save_dir = os.path.join(save_main_dir, 'tables', 'by_models')
os.makedirs(save_dir, exist_ok=True)

# model_prompts_combos = [('small-models', SMALL_MODELS, PROMPT_TYPES), ('big-models', BIG_MODELS, ['few_shot_1', 'few_shot_5'])]
model_prompts_combos = [('all-models', PROMPT_MODEL_NAMES, ['few_shot_1', 'few_shot_5'])]

100%|██████████| 27648/27648 [00:04<00:00, 6388.18it/s]


16721


In [10]:
def to_df(results_all, plan_lengths, answer_type, models=PROMPT_MODEL_NAMES,
          prompt_types = PROMPT_TYPES,
          domain = ALL_DOMAINS_KEY, subs = WITHOUT_RANDOM_SUB):
    
    index = []
    data = []    
    for plan_length in plan_lengths:
        for ramifications in RAMIFICATION_TYPES:
            index.append((plan_length, TO_PRETTY.get(ramifications, ramifications)))
            data_columns = {}
            for model_name in models:
                for prompt_type in prompt_types:
                    res_obj = filter_single_selector(results_all, plan_length, ALL_QUESTION_CATEGORIES_KEY, ramifications, model_name, prompt_type, domain, answer_type, subs)
                    # print(res_obj)
                    if res_obj:
                        mean = res_obj['result']
                        sem = None
                        if res_obj['result_other']:
                            sem = res_obj['result_other']['sem']
                        not_corrupted = res_obj['stats']['num_not_corrupted']
                        final_res = (mean, sem, not_corrupted)
                    else:
                        final_res = (None, None, None)
                    final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                    final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                    data_columns[(TO_PRETTY.get(model_name, model_name), TO_PRETTY.get(prompt_type, prompt_type))] = final_res
            data.append(data_columns)
    return pd.DataFrame(data, index = index)

def to_df_by_category(results_all, answer_type,  
                      model_names = PROMPT_MODEL_NAMES,
                      prompt_types= PROMPT_TYPES,
                      ramifications = WITHOUT_RAMIFICATIONS,
                      domain = ALL_DOMAINS_KEY, 
                      subs = WITHOUT_RANDOM_SUB,
                      plan_length=19):

    index = []
    data = []    
    for question_category in QUESTION_CATEGORIES:
        index.append(question_category)
        data_columns = {}
        for model_name in model_names:
            for prompt_type in prompt_types:
                res_obj = filter_single_selector(results_all, plan_length, question_category, ramifications, model_name, prompt_type, domain, answer_type, subs)
                if res_obj:
                    mean = res_obj['result']
                    sem = None
                    if res_obj['result_other']:
                        sem = res_obj['result_other']['sem']
                    not_corrupted = res_obj['stats']['num_not_corrupted']
                    final_res = (mean, sem, not_corrupted)
                else:
                    final_res = (None, None, None)
                final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                data_columns[(TO_PRETTY.get(model_name,model_name), TO_PRETTY.get(prompt_type,prompt_type))] = final_res
        data.append(data_columns)
    return pd.DataFrame(data, index = index)

def df_to_latex_table(df):
        latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
        return latex_table[latex_table.find('\midrule'):].replace("bottomrule", 'crap').replace("\crap", '').replace("\end{tabular}", '')

def assemble_table(results_all, answer_type, domain, score_key=None):
    latex_table_all = ''
    with open('latex_table_template/top') as f:
        latex_table_all += f.read() + '\n'
    latex_table_all += '\n'.join([df_to_latex_table(to_df(results_all, answer_type, plan_length, domain)) for plan_length in PLAN_LENGTHS])
    with open('latex_table_template/bottom') as f:
        latex_table_all += f.read()
    
    caption = f'{answer_type}, {score_key} scores for {domain}'.replace('_', ' ')
    latex_table_all = latex_table_all.replace('REPLACE_CAPTION_KEY', caption)
    
    return latex_table_all


In [11]:
for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for model_save_name, model_names, prompt_types in model_prompts_combos:
        df = to_df(stats_all, plan_lengths, answer_type, prompt_types=prompt_types, models=model_names, subs=subs)
        print(df)
        
        caption_nl = f'performance of {model_save_name} on the test set, {subs}'.replace('_', ' ')
        latex_table = latex_table_mods(to_latex_table(df, caption_nl, label=model_save_name))
        save_key = f'all.{model_save_name}.{subs}'
        with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
            f.write(latex_table)

                 (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
(1, W R)     ${46.25}_{1.22}$  ${36.78}_{1.34}$  ${51.55}_{1.22}$   
(1, W/O R)   ${46.59}_{1.22}$  ${36.24}_{1.35}$   ${51.4}_{1.22}$   
(10, W R)    ${44.76}_{1.21}$  ${34.72}_{1.32}$  ${51.03}_{1.21}$   
(10, W/O R)  ${43.46}_{1.21}$  ${33.31}_{1.32}$  ${50.44}_{1.22}$   
(19, W R)    ${44.97}_{1.23}$  ${29.97}_{1.36}$  ${49.12}_{1.23}$   
(19, W/O R)  ${44.66}_{1.23}$  ${29.79}_{1.37}$  ${48.74}_{1.24}$   

                 (G-7b, FS-5)      (L-7b, FS-1)      (L-7b, FS-5)  \
(1, W R)      ${54.31}_{1.4}$  ${47.73}_{1.23}$  ${53.47}_{1.75}$   
(1, W/O R)    ${52.83}_{1.4}$  ${48.45}_{1.22}$  ${52.72}_{1.76}$   
(10, W R)    ${59.68}_{1.38}$   ${48.7}_{1.23}$  ${55.38}_{1.96}$   
(10, W/O R)  ${57.94}_{1.38}$  ${46.97}_{1.21}$  ${52.53}_{1.99}$   
(19, W R)    ${56.26}_{1.48}$  ${47.03}_{1.26}$   ${52.8}_{2.58}$   
(19, W/O R)   ${56.2}_{1.48}$   ${48.2}_{1.23}$  ${51.21}_{2.59}$   

                (L-13b, FS-1)   

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


                 (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
(1, W R)      ${46.09}_{1.3}$  ${44.32}_{2.58}$  ${49.28}_{1.22}$   
(1, W/O R)   ${46.03}_{1.29}$  ${40.69}_{2.53}$  ${48.63}_{1.22}$   
(10, W R)    ${45.52}_{1.36}$  ${50.0}_{35.36}$  ${50.38}_{1.21}$   
(10, W/O R)  ${44.22}_{1.34}$   ${100.0}_{0.0}$   ${49.7}_{1.22}$   
(19, W R)    ${46.22}_{1.49}$   ${None}_{None}$  ${48.38}_{1.24}$   
(19, W/O R)  ${43.94}_{1.47}$   ${None}_{None}$  ${49.26}_{1.24}$   

                 (G-7b, FS-5)      (L-7b, FS-1)     (L-7b, FS-5)  \
(1, W R)     ${56.86}_{1.53}$  ${47.18}_{1.36}$  ${55.25}_{3.1}$   
(1, W/O R)   ${55.22}_{1.57}$  ${47.63}_{1.35}$   ${53.5}_{3.2}$   
(10, W R)    ${57.31}_{1.63}$  ${45.42}_{1.42}$  ${None}_{None}$   
(10, W/O R)  ${56.73}_{1.66}$  ${45.54}_{1.41}$  ${None}_{None}$   
(19, W R)     ${53.9}_{2.05}$   ${43.8}_{1.62}$  ${None}_{None}$   
(19, W/O R)    ${53.9}_{2.1}$  ${43.69}_{1.64}$  ${None}_{None}$   

                (L-13b, FS-1)     (L-13

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


# Plot By Category

In [12]:
plan_length = 19
for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for model_save_name, model_names, prompt_types in model_prompts_combos:
        df2 = to_df_by_category(stats_all, answer_type, model_names=model_names, prompt_types=prompt_types, subs=subs)
        print(df2)
        
        caption_nl = f'performance of {model_save_name} on the test set by categories, {subs}, pl-{plan_length}'
        save_key = f'by_categories.{model_save_name}.{subs}'
        
        latex_table_all = latex_table_mods(to_latex_table(df2, caption_nl, label=save_key))
        with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
            f.write(latex_table_all)

                          (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
object_tracking       ${40.57}_{2.14}$  ${28.61}_{2.32}$  ${47.16}_{2.17}$   
fluent_tracking        ${43.11}_{2.2}$   ${20.99}_{2.2}$  ${50.88}_{2.21}$   
state_tracking        ${47.73}_{7.53}$     ${0.0}_{0.0}$  ${59.09}_{7.41}$   
action_executability  ${53.75}_{5.57}$   ${36.0}_{6.79}$  ${51.28}_{5.66}$   
effects               ${49.36}_{2.82}$  ${37.44}_{3.27}$   ${46.3}_{2.83}$   
numerical_reasoning   ${46.25}_{5.57}$   ${40.0}_{6.32}$  ${38.75}_{5.45}$   
hallucination         ${50.63}_{5.62}$  ${59.18}_{7.02}$  ${56.96}_{5.57}$   

                           (G-7b, FS-5)      (L-7b, FS-1)       (L-7b, FS-5)  \
object_tracking        ${52.23}_{2.56}$  ${48.48}_{2.17}$    ${50.0}_{4.26}$   
fluent_tracking        ${53.64}_{2.69}$  ${42.02}_{2.18}$   ${45.65}_{5.19}$   
state_tracking        ${57.89}_{11.33}$  ${54.55}_{7.51}$    ${None}_{None}$   
action_executability    ${54.0}_{7.05}$  ${46.25}_{5.57

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


                          (G-2b, FS-1)     (G-2b, FS-5)      (G-7b, FS-1)  \
object_tracking       ${42.82}_{2.46}$  ${None}_{None}$  ${45.63}_{2.17}$   
fluent_tracking        ${41.29}_{2.8}$  ${None}_{None}$  ${50.69}_{2.22}$   
state_tracking         ${None}_{None}$  ${None}_{None}$  ${52.27}_{7.53}$   
action_executability  ${38.71}_{6.19}$  ${None}_{None}$   ${50.0}_{5.66}$   
effects               ${47.21}_{3.27}$  ${None}_{None}$  ${50.79}_{2.82}$   
numerical_reasoning   ${47.83}_{6.01}$  ${None}_{None}$  ${46.25}_{5.57}$   
hallucination         ${53.33}_{6.44}$  ${None}_{None}$   ${58.75}_{5.5}$   

                          (G-7b, FS-5)      (L-7b, FS-1)     (L-7b, FS-5)  \
object_tracking        ${50.87}_{3.3}$  ${40.47}_{2.66}$  ${None}_{None}$   
fluent_tracking       ${49.62}_{4.34}$    ${36.0}_{3.2}$  ${None}_{None}$   
state_tracking         ${None}_{None}$   ${None}_{None}$  ${None}_{None}$   
action_executability   ${61.9}_{10.6}$  ${43.75}_{7.16}$  ${None}_{None}$  

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
