In [1]:
from pprint import pprint
from collections import defaultdict
import json
import re
import sys
from typing import Union, List, Optional, Set, Tuple, Dict, Optional, Callable

import numpy as np
from lab.utils import shorten
import pandas as pd
from pathlib import Path
from machine_learning.analysis.dataframe import (
    pivot_rotate,

    slice_rows,
    slice_cols,
    sort_rows,
    sort_cols,
    aggregate,
    percentize,
    round,

    rename_index,
    rename_cols,
    rename_cells,

    isnan,
    to_latex,
    color_by_rank,
)
from machine_learning.analysis.series import (
    maybe_numeric_series,
)
from machine_learning.analysis.utils import (
    maybe_round,
)
pd.set_option('display.max_columns', 300)
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_colwidth', 1000)

  from .autonotebook import tqdm as notebook_tqdm


# Load functions

In [2]:
def find_results(top_dir: Union[str, Path], regexps: Optional[List[str]] = None) -> None:
    regexps = regexps or []
    top_dir = Path(top_dir)
    for path in sorted([
        path for path in top_dir.glob('**/results.tsv')
        if all(re.match(regexp, str(path)) for regexp in regexps)
    ]):
        print(path)

In [3]:
# COL_METHOD = 'checkpoint_name'
COL_METHOD = 'method'
COL_MODEL_NAME_OR_PATH = 'model_name_or_path'
COL_DATASET = 'dataset_uname'
COL_TASK = 'dataset_uname'

# Analysis functions

