In [1]:
import sys
import json
sys.path.append('..')
from src.common import *
from src.analysis.model_performances import *
from copy import deepcopy
from helpers import *
import pandas as pd

CONF_KEY = 'wilson'

def latex_table_mods(latex_table):
    return latex_table.replace('{lllllllllllll}','{l|ll|ll|ll|ll||ll|ll}').replace('${None}_{None}$', '---')


In [2]:
answer_type = TRUE_FALSE_ANSWER_TYPE #FREE_ANSWER#
score_key = ACCURACY_SCORE_KEY #F1_SCORE_KEY#'accuracy'
answer_type_ext = tf_answer_type(score_key = score_key)

ids_file_name = 'dataset_ids.test.pruned'  # None
save_main_dir = f'{STATISTICS_PATH}.{ids_file_name}'
stats_all = collect_stats_all(tf_answer_type(score_key = score_key), save_main_dir=save_main_dir)
print(len(stats_all))
plan_lengths = [1,10,19]

save_dir = os.path.join(save_main_dir, 'tables', 'by_models')
os.makedirs(save_dir, exist_ok=True)

# model_prompts_combos = [('small-models', SMALL_MODELS, PROMPT_TYPES), ('big-models', BIG_MODELS, ['few_shot_1', 'few_shot_5'])]
model_prompts_combos = [('all-models', PROMPT_MODEL_NAMES, ['few_shot_1', 'few_shot_5'])]

100%|██████████| 27648/27648 [00:01<00:00, 26451.62it/s]

3767





In [10]:
def to_df(results_all, plan_lengths, answer_type, models=PROMPT_MODEL_NAMES,
          prompt_types = PROMPT_TYPES,
          domain = ALL_DOMAINS_KEY, subs = WITHOUT_RANDOM_SUB):
    
    index = []
    data = []    
    for plan_length in plan_lengths:
        for ramifications in RAMIFICATION_TYPES:
            index.append((plan_length, TO_PRETTY.get(ramifications, ramifications)))
            data_columns = {}
            for model_name in models:
                for prompt_type in prompt_types:
                    res_obj = filter_single_selector(results_all, plan_length, ALL_QUESTION_CATEGORIES_KEY, ramifications, model_name, prompt_type, domain, answer_type, subs)
                    # print(res_obj)
                    if res_obj:
                        mean = res_obj['result']
                        sem = None
                        if res_obj['result_other']:
                            sem = res_obj['result_other'][CONF_KEY]
                        not_corrupted = res_obj['stats']['num_not_corrupted']
                        final_res = (mean, sem, not_corrupted)
                    else:
                        final_res = (None, None, None)
                    final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                    final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                    data_columns[(TO_PRETTY.get(model_name, model_name), TO_PRETTY.get(prompt_type, prompt_type))] = final_res
            data.append(data_columns)
    return pd.DataFrame(data, index = index)

def to_df_by_category(results_all, answer_type,  
                      model_names = PROMPT_MODEL_NAMES,
                      prompt_types= PROMPT_TYPES,
                      ramifications = WITHOUT_RAMIFICATIONS,
                      domain = ALL_DOMAINS_KEY, 
                      subs = WITHOUT_RANDOM_SUB,
                      plan_length=19):

    index = []
    data = []    
    for question_category in QUESTION_CATEGORIES:
        index.append(question_category)
        data_columns = {}
        for model_name in model_names:
            for prompt_type in prompt_types:
                res_obj = filter_single_selector(results_all, plan_length, question_category, ramifications, model_name, prompt_type, domain, answer_type, subs)
                if res_obj:
                    mean = res_obj['result']
                    sem = None
                    if res_obj['result_other']:
                        sem = res_obj['result_other'][CONF_KEY]
                    not_corrupted = res_obj['stats']['num_not_corrupted']
                    final_res = (mean, sem, not_corrupted)
                else:
                    final_res = (None, None, None)
                final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                data_columns[(TO_PRETTY.get(model_name,model_name), TO_PRETTY.get(prompt_type,prompt_type))] = final_res
        data.append(data_columns)
    return pd.DataFrame(data, index = index)

