In [1]:
import sys
import json
sys.path.append('..')
from src.common import *
from src.analysis.model_performances import *
from copy import deepcopy
from helpers import *
import pandas as pd

EVERYTHING_POSTFIX = 'everything'

TO_PRETTY = {
    WITH_RAMIFICATIONS : 'W R',
    WITHOUT_RAMIFICATIONS : 'W/O R',
    
    'few_shot_1': 'FS-1',
    'few_shot_3': 'FS-3',
    'few_shot_5': 'FS-5',
    
    'gemma-2b': 'G-2b', 
    'gemma-7b': 'G-7b', 
    'llama2-7b-chat': 'L-7b', 
    'llama2-13b-chat': 'L-13b',
    'gemini': 'Gemini'}


In [3]:
answer_type = TRUE_FALSE_ANSWER_TYPE #FREE_ANSWER#
score_key = ACCURACY_SCORE_KEY #F1_SCORE_KEY#'accuracy'
answer_type_ext = tf_answer_type(score_key = score_key)

ids_file_name = 'dataset_ids.test'  # None
save_main_dir = f'{STATISTICS_PATH}.{ids_file_name}'
stats_all = collect_stats_all(tf_answer_type(score_key = score_key), save_main_dir=save_main_dir)
print(len(stats_all))
plan_lengths = [1,10,19]

save_dir = os.path.join(save_main_dir, 'tables', 'by_models')
os.makedirs(save_dir, exist_ok=True)

# model_prompts_combos = [('small-models', SMALL_MODELS, PROMPT_TYPES), ('big-models', BIG_MODELS, ['few_shot_1', 'few_shot_5'])]
model_prompts_combos = [('all-models', PROMPT_MODEL_NAMES, ['few_shot_1', 'few_shot_5'])]

100%|██████████| 31104/31104 [00:05<00:00, 5686.22it/s] 

12758





In [2]:
def to_df(results_all, plan_lengths, answer_type, models=PROMPT_MODEL_NAMES,
          prompt_types = PROMPT_TYPES,
          domain = ALL_DOMAINS_KEY, subs = WITHOUT_RANDOM_SUB):
    
    index = []
    data = []    
    for plan_length in plan_lengths:
        for ramifications in RAMIFICATION_TYPES:
            index.append((plan_length, TO_PRETTY.get(ramifications, ramifications)))
            data_columns = {}
            for model_name in models:
                for prompt_type in prompt_types:
                    res_obj = filter_single_selector(results_all, plan_length, ALL_QUESTION_CATEGORIES_KEY, ramifications, model_name, prompt_type, domain, answer_type, subs)
                    # print(res_obj)
                    if res_obj:
                        mean = res_obj['result']
                        sem = None
                        if res_obj['result_other']:
                            sem = res_obj['result_other']['sem']
                        not_corrupted = res_obj['stats']['num_not_corrupted']
                        final_res = (mean, sem, not_corrupted)
                    else:
                        final_res = (None, None, None)
                    final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                    final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                    data_columns[(TO_PRETTY.get(model_name, model_name), TO_PRETTY.get(prompt_type, prompt_type))] = final_res
            data.append(data_columns)
    return pd.DataFrame(data, index = index)

def to_df_by_category(results_all, answer_type,  
                      model_names = PROMPT_MODEL_NAMES,
                      prompt_types= PROMPT_TYPES,
                      ramifications = WITHOUT_RAMIFICATIONS,
                      domain = ALL_DOMAINS_KEY, 
                      subs = WITHOUT_RANDOM_SUB,
                      plan_length=19):

    index = []
    data = []    
    for question_category in QUESTION_CATEGORIES[:-1]:
        index.append(question_category)
        data_columns = {}
        for model_name in model_names:
            for prompt_type in prompt_types:
                res_obj = filter_single_selector(results_all, plan_length, question_category, ramifications, model_name, prompt_type, domain, answer_type, subs)
                if res_obj:
                    mean = res_obj['result']
                    sem = None
                    if res_obj['result_other']:
                        sem = res_obj['result_other']['sem']
                    not_corrupted = res_obj['stats']['num_not_corrupted']
                    final_res = (mean, sem, not_corrupted)
                else:
                    final_res = (None, None, None)
                final_res = tuple([round(v*100, 2) if v else v for v in final_res ])
                final_res = '${'+str(final_res[0])+'}_{'+str(final_res[1])+'}$'
                data_columns[(TO_PRETTY.get(model_name,model_name), TO_PRETTY.get(prompt_type,prompt_type))] = final_res
        data.append(data_columns)
    return pd.DataFrame(data, index = index)

