In [1]:
import os
import pandas as pd

In [7]:
df = pd.read_csv('../best_dev_set_results.csv')
df.head()

Unnamed: 0,model,prompt_type_index,GLEU,Prec,Rec,split,F0.5
0,gpt-4-0613,0-shot_8,0.582,0.6824,0.6359,jfleg-dev,
1,stabilityai/StableBeluga2,0-shot_10,0.563,0.6131,0.6103,jfleg-dev,
2,meta-llama/Llama-2-70b-chat-hf,0-shot_6,0.5,0.5893,0.6054,jfleg-dev,
3,Writer/InstructPalmyra-20b,0-shot_7,0.517,0.5628,0.5269,jfleg-dev,
4,facebook/opt-iml-max-30b,2-shot-Coyne_1,0.506,0.7768,0.4899,jfleg-dev,


In [11]:
# write a latex row using the format: \diagbox[dir=SW]{metric}{prompt_index}
# where each row is a model and each column is a split
def write_latex_row_expanded_subcript(df, model, metric, include_pre_rec=False):
    # metric = 'F0.5' if 'jfleg' not in split else 'GLEU'
    # get the metric value for the model from the df
    metric_value = df[df['model'] == model][metric].values[0]
    
    # check if metric_value is the max value for this dataframe
    max_metric_value = df[metric].max()
    if metric_value == max_metric_value:
        metric_value = '\\textbf{' + '{:.3f}'.format(metric_value) + '}'
    else:
        metric_value = '{:.3f}'.format(metric_value)

    # get the prompt_index for the model from the df
    prompt_type_index = df[df['model'] == model]['prompt_type_index'].values[0]

    prompt_type = prompt_type_index.split('_')[0] 
    prompt_index = prompt_type_index.split('_')[1]

    if prompt_type == '0-shot':
        prompt_type = '0'
        if int(prompt_index) == 10:
            prompt_index = '\\textsc{coyne}'
        elif int(prompt_index) in [6,7]:
            prompt_index = '\\textsc{tool}'
        elif int(prompt_index) == 5:
            prompt_index = '\\textsc{elt}'
    else:
        if 'Coyne' in prompt_type:
            prompt_type = '2'
            prompt_index = '\\textsc{coyne}$^{*}$'
        else:
            prompt_type = prompt_type.split('-')[0]
            if int(prompt_index) == 1:
                prompt_index = '\\textsc{coyne}'
            elif int(prompt_index) == 2:
                prompt_index = '\\textsc{tool}'
            else:
                prompt_index = '\\textsc{elt}'

    precision = df[df['model'] == model]['Prec'].values[0]
    recall = df[df['model'] == model]['Rec'].values[0]
    
    # write the latex row
    # $_{{{str(prompt_index)}}}$
    return f'{precision:.3f} & {recall:.3f} & {str(metric_value)} & {prompt_type} & {str(prompt_index)}'

In [9]:

# write the latex table
# print(' & '.join([''] + [split for split in sorted(splits)]) + ' \\\\')

model_order = [
    'bigscience/bloomz-7b1',
    'google/flan-t5-xxl',
    'Writer/InstructPalmyra-20b',
    'facebook/opt-iml-max-30b',
    'tiiuae/falcon-40b-instruct',
    'meta-llama/Llama-2-70b-chat-hf',
    'stabilityai/StableBeluga2',
    'command', 
    'gpt-3.5-turbo-0613',
    # 'gpt-4-0613',
    ]

split_order = ['fce-dev', 'jfleg-dev', 'wibea-dev']

for model in model_order:
    if model not in df['model'].values:
        continue
    model_name = model.split('/')[-1]
    row = [model_name]
    for split in split_order:
        df_split_temp = df[df['split'] == split]
        metric = 'F0.5' if 'jfleg' not in split else 'GLEU'
        row.append(write_latex_row_expanded_subcript(df_split_temp, model, metric))
    
    print(' & '.join(row) + ' \\\\')
    # print('\\hline')