def to_df_by_len_by_category(results_all, answer_type,  
                      model_names = PROMPT_MODEL_NAMES,
                      prompt_types= ['few_shot_1'],
                      ramifications = WITHOUT_RAMIFICATIONS,
                      domain = ALL_DOMAINS_KEY, 
                      subs = WITHOUT_RANDOM_SUB):

    index = []
    data = []    
    for plan_length in PLAN_LENGTHS:
        for question_category in QUESTION_CATEGORIES+[ALL_QUESTION_CATEGORIES_KEY]:
            index.append((plan_length, TO_PRETTY.get(question_category,question_category)))
            data_columns = {}
            for model_name in model_names:
                for prompt_type in prompt_types:
                    res_obj = filter_single_selector(results_all, plan_length, question_category, ramifications, model_name, prompt_type, domain, answer_type, subs)
                    if res_obj:
                        mean = res_obj['result']
                        sem = None
                        if res_obj['result_other']:
                            sem = res_obj['result_other'][CONF_KEY]
                        not_corrupted = res_obj['stats']['num_not_corrupted']
                        final_res = (mean, sem, not_corrupted)
                    else:
                        final_res = (None, None, None)
                    final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                    final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                    data_columns[(TO_PRETTY.get(model_name,model_name), TO_PRETTY.get(prompt_type,prompt_type))] = final_res
            data.append(data_columns)
    return pd.DataFrame(data, index = index)

def df_to_latex_table(df):
        latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
        return latex_table[latex_table.find('\midrule'):].replace("bottomrule", 'crap').replace("\crap", '').replace("\end{tabular}", '')

def assemble_table(results_all, answer_type, domain, score_key=None):
    latex_table_all = ''
    with open('latex_table_template/top') as f:
        latex_table_all += f.read() + '\n'
    latex_table_all += '\n'.join([df_to_latex_table(to_df(results_all, answer_type, plan_length, domain)) for plan_length in PLAN_LENGTHS])
    with open('latex_table_template/bottom') as f:
        latex_table_all += f.read()
    
    caption = f'{answer_type}, {score_key} scores for {domain}'.replace('_', ' ')
    latex_table_all = latex_table_all.replace('REPLACE_CAPTION_KEY', caption)
    
    return latex_table_all