def df_to_latex_table(df):
        latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
        return latex_table[latex_table.find('\midrule'):].replace("bottomrule", 'crap').replace("\crap", '').replace("\end{tabular}", '')

def assemble_table(results_all, answer_type, domain, score_key=None):
    latex_table_all = ''
    with open('latex_table_template/top') as f:
        latex_table_all += f.read() + '\n'
    latex_table_all += '\n'.join([df_to_latex_table(to_df(results_all, answer_type, plan_length, domain)) for plan_length in PLAN_LENGTHS])
    with open('latex_table_template/bottom') as f:
        latex_table_all += f.read()
    
    caption = f'{answer_type}, {score_key} scores for {domain}'.replace('_', ' ')
    latex_table_all = latex_table_all.replace('REPLACE_CAPTION_KEY', caption)
    
    return latex_table_all


In [6]:
for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for model_save_name, model_names, prompt_types in model_prompts_combos:
        df = to_df(stats_all, plan_lengths, answer_type, prompt_types=prompt_types, models=model_names, subs=subs)
        print(df)
        
        caption_nl = f'performance of {model_save_name} on the test set, {subs}'.replace('_', ' ')
        latex_table = to_latex_table(df, caption_nl, label=model_save_name)
        save_key = f'all.{model_save_name}.{subs}'
        with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
            f.write(latex_table)

                 (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
(1, W R)     ${46.62}_{1.14}$  ${36.21}_{1.23}$  ${49.44}_{3.75}$   
(1, W/O R)   ${46.24}_{1.18}$  ${36.46}_{1.23}$  ${47.74}_{3.54}$   
(10, W R)    ${45.24}_{1.14}$  ${33.27}_{1.21}$   ${50.0}_{3.81}$   
(10, W/O R)  ${43.65}_{1.17}$   ${33.16}_{1.2}$  ${46.39}_{3.58}$   
(19, W R)    ${45.17}_{1.15}$  ${28.22}_{1.21}$   ${50.0}_{3.71}$   
(19, W/O R)  ${44.66}_{1.18}$  ${27.93}_{1.21}$   ${45.5}_{3.43}$   

                 (G-7b, FS-5)      (L-7b, FS-1)      (L-7b, FS-5)  \
(1, W R)     ${54.55}_{4.01}$  ${46.84}_{1.69}$  ${53.23}_{1.64}$   
(1, W/O R)    ${None}_{None}$  ${48.48}_{1.18}$   ${52.8}_{1.62}$   
(10, W R)    ${54.35}_{4.24}$  ${47.22}_{1.76}$   ${54.7}_{1.88}$   
(10, W/O R)   ${None}_{None}$  ${46.46}_{1.17}$  ${52.64}_{1.89}$   
(19, W R)    ${54.48}_{4.14}$   ${48.89}_{1.8}$  ${52.66}_{2.45}$   
(19, W/O R)   ${None}_{None}$  ${47.47}_{1.18}$  ${50.24}_{2.45}$   

                (L-13b, FS-1)   

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


                 (G-2b, FS-1)      (G-2b, FS-5)      (G-7b, FS-1)  \
(1, W R)     ${45.67}_{1.27}$  ${44.06}_{2.47}$  ${49.19}_{1.18}$   
(1, W/O R)    ${45.7}_{1.26}$  ${40.85}_{2.38}$  ${48.33}_{1.24}$   
(10, W R)    ${45.32}_{1.32}$  ${50.0}_{35.36}$  ${50.69}_{1.17}$   
(10, W/O R)   ${44.05}_{1.3}$   ${100.0}_{0.0}$  ${49.45}_{1.23}$   
(19, W R)    ${46.08}_{1.45}$   ${None}_{None}$  ${48.43}_{1.18}$   
(19, W/O R)  ${43.84}_{1.43}$   ${None}_{None}$  ${48.33}_{1.24}$   

                 (G-7b, FS-5)      (L-7b, FS-1)      (L-7b, FS-5)  \
(1, W R)     ${56.94}_{1.42}$  ${47.28}_{1.33}$  ${56.12}_{2.98}$   
(1, W/O R)   ${55.32}_{1.43}$  ${47.72}_{1.31}$  ${53.48}_{3.02}$   
(10, W R)    ${57.71}_{1.53}$  ${44.99}_{1.38}$   ${None}_{None}$   
(10, W/O R)  ${56.55}_{1.54}$  ${45.16}_{1.37}$   ${None}_{None}$   
(19, W R)    ${53.82}_{1.93}$  ${44.13}_{1.59}$   ${None}_{None}$   
(19, W/O R)  ${53.31}_{1.96}$  ${43.55}_{1.61}$   ${None}_{None}$   

                (L-13b, FS-1)   

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


# Plot By Category

In [7]:
plan_length = 19
for subs in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for model_save_name, model_names, prompt_types in model_prompts_combos:
        df2 = to_df_by_category(stats_all, answer_type, model_names=model_names, prompt_types=prompt_types, subs=subs)
        print(df2)
        
        caption_nl = f'performance of {model_save_name} on the test set by categories, {subs}, pl-{plan_length}'
        save_key = f'by_categories.{model_save_name}.{subs}'
        
        latex_table_all = to_latex_table(df2, caption_nl, label=save_key)
        with open(os.path.join(save_dir, f'{save_key}.tex'), 'w') as f:
            f.write(latex_table_all)

                          (G-2b, FS-1)      (G-2b, FS-5)       (G-7b, FS-1)  \
object_tracking       ${40.38}_{2.13}$  ${28.41}_{2.15}$    ${40.3}_{5.99}$   
fluent_tracking        ${43.2}_{1.97}$  ${19.58}_{1.82}$   ${49.28}_{6.02}$   
state_tracking        ${47.73}_{7.53}$     ${0.0}_{0.0}$    ${50.0}_{25.0}$   
action_executability  ${53.66}_{5.51}$  ${32.79}_{6.01}$  ${88.89}_{10.48}$   
effects               ${49.54}_{2.76}$  ${35.29}_{2.99}$   ${39.58}_{7.06}$   
numerical_reasoning   ${45.78}_{5.47}$  ${38.24}_{5.89}$   ${50.0}_{17.68}$   
hallucination         ${52.44}_{5.52}$  ${52.63}_{6.61}$  ${33.33}_{19.25}$   

                         (G-7b, FS-5)      (L-7b, FS-1)       (L-7b, FS-5)  \
object_tracking       ${None}_{None}$   ${48.3}_{2.17}$   ${49.65}_{4.18}$   
fluent_tracking       ${None}_{None}$   ${41.3}_{1.96}$   ${45.67}_{4.42}$   
state_tracking        ${None}_{None}$  ${54.55}_{7.51}$    ${None}_{None}$   
action_executability  ${None}_{None}$  ${46.34}_{5.51}$

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)


                          (G-2b, FS-1)     (G-2b, FS-5)      (G-7b, FS-1)  \