In [24]:
def name_method(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
            
    df[COL_METHOD] = df.apply(
        lambda row: row[COL_MODEL_NAME_OR_PATH],
        axis=1,
    )
    
    return df


def prettify_df(df: pd.DataFrame) -> pd.DataFrame:
    df = percentize(df)
    df = round(df)
    return df

In [5]:
_TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230826.jpn/'
find_results(_TOP_DIR)

../outputs/02.aggregate_tf_results.py/20230826.jpn/results.tsv


In [6]:
results_path = '../outputs/02.aggregate_tf_results.py/20230826.jpn/results.tsv'
master_df = pd.read_csv(results_path, sep='\t')
master_df

Unnamed: 0,base_config_name,dataset_uname,generation_max_proof_steps,generation_num_beams,generation_input_k,gradient_accumulation_steps,learning_rate,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,model_name_or_path,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,seed,shot,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
0,FLNLcorpus.20220827.base,20230826.jpn.D3,20,10,,16,0.0001,,False,0.5,1000,30,1700,20000,100,,retrieva-jp/t5-base-long,1,1,stepwise,True,0,FT.step-20000,Solve FLD task:,longest,1000,1.0,0.990741,0.780952,0.661017,,,,,,,0.798,1.0,0.953704,0.847619,0.70339,,,,,,,0.714,1.0,0.990741,0.961905,0.822034,,,,,,,0.874,1.0,0.990741,0.961905,0.822034,,,,,,,0.874
1,FLNLcorpus.20220827.base,20230826.jpn.D8,20,10,,16,0.0001,seq2seq,False,0.5,1000,30,1700,20000,100,,retrieva-jp/t5-base-long,1,1,stepwise,True,0,FT.step-20000,Solve FLD task:,longest,1000,0.8,1.0,0.680851,0.75,0.46875,0.3,0.172414,0.206897,0.130435,,0.55,0.8,1.0,0.702128,0.729167,0.53125,0.4,0.241379,0.241379,0.130435,,0.42,1.0,1.0,0.851064,0.875,0.6875,0.625,0.448276,0.551724,0.304348,,0.688,1.0,1.0,0.851064,0.875,0.6875,0.625,0.448276,0.551724,0.304348,,0.688


In [7]:
method_named_df = name_method(master_df)
method_named_df

Unnamed: 0,base_config_name,dataset_uname,generation_max_proof_steps,generation_num_beams,generation_input_k,gradient_accumulation_steps,learning_rate,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,model_name_or_path,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,seed,shot,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy,method
0,FLNLcorpus.20220827.base,20230826.jpn.D3,20,10,,16,0.0001,,False,0.5,1000,30,1700,20000,100,,retrieva-jp/t5-base-long,1,1,stepwise,True,0,FT.step-20000,Solve FLD task:,longest,1000,1.0,0.990741,0.780952,0.661017,,,,,,,0.798,1.0,0.953704,0.847619,0.70339,,,,,,,0.714,1.0,0.990741,0.961905,0.822034,,,,,,,0.874,1.0,0.990741,0.961905,0.822034,,,,,,,0.874,retrieva-jp/t5-base-long
1,FLNLcorpus.20220827.base,20230826.jpn.D8,20,10,,16,0.0001,seq2seq,False,0.5,1000,30,1700,20000,100,,retrieva-jp/t5-base-long,1,1,stepwise,True,0,FT.step-20000,Solve FLD task:,longest,1000,0.8,1.0,0.680851,0.75,0.46875,0.3,0.172414,0.206897,0.130435,,0.55,0.8,1.0,0.702128,0.729167,0.53125,0.4,0.241379,0.241379,0.130435,,0.42,1.0,1.0,0.851064,0.875,0.6875,0.625,0.448276,0.551724,0.304348,,0.688,1.0,1.0,0.851064,0.875,0.6875,0.625,0.448276,0.551724,0.304348,,0.688,retrieva-jp/t5-base-long


In [8]:
TASK_NAMES = [
    '20230826.jpn.D3',
    '20230826.jpn.D8',
]
MAJOR_TASK = TASK_NAMES[0]

task_dfs: Dict[str, pd.DataFrame] = {}
for task_name in TASK_NAMES:
    task_dfs[task_name] = slice_rows(
        method_named_df,
        lambda row: row[COL_TASK] == task_name
    )

task_dfs[MAJOR_TASK]

Unnamed: 0,base_config_name,dataset_uname,generation_max_proof_steps,generation_num_beams,generation_input_k,gradient_accumulation_steps,learning_rate,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,model_name_or_path,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,seed,shot,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy,method
0,FLNLcorpus.20220827.base,20230826.jpn.D3,20,10,,16,0.0001,,False,0.5,1000,30,1700,20000,100,,retrieva-jp/t5-base-long,1,1,stepwise,True,0,FT.step-20000,Solve FLD task:,longest,1000,1.0,0.990741,0.780952,0.661017,,,,,,,0.798,1.0,0.953704,0.847619,0.70339,,,,,,,0.714,1.0,0.990741,0.961905,0.822034,,,,,,,0.874,1.0,0.990741,0.961905,0.822034,,,,,,,0.874,retrieva-jp/t5-base-long


In [9]:
METRIC_NAMES = [
    'eval/extr_stps.D-all.proof_accuracy.zero_one',
    'eval/strct.D-all.proof_accuracy.zero_one',
    'eval/strct.D-all.answer_accuracy',
]

metric_dfs: Dict[str, pd.DataFrame] = {}
for task_name, task_df in task_dfs.items():
    metric_dfs[task_name] = slice_cols(task_df, [COL_METHOD] + METRIC_NAMES)

metric_dfs[MAJOR_TASK]

Unnamed: 0,method,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/strct.D-all.answer_accuracy
0,retrieva-jp/t5-base-long,0.798,0.714,0.874


In [25]:
pretty_dfs: Dict[str, pd.DataFrame] = {}

for task_name, metric_df in metric_dfs.items():
    pretty_df = prettify_df(metric_df)
    pretty_df = color_by_rank(pretty_df, 'col')
    pretty_dfs[task_name] = pretty_df

pretty_dfs[MAJOR_TASK]

  color_scale = color_lower + (val - scale_lower) / (scale_upper - scale_lower) * (color_upper - color_lower)


Unnamed: 0,method,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/strct.D-all.answer_accuracy
0,retrieva-jp/t5-base-long,79.8,71.4,87.4


In [26]:
def horizontal_concat(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    dfs = [df.copy() for df in dfs]
    # align index for horizontal concat
    for df in colored_dfs.values():
        df.index = range(len(df))
    return pd.concat(dfs, axis=1)
    
colored_concat_df = horizontal_concat(
    [colored_df for task_name, colored_df in sorted(colored_dfs.items())]
)
colored_concat_df

Unnamed: 0,method,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/strct.D-all.answer_accuracy,method.1,eval/extr_stps.D-all.proof_accuracy.zero_one.1,eval/strct.D-all.proof_accuracy.zero_one.1,eval/strct.D-all.answer_accuracy.1
0,retrieva-jp/t5-base-long,0.7979999780654907,0.7139999866485596,0.8740000128746033,retrieva-jp/t5-base-long,0.550000011920929,0.4199999868869781,0.6880000233650208


In [27]:
print(to_latex(colored_concat_df))

\begin{tabular}{lllllllll}
\toprule
 & method & eval/extr\_stps.D-all.proof\_accuracy.zero\_one & eval/strct.D-all.proof\_accuracy.zero\_one & eval/strct.D-all.answer\_accuracy & method & eval/extr\_stps.D-all.proof\_accuracy.zero\_one & eval/strct.D-all.proof\_accuracy.zero\_one & eval/strct.D-all.answer\_accuracy \\
\midrule
0 & retrieva-jp/t5-base-long & 0.7979999780654907 & 0.7139999866485596 & 0.8740000128746033 & retrieva-jp/t5-base-long & 0.550000011920929 & 0.4199999868869781 & 0.6880000233650208 \\
\bottomrule
\end{tabular}

