In [1]:
from pprint import pprint
from collections import defaultdict
import json
import re
import sys
from typing import Union, List, Optional, Set, Tuple, Dict, Optional, Callable
from pprint import pprint

import numpy as np
from lab.utils import shorten
import pandas as pd
from pathlib import Path
from machine_learning.analysis.dataframe import (
    pivot_rotate,

    slice_rows,
    slice_cols,
    sort_rows,
    sort_cols,
    aggregate,
    percentize,
    round,

    rename_index,
    rename_cols,
    rename_cells,

    isnan,
    to_latex,
    color_by_rank,
)
from machine_learning.analysis.series import (
    maybe_numeric_series,
)
from machine_learning.analysis.utils import (
    maybe_round,
)
pd.set_option('display.max_columns', 300)
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_colwidth', 1000)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def find_result_paths(top_dir: Union[str, Path], regexps: Optional[List[str]] = None) -> List[str]:
    regexps = regexps or []
    top_dir = Path(top_dir)
    return [
        str(path) for path in top_dir.glob('**/results.tsv')
        if all(re.match(regexp, str(path)) for regexp in regexps)
    ]

In [3]:
def name_method(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
            
    df[COL_METHOD] = df.apply(
        lambda row: '__'.join(str(row[col]) for col in METHOD_DEFINE_COLS),
        axis=1,
    )

    df = sort_cols(df, [COL_METHOD, '.*'])

    return df

def name_task(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    df[COL_TASK] = df.apply(
        lambda row: '__'.join(str(row[col]) for col in TASK_DEFINE_COLS),
        axis=1,
    )

    df = sort_cols(df, [COL_TASK, '.*'])
    
    return df

def prettify_df(df: pd.DataFrame) -> pd.DataFrame:
    df = percentize(df)
    df = round(df)
    return df

In [4]:
COL_DATASET = 'dataset_uname'
COL_LEARNING = 'learning'
COL_MODEL_NAME_OR_PATH = 'model_name_or_path'
COL_LRATE = 'learning_rate'

TASK_DEFINE_COLS = [COL_DATASET, COL_LEARNING]
COL_TASK = 'task'

METHOD_DEFINE_COLS = [COL_MODEL_NAME_OR_PATH, COL_LRATE]
COL_METHOD = 'method'

In [5]:
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230826.jpn/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230901.overfit/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230904.LLM_FS/'
_TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230905.LLM_FS/'

In [6]:
result_paths = find_result_paths(_TOP_DIR)
if len(result_paths) == 0:
    raise Exception(f'Results not found under {_TOP_DIR}')
elif len(result_paths) == 1:
    results_path = result_paths[0]
else:
    print('Choose the result fomr the following paths:')
    pprint(result_paths)
    results_path = input('path = ')
    


In [7]:
master_df = pd.read_csv(results_path, sep='\t')
master_df.head()

Unnamed: 0,dataset_uname,learning,model_name_or_path,learning_rate,seed,base_config_name,generation_max_proof_steps,generation_num_beams,generation_input_k,gradient_accumulation_steps,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
0,20230904.jpn.D1.wo_brnch,LLM_FS.shot-100,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.058824,,,,,,,,0.1,0.066667,0.0,0.058824,,,,,,,,0.0,0.033333,0.666667,0.647059,,,,,,,,0.1,0.466667,0.666667,0.647059,,,,,,,,0.1,0.466667
1,20230904.jpn.D1.wo_brnch,LLM_FS.shot-100,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.058824,,,,,,,,0.1,0.066667,0.0,0.058824,,,,,,,,0.0,0.033333,1.0,0.647059,,,,,,,,0.1,0.5,1.0,0.647059,,,,,,,,0.1,0.5
2,20230904.jpn.D1.wo_brnch,LLM_FS.shot-1000,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.333333,0.294118,,,,,,,,0.1,0.233333,0.333333,0.294118,,,,,,,,0.1,0.233333,0.333333,0.823529,,,,,,,,0.1,0.533333,0.333333,0.823529,,,,,,,,0.1,0.533333
3,20230904.jpn.D1.wo_brnch,LLM_FS.shot-1000,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.117647,,,,,,,,0.0,0.066667,0.0,0.117647,,,,,,,,0.0,0.066667,0.333333,0.470588,,,,,,,,0.0,0.3,0.333333,0.470588,,,,,,,,0.0,0.3
4,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-100,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.619048,,,,,,,,0.0,0.433333,0.0,0.619048,,,,,,,,0.0,0.433333,0.5,0.666667,,,,,,,,0.0,0.5,0.5,0.666667,,,,,,,,0.0,0.5


In [8]:
df = name_method(master_df)
df = name_task(df)
df

Unnamed: 0,task,method,dataset_uname,learning,model_name_or_path,learning_rate,seed,base_config_name,generation_max_proof_steps,generation_num_beams,generation_input_k,gradient_accumulation_steps,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
0,20230904.jpn.D1.wo_brnch__LLM_FS.shot-100,matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1.wo_brnch,LLM_FS.shot-100,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.058824,,,,,,,,0.1,0.066667,0.0,0.058824,,,,,,,,0.0,0.033333,0.666667,0.647059,,,,,,,,0.1,0.466667,0.666667,0.647059,,,,,,,,0.1,0.466667
1,20230904.jpn.D1.wo_brnch__LLM_FS.shot-100,matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1.wo_brnch,LLM_FS.shot-100,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.058824,,,,,,,,0.1,0.066667,0.0,0.058824,,,,,,,,0.0,0.033333,1.0,0.647059,,,,,,,,0.1,0.5,1.0,0.647059,,,,,,,,0.1,0.5
2,20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000,matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1.wo_brnch,LLM_FS.shot-1000,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.333333,0.294118,,,,,,,,0.1,0.233333,0.333333,0.294118,,,,,,,,0.1,0.233333,0.333333,0.823529,,,,,,,,0.1,0.533333,0.333333,0.823529,,,,,,,,0.1,0.533333
3,20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000,matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1.wo_brnch,LLM_FS.shot-1000,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.117647,,,,,,,,0.0,0.066667,0.0,0.117647,,,,,,,,0.0,0.066667,0.333333,0.470588,,,,,,,,0.0,0.3,0.333333,0.470588,,,,,,,,0.0,0.3
4,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100,matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-100,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.619048,,,,,,,,0.0,0.433333,0.0,0.619048,,,,,,,,0.0,0.433333,0.5,0.666667,,,,,,,,0.0,0.5,0.5,0.666667,,,,,,,,0.0,0.5
5,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100,matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-100,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.571429,,,,,,,,0.142857,0.433333,0.0,0.571429,,,,,,,,0.142857,0.433333,0.0,0.571429,,,,,,,,0.142857,0.433333,0.0,0.571429,,,,,,,,0.142857,0.433333
6,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-1000,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.5,0.904762,,,,,,,,0.571429,0.8,0.5,0.904762,,,,,,,,0.428571,0.766667,1.0,1.0,,,,,,,,0.571429,0.9,1.0,1.0,,,,,,,,0.571429,0.9
7,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-1000,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.52381,,,,,,,,0.571429,0.5,0.0,0.52381,,,,,,,,0.285714,0.433333,0.5,0.52381,,,,,,,,0.571429,0.533333,0.5,0.52381,,,,,,,,0.571429,0.533333
8,20230904.jpn.D1__LLM_FS.shot-100,matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1,LLM_FS.shot-100,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.2,0.117647,,,,,,,,0.0,0.1,0.2,0.117647,,,,,,,,0.0,0.1,0.6,0.529412,,,,,,,,0.0,0.4,0.6,0.529412,,,,,,,,0.0,0.4
9,20230904.jpn.D1__LLM_FS.shot-100,matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1,LLM_FS.shot-100,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,100,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.176471,,,,,,,,0.0,0.1,0.0,0.176471,,,,,,,,0.0,0.1,0.8,0.470588,,,,,,,,0.0,0.4,0.8,0.470588,,,,,,,,0.0,0.4


In [9]:
for task in df[COL_TASK].unique():
    print("'" + task + "',")

'20230904.jpn.D1.wo_brnch__LLM_FS.shot-100',
'20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000',
'20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100',
'20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000',
'20230904.jpn.D1__LLM_FS.shot-100',
'20230904.jpn.D1__LLM_FS.shot-1000',


In [10]:
TASK_NAMES = [
    '20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100',
    '20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000',
    '20230904.jpn.D1.wo_brnch__LLM_FS.shot-100',
    '20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000',
    '20230904.jpn.D1__LLM_FS.shot-100',
    '20230904.jpn.D1__LLM_FS.shot-1000',
]
MAJOR_TASK = TASK_NAMES[1]

In [11]:
task_dfs: Dict[str, pd.DataFrame] = {}
for task_name in TASK_NAMES:
    task_dfs[task_name] = slice_rows(
        df,
        lambda row: row[COL_TASK] == task_name
    )

task_dfs[MAJOR_TASK]

Unnamed: 0,task,method,dataset_uname,learning,model_name_or_path,learning_rate,seed,base_config_name,generation_max_proof_steps,generation_num_beams,generation_input_k,gradient_accumulation_steps,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
6,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-1000,matsuo-lab/weblab-10b,0.0001,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.5,0.904762,,,,,,,,0.571429,0.8,0.5,0.904762,,,,,,,,0.428571,0.766667,1.0,1.0,,,,,,,,0.571429,0.9,1.0,1.0,,,,,,,,0.571429,0.9
7,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1.wo_brnch.wo_dstrct,LLM_FS.shot-1000,matsuo-lab/weblab-10b,1e-05,0,FLNLcorpus.20220827.base,20,1,,4,causal,True,0.5,0,30,2000,320,2000,1000,1,2,all_at_once,True,Solve FLD task:,longest,32,0.0,0.52381,,,,,,,,0.571429,0.5,0.0,0.52381,,,,,,,,0.285714,0.433333,0.5,0.52381,,,,,,,,0.571429,0.533333,0.5,0.52381,,,,,,,,0.571429,0.533333


In [12]:
METRIC_NAMES = [
    'eval/extr_stps.D-all.proof_accuracy.zero_one',
    'eval/strct.D-all.proof_accuracy.zero_one',
    'eval/strct.D-all.answer_accuracy',
]

metric_dfs: Dict[str, pd.DataFrame] = {}
for task_name, task_df in task_dfs.items():
    metric_dfs[task_name] = slice_cols(task_df, [COL_TASK, COL_METHOD] + METRIC_NAMES)

metric_dfs[MAJOR_TASK]

Unnamed: 0,task,method,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/strct.D-all.answer_accuracy
6,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,matsuo-lab/weblab-10b__0.0001,0.8,0.766667,0.9
7,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,matsuo-lab/weblab-10b__1e-05,0.5,0.433333,0.533333


In [13]:
METRIC_RENAMES = {
    'eval/extr_stps.D-all.proof_accuracy.zero_one': 'prf_acc.extr',
    'eval/strct.D-all.proof_accuracy.zero_one': 'prf_acc.strct',
    'eval/strct.D-all.answer_accuracy': 'ans.acc',
}

pretty_dfs: Dict[str, pd.DataFrame] = {}

for task_name, metric_df in metric_dfs.items():
    pretty_df = prettify_df(metric_df)
    pretty_df = rename_cols(pretty_df, METRIC_RENAMES)
    
    pretty_df.index = pretty_df[COL_METHOD]
    pretty_df = pretty_df.drop(columns=[COL_METHOD])
    
    pretty_df = color_by_rank(pretty_df, 'col')
    
    pretty_dfs[task_name] = pretty_df

pretty_dfs[MAJOR_TASK]

Unnamed: 0_level_0,task,prf_acc.extr,prf_acc.strct,ans.acc
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
matsuo-lab/weblab-10b__0.0001,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,\cellcolor{blue!60} 80.0,\cellcolor{blue!60} 76.7,\cellcolor{blue!60} 90.0
matsuo-lab/weblab-10b__1e-05,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,\cellcolor{blue!10} 50.0,\cellcolor{blue!10} 43.3,\cellcolor{blue!10} 53.3


In [14]:
def horizontal_concat(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    dfs = [df.copy() for df in dfs]
    # align index for horizontal concat
    for df in dfs:
        df.index = range(len(df))
    return pd.concat(dfs, axis=1)
    
colored_concat_df = horizontal_concat(
    [pretty_df for task_name, pretty_df in sorted(pretty_dfs.items())]
)

print('    '.join([task_name for task_name in pretty_dfs.keys()]))
colored_concat_df

20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100    20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000    20230904.jpn.D1.wo_brnch__LLM_FS.shot-100    20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000    20230904.jpn.D1__LLM_FS.shot-100    20230904.jpn.D1__LLM_FS.shot-1000


Unnamed: 0,task,prf_acc.extr,prf_acc.strct,ans.acc,task.1,prf_acc.extr.1,prf_acc.strct.1,ans.acc.1,task.2,prf_acc.extr.2,prf_acc.strct.2,ans.acc.2,task.3,prf_acc.extr.3,prf_acc.strct.3,ans.acc.3,task.4,prf_acc.extr.4,prf_acc.strct.4,ans.acc.4,task.5,prf_acc.extr.5,prf_acc.strct.5,ans.acc.5
0,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100,\cellcolor{blue!36} 43.3,\cellcolor{blue!36} 43.3,\cellcolor{blue!60} 50.0,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,\cellcolor{blue!60} 80.0,\cellcolor{blue!60} 76.7,\cellcolor{blue!60} 90.0,20230904.jpn.D1.wo_brnch__LLM_FS.shot-100,\cellcolor{blue!36} 6.7,\cellcolor{blue!36} 3.3,\cellcolor{blue!10} 46.7,20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000,\cellcolor{blue!60} 23.3,\cellcolor{blue!60} 23.3,\cellcolor{blue!60} 53.3,20230904.jpn.D1__LLM_FS.shot-100,\cellcolor{blue!36} 10.0,\cellcolor{blue!36} 10.0,\cellcolor{blue!36} 40.0,20230904.jpn.D1__LLM_FS.shot-1000,\cellcolor{blue!60} 33.3,\cellcolor{blue!60} 26.7,\cellcolor{blue!60} 56.7
1,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100,\cellcolor{blue!36} 43.3,\cellcolor{blue!36} 43.3,\cellcolor{blue!10} 43.3,20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000,\cellcolor{blue!10} 50.0,\cellcolor{blue!10} 43.3,\cellcolor{blue!10} 53.3,20230904.jpn.D1.wo_brnch__LLM_FS.shot-100,\cellcolor{blue!36} 6.7,\cellcolor{blue!36} 3.3,\cellcolor{blue!60} 50.0,20230904.jpn.D1.wo_brnch__LLM_FS.shot-1000,\cellcolor{blue!10} 6.7,\cellcolor{blue!10} 6.7,\cellcolor{blue!10} 30.0,20230904.jpn.D1__LLM_FS.shot-100,\cellcolor{blue!36} 10.0,\cellcolor{blue!36} 10.0,\cellcolor{blue!36} 40.0,20230904.jpn.D1__LLM_FS.shot-1000,\cellcolor{blue!10} 10.0,\cellcolor{blue!10} 3.3,\cellcolor{blue!10} 33.3


In [15]:
print(to_latex(colored_concat_df, with_index=True))

\begin{tabular}{lllllllllllllllllllllllll}
\toprule
 & task & prf\_acc.extr & prf\_acc.strct & ans.acc & task & prf\_acc.extr & prf\_acc.strct & ans.acc & task & prf\_acc.extr & prf\_acc.strct & ans.acc & task & prf\_acc.extr & prf\_acc.strct & ans.acc & task & prf\_acc.extr & prf\_acc.strct & ans.acc & task & prf\_acc.extr & prf\_acc.strct & ans.acc \\
\midrule
0 & 20230904.jpn.D1.wo\_brnch.wo\_dstrct\_\_LLM\_FS.shot-100 & \cellcolor{blue!36} 43.3 & \cellcolor{blue!36} 43.3 & \cellcolor{blue!60} 50.0 & 20230904.jpn.D1.wo\_brnch.wo\_dstrct\_\_LLM\_FS.shot-1000 & \cellcolor{blue!60} 80.0 & \cellcolor{blue!60} 76.7 & \cellcolor{blue!60} 90.0 & 20230904.jpn.D1.wo\_brnch\_\_LLM\_FS.shot-100 & \cellcolor{blue!36} 6.7 & \cellcolor{blue!36} 3.3 & \cellcolor{blue!10} 46.7 & 20230904.jpn.D1.wo\_brnch\_\_LLM\_FS.shot-1000 & \cellcolor{blue!60} 23.3 & \cellcolor{blue!60} 23.3 & \cellcolor{blue!60} 53.3 & 20230904.jpn.D1\_\_LLM\_FS.shot-100 & \cellcolor{blue!36} 10.0 & \cellcolor{blue!36} 10.0 & \