object_tracking       ${42.82}_{2.46}$  ${None}_{None}$   ${44.4}_{2.26}$   
fluent_tracking       ${41.44}_{2.55}$  ${None}_{None}$   ${48.76}_{2.1}$   
state_tracking         ${None}_{None}$  ${None}_{None}$  ${55.17}_{9.23}$   
action_executability  ${38.71}_{6.19}$  ${None}_{None}$  ${48.68}_{5.73}$   
effects               ${47.21}_{3.27}$  ${None}_{None}$  ${49.84}_{2.85}$   
numerical_reasoning   ${47.83}_{6.01}$  ${None}_{None}$   ${48.1}_{5.62}$   
hallucination         ${53.33}_{6.44}$  ${None}_{None}$  ${60.76}_{5.49}$   

                          (G-7b, FS-5)      (L-7b, FS-1)     (L-7b, FS-5)  \
object_tracking       ${51.41}_{3.17}$  ${40.47}_{2.66}$  ${None}_{None}$   
fluent_tracking       ${48.31}_{3.75}$   ${36.7}_{2.95}$  ${None}_{None}$   
state_tracking         ${None}_{None}$   ${None}_{None}$  ${None}_{None}$   
action_executability  ${57.69}_{9.69}$  ${43.75}_{7.16}$  ${None}_{None}$  

  latex_table = df.to_latex(index=True, formatters={"name": str.upper}, float_format="{:.2f}".format)