In [11]:
for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for model_save_name, model_names, prompt_types in model_prompts_combos:
        df = to_df(stats_all, plan_lengths, answer_type, prompt_types=prompt_types, models=model_names, subs=subs)
        print(df)
        
        caption_nl = f'performance of {model_save_name} on the test set, {subs}'.replace('_', ' ')
        latex_table = latex_table_mods(to_latex_table(df, caption_nl, label=model_save_name))
        save_key = f'all.{model_save_name}.{subs}'
        with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
            f.write(latex_table)

                 (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
(1, W R)     ${46.25}_{2.38}$  ${36.78}_{2.64}$  ${51.55}_{2.39}$   
(1, W/O R)   ${46.59}_{2.39}$  ${36.24}_{2.64}$   ${51.4}_{2.39}$   
(10, W R)    ${44.76}_{2.37}$  ${34.72}_{2.59}$  ${51.03}_{2.38}$   
(10, W/O R)  ${43.46}_{2.36}$  ${33.31}_{2.58}$  ${50.44}_{2.38}$   
(19, W R)    ${44.97}_{2.41}$  ${29.97}_{2.66}$  ${49.12}_{2.42}$   
(19, W/O R)  ${44.66}_{2.41}$  ${29.79}_{2.68}$  ${48.74}_{2.43}$   

                 (G-7b, FS-5)      (L-7b, FS-1)      (L-7b, FS-5)  \
(1, W R)     ${54.31}_{2.75}$  ${47.73}_{2.41}$  ${53.47}_{3.44}$   
(1, W/O R)   ${52.83}_{2.74}$  ${48.45}_{2.39}$  ${52.72}_{3.44}$   
(10, W R)     ${59.68}_{2.7}$   ${48.7}_{2.41}$  ${55.38}_{3.85}$   
(10, W/O R)  ${57.94}_{2.71}$  ${46.97}_{2.37}$  ${52.53}_{3.89}$   
(19, W R)     ${56.26}_{2.9}$  ${47.03}_{2.46}$   ${52.8}_{5.05}$   
(19, W/O R)    ${56.2}_{2.9}$   ${48.2}_{2.42}$  ${51.21}_{5.07}$   

                (L-13b, FS-1)   

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


                 (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
(1, W R)     ${46.09}_{2.55}$  ${44.32}_{5.06}$  ${49.28}_{2.39}$   
(1, W/O R)   ${46.03}_{2.53}$  ${40.69}_{4.97}$  ${48.63}_{2.39}$   
(10, W R)    ${45.52}_{2.66}$   ${50.0}_{69.3}$  ${50.38}_{2.38}$   
(10, W/O R)  ${44.22}_{2.63}$   ${100.0}_{0.0}$   ${49.7}_{2.38}$   
(19, W R)    ${46.22}_{2.92}$   ${None}_{None}$  ${48.38}_{2.42}$   
(19, W/O R)  ${43.94}_{2.88}$   ${None}_{None}$  ${49.26}_{2.43}$   

                 (G-7b, FS-5)      (L-7b, FS-1)      (L-7b, FS-5)  \
(1, W R)     ${56.86}_{3.01}$  ${47.18}_{2.67}$  ${55.25}_{6.08}$   
(1, W/O R)   ${55.22}_{3.07}$  ${47.63}_{2.64}$   ${53.5}_{6.27}$   
(10, W R)     ${57.31}_{3.2}$  ${45.42}_{2.78}$   ${None}_{None}$   
(10, W/O R)  ${56.73}_{3.25}$  ${45.54}_{2.77}$   ${None}_{None}$   
(19, W R)     ${53.9}_{4.02}$   ${43.8}_{3.18}$   ${None}_{None}$   
(19, W/O R)   ${53.9}_{4.11}$  ${43.69}_{3.22}$   ${None}_{None}$   

                (L-13b, FS-1)   

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


# Plot By Category

In [12]:
plan_length = 19
for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for model_save_name, model_names, prompt_types in model_prompts_combos:
        df2 = to_df_by_category(stats_all, answer_type, model_names=model_names, prompt_types=prompt_types, subs=subs)
        print(df2)
        
        caption_nl = f'performance of {model_save_name} on the test set by categories, {subs}, pl-{plan_length}'
        save_key = f'by_categories.{model_save_name}.{subs}'
        
        latex_table_all = latex_table_mods(to_latex_table(df2, caption_nl, label=save_key))
        with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
            f.write(latex_table_all)

                           (G-2b, FS-1)       (G-2b, FS-5)       (G-7b, FS-1)  \
object_tracking         ${40.57}_{4.2}$   ${28.61}_{4.54}$   ${47.16}_{4.26}$   
fluent_tracking        ${43.11}_{4.31}$   ${20.99}_{4.31}$   ${50.88}_{4.33}$   
state_tracking        ${47.73}_{14.76}$      ${0.0}_{0.0}$  ${59.09}_{14.53}$   
action_executability  ${53.75}_{10.93}$    ${36.0}_{13.3}$  ${51.28}_{11.09}$   
effects                ${49.36}_{5.53}$   ${37.44}_{6.41}$    ${46.3}_{5.54}$   
numerical_reasoning   ${46.25}_{10.93}$    ${40.0}_{12.4}$  ${38.75}_{10.68}$   
hallucination         ${50.63}_{11.02}$  ${59.18}_{13.76}$  ${56.96}_{10.92}$   

                           (G-7b, FS-5)       (L-7b, FS-1)       (L-7b, FS-5)  \
object_tracking        ${52.23}_{5.02}$   ${48.48}_{4.26}$    ${50.0}_{8.34}$   
fluent_tracking        ${53.64}_{5.28}$   ${42.02}_{4.27}$  ${45.65}_{10.18}$   
state_tracking         ${57.89}_{22.2}$  ${54.55}_{14.71}$    ${None}_{None}$   
action_executability   ${54

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


                           (G-2b, FS-1)     (G-2b, FS-5)       (G-7b, FS-1)  \
object_tracking        ${42.82}_{4.83}$  ${None}_{None}$   ${45.63}_{4.26}$   
fluent_tracking        ${41.29}_{5.48}$  ${None}_{None}$   ${50.69}_{4.34}$   
state_tracking          ${None}_{None}$  ${None}_{None}$  ${52.27}_{14.76}$   
action_executability  ${38.71}_{12.12}$  ${None}_{None}$    ${50.0}_{11.1}$   
effects                ${47.21}_{6.41}$  ${None}_{None}$   ${50.79}_{5.52}$   
numerical_reasoning   ${47.83}_{11.79}$  ${None}_{None}$  ${46.25}_{10.93}$   
hallucination         ${53.33}_{12.62}$  ${None}_{None}$  ${58.75}_{10.79}$   

                           (G-7b, FS-5)       (L-7b, FS-1)     (L-7b, FS-5)  \
object_tracking        ${50.87}_{6.46}$   ${40.47}_{5.21}$  ${None}_{None}$   
fluent_tracking         ${49.62}_{8.5}$    ${36.0}_{6.27}$  ${None}_{None}$   
state_tracking          ${None}_{None}$    ${None}_{None}$  ${None}_{None}$   
action_executability   ${61.9}_{20.77}$  ${43.75}_{

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


# By Category By Length

In [13]:
subs = WITHOUT_RANDOM_SUB
rams = WITHOUT_RAMIFICATIONS
# for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
#     for rams in [WITHOUT_RAMIFICATIONS, WITH_RAMIFICATIONS]:
model_names = PROMPT_MODEL_NAMES[::-1] + [('g-7b','tuned'),('L-7b','tuned')]
df3 = to_df_by_len_by_category(stats_all, answer_type, model_names=model_names, subs=subs, ramifications=rams)

caption_nl = f'performance of on the test set by categories, {subs}, {rams}'
save_key = f'by_plan_by_categories.{subs}.{rams}'

latex_table_all = latex_table_mods(to_latex_table(df3, caption_nl, label=save_key))
with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
    f.write(latex_table_all)

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


In [14]:
df3

Unnamed: 0,"(gpt-4o, FS-1)","(Gemini, FS-1)","(L-13b, FS-1)","(L-7b, FS-1)","(G-7b, FS-1)","(G-2b, FS-1)","((g-7b, tuned), FS-1)","((L-7b, tuned), FS-1)"
"(1, Obj. Trk.)",${79.62}_{3.45}$,${69.33}_{3.94}$,${55.15}_{4.26}$,${47.14}_{4.27}$,${49.14}_{4.28}$,${40.42}_{4.21}$,${None}_{None}$,${None}_{None}$
"(1, Fl. Trk.)",${83.39}_{3.13}$,${68.75}_{3.9}$,${45.67}_{4.19}$,${42.91}_{4.16}$,${55.51}_{4.18}$,${47.78}_{4.21}$,${None}_{None}$,${None}_{None}$
"(1, St. Trk.)",${70.91}_{12.0}$,${65.45}_{12.57}$,${52.73}_{13.19}$,${47.27}_{13.19}$,${61.82}_{12.84}$,${50.91}_{13.21}$,${None}_{None}$,${None}_{None}$
"(1, Act. Exec.)",${79.75}_{8.86}$,${70.89}_{10.02}$,${55.0}_{10.9}$,${61.25}_{10.68}$,${50.0}_{10.96}$,${58.75}_{10.79}$,${None}_{None}$,${None}_{None}$
"(1, Eff.)",${59.75}_{5.39}$,${59.69}_{5.37}$,${49.06}_{5.48}$,${54.37}_{5.46}$,${46.52}_{5.5}$,${48.57}_{5.52}$,${None}_{None}$,${None}_{None}$
"(1, Num. Reas.)",${55.13}_{11.04}$,${52.5}_{10.94}$,${48.75}_{10.95}$,${48.75}_{10.95}$,${45.0}_{10.9}$,${50.0}_{10.96}$,${None}_{None}$,${None}_{None}$
"(1, Hall.)",${93.67}_{5.37}$,${82.5}_{8.33}$,${56.25}_{10.87}$,${58.75}_{10.79}$,${58.23}_{10.88}$,${52.5}_{10.94}$,${None}_{None}$,${None}_{None}$
"(1, All)",${76.31}_{2.04}$,${67.08}_{2.25}$,${50.59}_{2.39}$,${48.45}_{2.39}$,${51.4}_{2.39}$,${46.59}_{2.39}$,${None}_{None}$,${None}_{None}$
"(10, Obj. Trk.)",${77.26}_{3.53}$,${66.79}_{3.95}$,${55.62}_{4.18}$,${47.88}_{4.2}$,${50.09}_{4.2}$,${40.96}_{4.14}$,${None}_{None}$,${None}_{None}$
"(10, Fl. Trk.)",${81.78}_{3.26}$,${62.36}_{4.08}$,${46.49}_{4.2}$,${39.48}_{4.12}$,${50.75}_{4.24}$,${40.41}_{4.15}$,${None}_{None}$,${None}_{None}$