bloomz-7b1 & 0.349 & 3 & \textsc{coyne} & 0.456 & 2 & \textsc{coyne}$^{*}$ & 0.347 & 3 & \textsc{coyne} \\
flan-t5-xxl & 0.447 & 1 & \textsc{tool} & 0.463 & 1 & \textsc{tool} & 0.423 & 3 & \textsc{tool} \\
InstructPalmyra-20b & 0.341 & 2 & \textsc{coyne} & 0.517 & 0 & \textsc{tool} & 0.374 & 2 & \textsc{coyne} \\
opt-iml-max-30b & 0.395 & 0 & \textsc{tool} & 0.506 & 2 & \textsc{coyne}$^{*}$ & 0.400 & 3 & \textsc{elt} \\
falcon-40b-instruct & 0.425 & 2 & \textsc{tool} & 0.548 & 4 & \textsc{coyne} & 0.454 & 4 & \textsc{tool} \\
Llama-2-70b-chat-hf & 0.323 & 0 & \textsc{tool} & 0.500 & 0 & \textsc{tool} & 0.359 & 0 & \textsc{tool} \\
StableBeluga2 & 0.403 & 0 & \textsc{tool} & 0.563 & 0 & \textsc{coyne} & 0.447 & 0 & \textsc{tool} \\
command & 0.353 & 0 & \textsc{tool} & 0.543 & 2 & \textsc{coyne}$^{*}$ & 0.391 & 0 & \textsc{tool} \\
gpt-3.5-turbo-0613 & 0.416 & 0 & \textsc{elt} & 0.577 & 4 & \textsc{tool} & 0.439 & 1 & \textsc{tool} \\


### print a latex table for one dataset, including the precision and recall

In [13]:

model_order = [
    'bigscience/bloomz-7b1',
    'google/flan-t5-xxl',
    'Writer/InstructPalmyra-20b',
    'facebook/opt-iml-max-30b',
    'tiiuae/falcon-40b-instruct',
    'meta-llama/Llama-2-70b-chat-hf',
    'stabilityai/StableBeluga2',
    'command', 
    'gpt-3.5-turbo-0613',
    # 'gpt-4-0613',
    ]

split_order = ['fce-dev', 'jfleg-dev', 'wibea-dev']
split = split_order[2]

for model in model_order:
    if model not in df['model'].values:
        continue
    model_name = model.split('/')[-1]
    row = [model_name]
    
    df_split_temp = df[df['split'] == split]
    metric = 'F0.5' if 'jfleg' not in split else 'GLEU'
    row.append(write_latex_row_expanded_subcript(df_split_temp, model, metric))
    
    print(' & '.join(row) + ' \\\\')
    # print('\\hline')

bloomz-7b1 & 0.508 & 0.153 & 0.347 & 3 & \textsc{coyne} \\
flan-t5-xxl & 0.623 & 0.185 & 0.423 & 3 & \textsc{tool} \\
InstructPalmyra-20b & 0.396 & 0.305 & 0.374 & 2 & \textsc{coyne} \\
opt-iml-max-30b & 0.577 & 0.180 & 0.400 & 3 & \textsc{elt} \\
falcon-40b-instruct & 0.467 & 0.407 & 0.454 & 4 & \textsc{tool} \\
Llama-2-70b-chat-hf & 0.339 & 0.469 & 0.359 & 0 & \textsc{tool} \\
StableBeluga2 & 0.442 & 0.472 & 0.447 & 0 & \textsc{tool} \\
command & 0.403 & 0.350 & 0.391 & 0 & \textsc{tool} \\
gpt-3.5-turbo-0613 & 0.422 & 0.524 & 0.439 & 1 & \textsc{tool} \\


In [None]:
# load full dev set results

In [10]:
df[df['split'] == 'fce-dev'].head()

Unnamed: 0,model,prompt_type_index,GLEU,Prec,Rec,split,F0.5
20,gpt-4-0613,0-shot_7,,0.4727,0.4775,fce-dev,0.474
21,gpt-3.5-turbo-0613,0-shot_5,,0.3984,0.5045,fce-dev,0.416
22,stabilityai/StableBeluga2,0-shot_7,,0.3964,0.4321,fce-dev,0.403
23,facebook/opt-iml-max-30b,0-shot_7,,0.5586,0.182,fce-dev,0.395
24,command,0-shot_6,,0.3562,0.3419,fce-dev,0.353
