In [1]:
from pprint import pprint
from collections import defaultdict
import json
import re
import sys
from typing import Union, List, Optional, Set, Tuple, Dict, Optional, Callable
from pprint import pprint
from IPython.display import display, HTML
from collections import OrderedDict

import numpy as np
from lab.utils import shorten
import pandas as pd
from pathlib import Path
from machine_learning.analysis.dataframe import (
    pivot_rotate,

    slice_rows,
    slice_cols,
    sort_rows,
    sort_cols,
    aggregate,
    percentize,
    round,

    rename_index,
    rename_cols,
    rename_cells,

    isnan,
    to_latex,
    color_by_rank,
)
from machine_learning.analysis.series import (
    maybe_numeric_series,
)
from machine_learning.analysis.utils import (
    maybe_round,
)
pd.set_option('display.max_columns', 300)
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_colwidth', 1000)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def find_result_paths(top_dir: Union[str, Path], regexps: Optional[List[str]] = None) -> List[str]:
    regexps = regexps or []
    top_dir = Path(top_dir)
    return [
        str(path) for path in top_dir.glob('**/results.tsv')
        if all(re.match(regexp, str(path)) for regexp in regexps)
    ]

In [3]:
def name_method(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
            
    df[COL_METHOD] = df.apply(
        lambda row: '__'.join(str(row[col]) for col in METHOD_DEFINE_COLS),
        axis=1,
    )

    df = sort_cols(df, [COL_METHOD, '.*'])

    return df

def name_task(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    df[COL_TASK] = df.apply(
        lambda row: '__'.join(str(row[col]) for col in TASK_DEFINE_COLS),
        axis=1,
    )

    df = sort_cols(df, [COL_TASK, '.*'])
    
    return df

def prettify_df(df: pd.DataFrame) -> pd.DataFrame:
    df = percentize(df)
    df = round(df)
    return df

In [4]:
COL_DATASET = 'dataset_uname'
COL_LEARNING = 'learning'
COL_MODEL_NAME_OR_PATH = 'model_name_or_path'
COL_LRATE = 'learning_rate'

TASK_DEFINE_COLS = [COL_DATASET, COL_LEARNING]
COL_TASK = 'task'

METHOD_DEFINE_COLS = [COL_MODEL_NAME_OR_PATH]
COL_METHOD = 'method'
METHOD_RENAMES = OrderedDict([
    ('^retrieva-jp/t5-base-long$', 'retrieva-t5-base'),
    ('^retrieva-jp/t5-xl$', 'retrieva-t5-xl'),

    ('^line-corporation/japanese-large-lm-1.7b$', 'line-1B'),
    ('^line-corporation/japanese-large-lm-1.7b-instruction-sft$', 'line-1B-instruct'),
    ('^line-corporation/japanese-large-lm-3.6b$', 'line-4B'),
    ('^line-corporation/japanese-large-lm-3.6b-instruction-sft$', 'line-4B-instruct'),
    
    ('^cyberagent/open-calm-medium$', 'calm-0.4B'),
    ('^cyberagent/open-calm-1b$', 'calm-1B'),
    ('^cyberagent/open-calm-3b$', 'calm-3B'),
    ('^cyberagent/open-calm-7b$', 'calm-7B'),
    
    ('^rinna/japanese-gpt-neox-3.6b$', 'rinna-4B'),
    ('^rinna/japanese-gpt-neox-3.6b-instruction-ppo$', 'rinna-4B-instruct'),
    
    ('^stabilityai/japanese-stablelm-base-alpha-7b$', 'stablelm-7B'),
    
    ('^elyza/ELYZA-japanese-Llama-2-7b-fast$', 'elyza-7B'),
    ('^elyza/ELYZA-japanese-Llama-2-7b-fast-instruct$', 'elyza-7B-instruct'),
    
    ('^matsuo-lab/weblab-10b$$', 'weblab-10B'),
    ('^matsuo-lab/weblab-10b-instruction-sft$', 'weblab-10B-instruct'),
])

COL_LRATE = 'learning_rate'
# LRATE = 0.0001
LRATE = 1e-05

METRIC_RENAMES = OrderedDict([
    ('eval/extr_stps.D-all.proof_accuracy.zero_one', 'prf_acc.extr'),
    # ('eval/strct.D-all.proof_accuracy.zero_one', 'prf_acc.strct'),
    # ('eval/strct.D-all.answer_accuracy', 'ans.acc'),
])
METRIC_NAMES = list(METRIC_RENAMES.values())

TASK_RENAMES = OrderedDict([
    # ('^20230826.jpn.D3__nan$', 'D3.full'),
    # ('^20230826.jpn.D8__nan$', 'D8.full'),
    
    # ('^20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-10$', 'D1-.10'),
    # ('^20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-100$', 'D1-.100'),
    # ('^20230904.jpn.D1.wo_brnch.wo_dstrct__LLM_FS.shot-1000$', 'D1-.1000'),
    # 
    # ('^20230904.jpn.D1__LLM_FS.shot-10$', 'D1.10'),
    # ('^20230904.jpn.D1__LLM_FS.shot-100$', 'D1.100'),
    # ('^20230904.jpn.D1__LLM_FS.shot-1000$', 'D1.1000'),
    # 
    # ('^20230904.jpn.D3__LLM_FS.shot-10$', 'D3.10'),
    # ('^20230904.jpn.D3__LLM_FS.shot-100$', 'D3.100'),
    # ('^20230904.jpn.D3__LLM_FS.shot-1000$', 'D3.1000'),

    ('^20230916.jpn.D1_wo_dist__LLM_FS.shot-10$', 'D1-.10'),
    ('^20230916.jpn.D1_wo_dist__LLM_FS.shot-100$', 'D1-.100'),
    ('^20230916.jpn.D1_wo_dist__LLM_FS.shot-1000$', 'D1-.1000'),
    ('^20230916.jpn.D1_wo_dist__LLM_FS.shot-10000$', 'D1-.10000'),

    ('^20230916.jpn.D1__LLM_FS.shot-10$', 'D1.10'),
    ('^20230916.jpn.D1__LLM_FS.shot-100$', 'D1.100'),
    ('^20230916.jpn.D1__LLM_FS.shot-1000$', 'D1.1000'),
    ('^20230916.jpn.D1__LLM_FS.shot-10000$', 'D1.10000'),

    ('^20230916.jpn.D3__LLM_FS.shot-10$', 'D3.10'),
    ('^20230916.jpn.D3__LLM_FS.shot-100$', 'D3.100'),
    ('^20230916.jpn.D3__LLM_FS.shot-1000$', 'D3.1000'),
    ('^20230916.jpn.D3__LLM_FS.shot-10000$', 'D3.10000'),

    ('^20230916.jpn.D5__LLM_FS.shot-10$', 'D5.10'),
    ('^20230916.jpn.D5__LLM_FS.shot-100$', 'D5.100'),
    ('^20230916.jpn.D5__LLM_FS.shot-1000$', 'D5.1000'),
    ('^20230916.jpn.D5__LLM_FS.shot-10000$', 'D5.10000'),
])
    
TASK_NAMES = list(TASK_RENAMES.values())
MAJOR_TASK = TASK_NAMES[1]

COLOR_SCALE_LOWER = 0
COLOR_SCALE_UPPER = 100
COLOR_PARETTE_LOWER = 3
COLOR_PARETTE_UPPER = 70

In [5]:
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230826.jpn/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230901.overfit/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230904.LLM_FS/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230905.LLM_FS/'

# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230910.preliminary'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230916.jpn/'
_TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230919.jpn/'

In [6]:
result_paths = find_result_paths(_TOP_DIR)
if len(result_paths) == 0:
    raise Exception(f'Results not found under {_TOP_DIR}')
elif len(result_paths) == 1:
    results_path = result_paths[0]
else:
    print('Choose the result fomr the following paths:')
    pprint(result_paths)
    results_path = input('path = ')
    


In [7]:
master_df = pd.read_csv(results_path, sep='\t')
master_df

Unnamed: 0,dataset_uname,learning,model_name_or_path,seed,learning_rate,base_config_name,generation_max_proof_steps,generation_num_beams,generation_top_k,gradient_accumulation_steps,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,eval/extr_stps.D-all.proof_accuracy.zero_one,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
0,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,cyberagent/open-calm-7b,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.380952,0.45098,,,,,,,,0.357143,0.41,0.380952,0.45098,,,,,,,,0.357143,0.41,0.380952,0.568627,,,,,,,,0.357143,0.47,0.380952,0.568627,,,,,,,,0.357143,0.47
1,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.619048,0.686275,,,,,,,,0.428571,0.6,0.619048,0.686275,,,,,,,,0.428571,0.6,0.619048,0.705882,,,,,,,,0.428571,0.61,0.619048,0.705882,,,,,,,,0.428571,0.61
2,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,stabilityai/japanese-stablelm-base-alpha-7b,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.52381,0.568627,,,,,,,,0.321429,0.49,0.52381,0.568627,,,,,,,,0.321429,0.49,0.52381,0.627451,,,,,,,,0.321429,0.52,0.52381,0.627451,,,,,,,,0.321429,0.52
3,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,elyza/ELYZA-japanese-Llama-2-7b-fast-instruct,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.809524,0.784314,,,,,,,,0.357143,0.67,0.809524,0.784314,,,,,,,,0.357143,0.67,0.809524,0.784314,,,,,,,,0.357143,0.67,0.809524,0.784314,,,,,,,,0.357143,0.67
4,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,rinna/japanese-gpt-neox-3.6b-instruction-ppo,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.809524,0.803922,,,,,,,,0.357143,0.68,0.809524,0.803922,,,,,,,,0.357143,0.68,0.809524,0.843137,,,,,,,,0.357143,0.7,0.809524,0.843137,,,,,,,,0.357143,0.7
5,20230916.jpn.D1_wo_dist,LLM_FS.shot-10,cyberagent/open-calm-7b,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.380952,0.45098,,,,,,,,0.178571,0.36,0.380952,0.45098,,,,,,,,0.178571,0.36,0.380952,0.45098,,,,,,,,0.178571,0.36,0.380952,0.45098,,,,,,,,0.178571,0.36
6,20230916.jpn.D1_wo_dist,LLM_FS.shot-10,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.428571,0.431373,,,,,,,,0.178571,0.36,0.428571,0.431373,,,,,,,,0.178571,0.36,0.428571,0.470588,,,,,,,,0.178571,0.38,0.428571,0.470588,,,,,,,,0.178571,0.38
7,20230916.jpn.D1_wo_dist,LLM_FS.shot-10,stabilityai/japanese-stablelm-base-alpha-7b,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.380952,0.529412,,,,,,,,0.107143,0.38,0.380952,0.529412,,,,,,,,0.107143,0.38,0.380952,0.529412,,,,,,,,0.107143,0.38,0.380952,0.529412,,,,,,,,0.107143,0.38
8,20230916.jpn.D1_wo_dist,LLM_FS.shot-10,elyza/ELYZA-japanese-Llama-2-7b-fast-instruct,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.714286,0.627451,,,,,,,,0.071429,0.49,0.714286,0.627451,,,,,,,,0.071429,0.49,0.714286,0.647059,,,,,,,,0.071429,0.5,0.714286,0.647059,,,,,,,,0.071429,0.5
9,20230916.jpn.D1_wo_dist,LLM_FS.shot-10,rinna/japanese-gpt-neox-3.6b-instruction-ppo,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.285714,0.27451,,,,,,,,0.821429,0.43,0.285714,0.27451,,,,,,,,0.821429,0.43,0.285714,0.27451,,,,,,,,0.821429,0.43,0.285714,0.27451,,,,,,,,0.821429,0.43


In [8]:
df = name_method(master_df)
df = rename_cells(df, [COL_METHOD], METHOD_RENAMES)
df = sort_rows(df, COL_METHOD, [f'^{name}$' for name in METHOD_RENAMES.values()])

df = name_task(df)
df = rename_cells(df, [COL_TASK], TASK_RENAMES)


df = rename_cols(df, METRIC_RENAMES)
                 
df = slice_rows(df, lambda row: row[COL_LRATE] == LRATE)
df = aggregate(
    df,
    [COL_METHOD, COL_TASK],
    {metric_name: lambda vals: np.mean(vals) for metric_name in METRIC_NAMES},
)

df

Unnamed: 0,method,task,dataset_uname,learning,model_name_or_path,seed,learning_rate,base_config_name,generation_max_proof_steps,generation_num_beams,generation_top_k,gradient_accumulation_steps,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,prf_acc.extr,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
"(line-4B-instruct, D1-.100)",line-4B-instruct,D1-.100,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.619048,0.686275,,,,,,,,0.428571,0.6,0.619048,0.686275,,,,,,,,0.428571,0.6,0.619048,0.705882,,,,,,,,0.428571,0.61,0.619048,0.705882,,,,,,,,0.428571,0.61
"(line-4B-instruct, D1-.10)",line-4B-instruct,D1-.10,20230916.jpn.D1_wo_dist,LLM_FS.shot-10,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.428571,0.431373,,,,,,,,0.178571,0.36,0.428571,0.431373,,,,,,,,0.178571,0.36,0.428571,0.470588,,,,,,,,0.178571,0.38,0.428571,0.470588,,,,,,,,0.178571,0.38
"(line-4B-instruct, D1-.10000)",line-4B-instruct,D1-.10000,20230916.jpn.D1_wo_dist,LLM_FS.shot-10000,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,300,2000,10000,1,2,all_at_once,False,Solve FLD task:,longest,90,1.0,0.941176,,,,,,,,0.964286,0.96,1.0,0.941176,,,,,,,,0.964286,0.96,1.0,0.980392,,,,,,,,0.964286,0.98,1.0,0.980392,,,,,,,,0.964286,0.98
"(line-4B-instruct, D1-.1000)",line-4B-instruct,D1-.1000,20230916.jpn.D1_wo_dist,LLM_FS.shot-1000,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,300,2000,1000,1,2,all_at_once,False,Solve FLD task:,longest,90,1.0,0.882353,,,,,,,,0.821429,0.89,1.0,0.882353,,,,,,,,0.821429,0.89,1.0,0.901961,,,,,,,,0.821429,0.9,1.0,0.901961,,,,,,,,0.821429,0.9
"(line-4B-instruct, D3.100)",line-4B-instruct,D3.100,20230916.jpn.D3,LLM_FS.shot-100,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.0,0.047619,0.0,0.0,,,,,,0.034483,0.02,0.0,0.047619,0.0,0.0,,,,,,0.034483,0.02,0.0,0.142857,0.095238,0.12,,,,,,0.034483,0.09,0.0,0.142857,0.095238,0.12,,,,,,0.034483,0.09
"(line-4B-instruct, D3.10)",line-4B-instruct,D3.10,20230916.jpn.D3,LLM_FS.shot-10,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,0.0,0.0,0.0,0.0,,,,,,0.0,0.0,0.0,0.0,0.0,0.0,,,,,,0.0,0.0,0.0,0.095238,0.095238,0.12,,,,,,0.0,0.07,0.0,0.095238,0.095238,0.12,,,,,,0.0,0.07
"(line-4B-instruct, D3.10000)",line-4B-instruct,D3.10000,20230916.jpn.D3,LLM_FS.shot-10000,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,300,2000,10000,1,2,all_at_once,False,Solve FLD task:,longest,90,0.5,0.428571,0.0,0.0,,,,,,0.862069,0.36,0.5,0.428571,0.0,0.0,,,,,,0.862069,0.36,0.5,0.666667,0.333333,0.08,,,,,,0.862069,0.5,0.5,0.666667,0.333333,0.08,,,,,,0.862069,0.5
"(line-4B-instruct, D3.1000)",line-4B-instruct,D3.1000,20230916.jpn.D3,LLM_FS.shot-1000,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,300,2000,1000,1,2,all_at_once,False,Solve FLD task:,longest,90,0.0,0.0,0.0,0.0,,,,,,0.068966,0.02,0.0,0.0,0.0,0.0,,,,,,0.068966,0.02,0.0,0.190476,0.190476,0.16,,,,,,0.068966,0.14,0.0,0.190476,0.190476,0.16,,,,,,0.068966,0.14
"(line-4B-instruct, D5.100)",line-4B-instruct,D5.100,20230916.jpn.D5,LLM_FS.shot-100,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,,0.0,0.0,0.0,0.0,0.0,,,,0.096774,0.03,,0.0,0.0,0.0,0.0,0.0,,,,0.096774,0.03,,0.0,0.0,0.0,0.142857,0.117647,,,,0.096774,0.07,,0.0,0.0,0.0,0.142857,0.117647,,,,0.096774,0.07
"(line-4B-instruct, D5.10)",line-4B-instruct,D5.10,20230916.jpn.D5,LLM_FS.shot-10,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,70,2000,10,1,2,all_at_once,False,Solve FLD task:,longest,21,,0.0,0.0,0.0,0.0,0.0,,,,0.129032,0.04,,0.0,0.0,0.0,0.0,0.0,,,,0.129032,0.04,,0.0,0.0,0.071429,0.071429,0.0,,,,0.129032,0.06,,0.0,0.0,0.071429,0.071429,0.0,,,,0.129032,0.06


In [9]:
for task in df[COL_TASK].unique():
    print("'" + task + "',")

'D1-.100',
'D1-.10',
'D1-.10000',
'D1-.1000',
'D3.100',
'D3.10',
'D3.10000',
'D3.1000',
'D5.100',
'D5.10',
'D5.10000',
'D5.1000',
'D1.100',
'D1.10',
'D1.10000',
'D1.1000',


In [10]:
task_dfs: Dict[str, pd.DataFrame] = OrderedDict()
for task_name in TASK_NAMES:
    task_dfs[task_name] = slice_rows(
        df,
        lambda row: row[COL_TASK] == task_name
    )

task_dfs[MAJOR_TASK]

Unnamed: 0,method,task,dataset_uname,learning,model_name_or_path,seed,learning_rate,base_config_name,generation_max_proof_steps,generation_num_beams,generation_top_k,gradient_accumulation_steps,lm_type,lora,max_grad_norm,max_predict_samples,max_proof_steps,max_source_length,max_steps,max_target_length,max_train_samples,per_device_eval_batch_size,per_device_train_batch_size,proof_sampling,sample_negative_proof,source_prefix,tokenizer_padding,warmup_steps,eval/extr_stps.D-0.proof_accuracy.zero_one,eval/extr_stps.D-1.proof_accuracy.zero_one,eval/extr_stps.D-2.proof_accuracy.zero_one,eval/extr_stps.D-3.proof_accuracy.zero_one,eval/extr_stps.D-4.proof_accuracy.zero_one,eval/extr_stps.D-5.proof_accuracy.zero_one,eval/extr_stps.D-6.proof_accuracy.zero_one,eval/extr_stps.D-7.proof_accuracy.zero_one,eval/extr_stps.D-8.proof_accuracy.zero_one,eval/extr_stps.D-None.proof_accuracy.zero_one,prf_acc.extr,eval/strct.D-0.proof_accuracy.zero_one,eval/strct.D-1.proof_accuracy.zero_one,eval/strct.D-2.proof_accuracy.zero_one,eval/strct.D-3.proof_accuracy.zero_one,eval/strct.D-4.proof_accuracy.zero_one,eval/strct.D-5.proof_accuracy.zero_one,eval/strct.D-6.proof_accuracy.zero_one,eval/strct.D-7.proof_accuracy.zero_one,eval/strct.D-8.proof_accuracy.zero_one,eval/strct.D-None.proof_accuracy.zero_one,eval/strct.D-all.proof_accuracy.zero_one,eval/extr_stps.D-0.answer_accuracy,eval/extr_stps.D-1.answer_accuracy,eval/extr_stps.D-2.answer_accuracy,eval/extr_stps.D-3.answer_accuracy,eval/extr_stps.D-4.answer_accuracy,eval/extr_stps.D-5.answer_accuracy,eval/extr_stps.D-6.answer_accuracy,eval/extr_stps.D-7.answer_accuracy,eval/extr_stps.D-8.answer_accuracy,eval/extr_stps.D-None.answer_accuracy,eval/extr_stps.D-all.answer_accuracy,eval/strct.D-0.answer_accuracy,eval/strct.D-1.answer_accuracy,eval/strct.D-2.answer_accuracy,eval/strct.D-3.answer_accuracy,eval/strct.D-4.answer_accuracy,eval/strct.D-5.answer_accuracy,eval/strct.D-6.answer_accuracy,eval/strct.D-7.answer_accuracy,eval/strct.D-8.answer_accuracy,eval/strct.D-None.answer_accuracy,eval/strct.D-all.answer_accuracy
"(line-4B-instruct, D1-.100)",line-4B-instruct,D1-.100,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,line-corporation/japanese-large-lm-3.6b-instruction-sft,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.619048,0.686275,,,,,,,,0.428571,0.6,0.619048,0.686275,,,,,,,,0.428571,0.6,0.619048,0.705882,,,,,,,,0.428571,0.61,0.619048,0.705882,,,,,,,,0.428571,0.61
"(calm-7B, D1-.100)",calm-7B,D1-.100,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,cyberagent/open-calm-7b,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.380952,0.45098,,,,,,,,0.357143,0.41,0.380952,0.45098,,,,,,,,0.357143,0.41,0.380952,0.568627,,,,,,,,0.357143,0.47,0.380952,0.568627,,,,,,,,0.357143,0.47
"(rinna-4B-instruct, D1-.100)",rinna-4B-instruct,D1-.100,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,rinna/japanese-gpt-neox-3.6b-instruction-ppo,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.809524,0.803922,,,,,,,,0.357143,0.68,0.809524,0.803922,,,,,,,,0.357143,0.68,0.809524,0.843137,,,,,,,,0.357143,0.7,0.809524,0.843137,,,,,,,,0.357143,0.7
"(stablelm-7B, D1-.100)",stablelm-7B,D1-.100,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,stabilityai/japanese-stablelm-base-alpha-7b,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.52381,0.568627,,,,,,,,0.321429,0.49,0.52381,0.568627,,,,,,,,0.321429,0.49,0.52381,0.627451,,,,,,,,0.321429,0.52,0.52381,0.627451,,,,,,,,0.321429,0.52
"(elyza-7B-instruct, D1-.100)",elyza-7B-instruct,D1-.100,20230916.jpn.D1_wo_dist,LLM_FS.shot-100,elyza/ELYZA-japanese-Llama-2-7b-fast-instruct,0,1e-05,FLNLcorpus.20220827.base,20,1,10,4,causal,False,0.5,,30,2000,170,2000,100,1,2,all_at_once,False,Solve FLD task:,longest,51,0.809524,0.784314,,,,,,,,0.357143,0.67,0.809524,0.784314,,,,,,,,0.357143,0.67,0.809524,0.784314,,,,,,,,0.357143,0.67,0.809524,0.784314,,,,,,,,0.357143,0.67


In [11]:
metric_dfs: Dict[str, pd.DataFrame] = OrderedDict()
for task_name, task_df in task_dfs.items():
    metric_dfs[task_name] = slice_cols(task_df, [COL_TASK, COL_METHOD] + METRIC_NAMES)

for task_name in TASK_NAMES:
    print('\n\n')
    print(f'========================= {task_name} ========================')
    display(metric_dfs[task_name])






Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1-.10)",D1-.10,line-4B-instruct,0.36
"(calm-7B, D1-.10)",D1-.10,calm-7B,0.36
"(rinna-4B-instruct, D1-.10)",D1-.10,rinna-4B-instruct,0.43
"(stablelm-7B, D1-.10)",D1-.10,stablelm-7B,0.38
"(elyza-7B-instruct, D1-.10)",D1-.10,elyza-7B-instruct,0.49







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1-.100)",D1-.100,line-4B-instruct,0.6
"(calm-7B, D1-.100)",D1-.100,calm-7B,0.41
"(rinna-4B-instruct, D1-.100)",D1-.100,rinna-4B-instruct,0.68
"(stablelm-7B, D1-.100)",D1-.100,stablelm-7B,0.49
"(elyza-7B-instruct, D1-.100)",D1-.100,elyza-7B-instruct,0.67







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1-.1000)",D1-.1000,line-4B-instruct,0.89
"(calm-7B, D1-.1000)",D1-.1000,calm-7B,0.64
"(rinna-4B-instruct, D1-.1000)",D1-.1000,rinna-4B-instruct,0.92
"(stablelm-7B, D1-.1000)",D1-.1000,stablelm-7B,0.97
"(elyza-7B-instruct, D1-.1000)",D1-.1000,elyza-7B-instruct,0.98







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1-.10000)",D1-.10000,line-4B-instruct,0.96
"(calm-7B, D1-.10000)",D1-.10000,calm-7B,0.76
"(rinna-4B-instruct, D1-.10000)",D1-.10000,rinna-4B-instruct,0.97
"(stablelm-7B, D1-.10000)",D1-.10000,stablelm-7B,0.99
"(elyza-7B-instruct, D1-.10000)",D1-.10000,elyza-7B-instruct,1.0







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1.10)",D1.10,line-4B-instruct,0.01
"(calm-7B, D1.10)",D1.10,calm-7B,0.06
"(rinna-4B-instruct, D1.10)",D1.10,rinna-4B-instruct,0.13
"(stablelm-7B, D1.10)",D1.10,stablelm-7B,0.02
"(elyza-7B-instruct, D1.10)",D1.10,elyza-7B-instruct,0.02







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1.100)",D1.100,line-4B-instruct,0.12
"(calm-7B, D1.100)",D1.100,calm-7B,0.1
"(rinna-4B-instruct, D1.100)",D1.100,rinna-4B-instruct,0.08
"(stablelm-7B, D1.100)",D1.100,stablelm-7B,0.11
"(elyza-7B-instruct, D1.100)",D1.100,elyza-7B-instruct,0.11







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1.1000)",D1.1000,line-4B-instruct,0.26
"(calm-7B, D1.1000)",D1.1000,calm-7B,0.12
"(rinna-4B-instruct, D1.1000)",D1.1000,rinna-4B-instruct,0.22
"(stablelm-7B, D1.1000)",D1.1000,stablelm-7B,0.24
"(elyza-7B-instruct, D1.1000)",D1.1000,elyza-7B-instruct,0.8







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D1.10000)",D1.10000,line-4B-instruct,0.52
"(calm-7B, D1.10000)",D1.10000,calm-7B,0.17
"(rinna-4B-instruct, D1.10000)",D1.10000,rinna-4B-instruct,0.51
"(stablelm-7B, D1.10000)",D1.10000,stablelm-7B,0.47
"(elyza-7B-instruct, D1.10000)",D1.10000,elyza-7B-instruct,0.94







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D3.10)",D3.10,line-4B-instruct,0.0
"(calm-7B, D3.10)",D3.10,calm-7B,0.0
"(rinna-4B-instruct, D3.10)",D3.10,rinna-4B-instruct,0.0
"(stablelm-7B, D3.10)",D3.10,stablelm-7B,0.0
"(elyza-7B-instruct, D3.10)",D3.10,elyza-7B-instruct,0.0







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D3.100)",D3.100,line-4B-instruct,0.02
"(calm-7B, D3.100)",D3.100,calm-7B,0.0
"(rinna-4B-instruct, D3.100)",D3.100,rinna-4B-instruct,0.02
"(stablelm-7B, D3.100)",D3.100,stablelm-7B,0.0
"(elyza-7B-instruct, D3.100)",D3.100,elyza-7B-instruct,0.02







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D3.1000)",D3.1000,line-4B-instruct,0.02
"(calm-7B, D3.1000)",D3.1000,calm-7B,0.04
"(rinna-4B-instruct, D3.1000)",D3.1000,rinna-4B-instruct,0.02
"(stablelm-7B, D3.1000)",D3.1000,stablelm-7B,0.04
"(elyza-7B-instruct, D3.1000)",D3.1000,elyza-7B-instruct,0.05







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D3.10000)",D3.10000,line-4B-instruct,0.36
"(calm-7B, D3.10000)",D3.10000,calm-7B,0.31
"(rinna-4B-instruct, D3.10000)",D3.10000,rinna-4B-instruct,0.29
"(stablelm-7B, D3.10000)",D3.10000,stablelm-7B,0.36
"(elyza-7B-instruct, D3.10000)",D3.10000,elyza-7B-instruct,0.19







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D5.10)",D5.10,line-4B-instruct,0.04
"(calm-7B, D5.10)",D5.10,calm-7B,0.01
"(rinna-4B-instruct, D5.10)",D5.10,rinna-4B-instruct,0.0
"(stablelm-7B, D5.10)",D5.10,stablelm-7B,0.02
"(elyza-7B-instruct, D5.10)",D5.10,elyza-7B-instruct,0.1







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D5.100)",D5.100,line-4B-instruct,0.03
"(calm-7B, D5.100)",D5.100,calm-7B,0.05
"(rinna-4B-instruct, D5.100)",D5.100,rinna-4B-instruct,0.02
"(stablelm-7B, D5.100)",D5.100,stablelm-7B,0.01
"(elyza-7B-instruct, D5.100)",D5.100,elyza-7B-instruct,0.02







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D5.1000)",D5.1000,line-4B-instruct,0.02
"(calm-7B, D5.1000)",D5.1000,calm-7B,0.05
"(rinna-4B-instruct, D5.1000)",D5.1000,rinna-4B-instruct,0.01
"(stablelm-7B, D5.1000)",D5.1000,stablelm-7B,0.02
"(elyza-7B-instruct, D5.1000)",D5.1000,elyza-7B-instruct,0.05







Unnamed: 0,task,method,prf_acc.extr
"(line-4B-instruct, D5.10000)",D5.10000,line-4B-instruct,0.26
"(calm-7B, D5.10000)",D5.10000,calm-7B,0.12
"(rinna-4B-instruct, D5.10000)",D5.10000,rinna-4B-instruct,0.1
"(stablelm-7B, D5.10000)",D5.10000,stablelm-7B,0.06
"(elyza-7B-instruct, D5.10000)",D5.10000,elyza-7B-instruct,0.15


In [12]:

pretty_dfs: Dict[str, pd.DataFrame] = OrderedDict()

for task_name, metric_df in metric_dfs.items():
    pretty_df = prettify_df(metric_df)
    pretty_df = rename_cols(pretty_df, METRIC_RENAMES)
    
    pretty_df.index = pretty_df[COL_METHOD]
    pretty_df = pretty_df.drop(columns=[COL_METHOD])
    
    pretty_df = color_by_rank(pretty_df, 'col',
                              scale_lower=COLOR_SCALE_LOWER,
                              scale_upper=COLOR_SCALE_UPPER,
                              color_lower=COLOR_PARETTE_LOWER,
                              color_upper=COLOR_PARETTE_UPPER)
    
    pretty_dfs[task_name] = pretty_df


for task_name in TASK_NAMES:
    print('\n\n')
    print(f'========================= {task_name} ========================')
    display(pretty_dfs[task_name])






Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1-.10,\cellcolor{blue!27} 36.0
calm-7B,D1-.10,\cellcolor{blue!27} 36.0
rinna-4B-instruct,D1-.10,\cellcolor{blue!31} 43.0
stablelm-7B,D1-.10,\cellcolor{blue!28} 38.0
elyza-7B-instruct,D1-.10,\cellcolor{blue!35} 49.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1-.100,\cellcolor{blue!43} 60.0
calm-7B,D1-.100,\cellcolor{blue!30} 41.0
rinna-4B-instruct,D1-.100,\cellcolor{blue!48} 68.0
stablelm-7B,D1-.100,\cellcolor{blue!35} 49.0
elyza-7B-instruct,D1-.100,\cellcolor{blue!47} 67.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1-.1000,\cellcolor{blue!62} 89.0
calm-7B,D1-.1000,\cellcolor{blue!45} 64.0
rinna-4B-instruct,D1-.1000,\cellcolor{blue!64} 92.0
stablelm-7B,D1-.1000,\cellcolor{blue!67} 97.0
elyza-7B-instruct,D1-.1000,\cellcolor{blue!68} 98.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1-.10000,\cellcolor{blue!67} 96.0
calm-7B,D1-.10000,\cellcolor{blue!53} 76.0
rinna-4B-instruct,D1-.10000,\cellcolor{blue!67} 97.0
stablelm-7B,D1-.10000,\cellcolor{blue!69} 99.0
elyza-7B-instruct,D1-.10000,\cellcolor{blue!70} 100.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1.10,\cellcolor{blue!3} 1.0
calm-7B,D1.10,\cellcolor{blue!7} 6.0
rinna-4B-instruct,D1.10,\cellcolor{blue!11} 13.0
stablelm-7B,D1.10,\cellcolor{blue!4} 2.0
elyza-7B-instruct,D1.10,\cellcolor{blue!4} 2.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1.100,\cellcolor{blue!11} 12.0
calm-7B,D1.100,\cellcolor{blue!9} 10.0
rinna-4B-instruct,D1.100,\cellcolor{blue!8} 8.0
stablelm-7B,D1.100,\cellcolor{blue!10} 11.0
elyza-7B-instruct,D1.100,\cellcolor{blue!10} 11.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1.1000,\cellcolor{blue!20} 26.0
calm-7B,D1.1000,\cellcolor{blue!11} 12.0
rinna-4B-instruct,D1.1000,\cellcolor{blue!17} 22.0
stablelm-7B,D1.1000,\cellcolor{blue!19} 24.0
elyza-7B-instruct,D1.1000,\cellcolor{blue!56} 80.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D1.10000,\cellcolor{blue!37} 52.0
calm-7B,D1.10000,\cellcolor{blue!14} 17.0
rinna-4B-instruct,D1.10000,\cellcolor{blue!37} 51.0
stablelm-7B,D1.10000,\cellcolor{blue!34} 47.0
elyza-7B-instruct,D1.10000,\cellcolor{blue!65} 94.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D3.10,\cellcolor{blue!3} 0.0
calm-7B,D3.10,\cellcolor{blue!3} 0.0
rinna-4B-instruct,D3.10,\cellcolor{blue!3} 0.0
stablelm-7B,D3.10,\cellcolor{blue!3} 0.0
elyza-7B-instruct,D3.10,\cellcolor{blue!3} 0.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D3.100,\cellcolor{blue!4} 2.0
calm-7B,D3.100,\cellcolor{blue!3} 0.0
rinna-4B-instruct,D3.100,\cellcolor{blue!4} 2.0
stablelm-7B,D3.100,\cellcolor{blue!3} 0.0
elyza-7B-instruct,D3.100,\cellcolor{blue!4} 2.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D3.1000,\cellcolor{blue!4} 2.0
calm-7B,D3.1000,\cellcolor{blue!5} 4.0
rinna-4B-instruct,D3.1000,\cellcolor{blue!4} 2.0
stablelm-7B,D3.1000,\cellcolor{blue!5} 4.0
elyza-7B-instruct,D3.1000,\cellcolor{blue!6} 5.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D3.10000,\cellcolor{blue!27} 36.0
calm-7B,D3.10000,\cellcolor{blue!23} 31.0
rinna-4B-instruct,D3.10000,\cellcolor{blue!22} 29.0
stablelm-7B,D3.10000,\cellcolor{blue!27} 36.0
elyza-7B-instruct,D3.10000,\cellcolor{blue!15} 19.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D5.10,\cellcolor{blue!5} 4.0
calm-7B,D5.10,\cellcolor{blue!3} 1.0
rinna-4B-instruct,D5.10,\cellcolor{blue!3} 0.0
stablelm-7B,D5.10,\cellcolor{blue!4} 2.0
elyza-7B-instruct,D5.10,\cellcolor{blue!9} 10.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D5.100,\cellcolor{blue!5} 3.0
calm-7B,D5.100,\cellcolor{blue!6} 5.0
rinna-4B-instruct,D5.100,\cellcolor{blue!4} 2.0
stablelm-7B,D5.100,\cellcolor{blue!3} 1.0
elyza-7B-instruct,D5.100,\cellcolor{blue!4} 2.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D5.1000,\cellcolor{blue!4} 2.0
calm-7B,D5.1000,\cellcolor{blue!6} 5.0
rinna-4B-instruct,D5.1000,\cellcolor{blue!3} 1.0
stablelm-7B,D5.1000,\cellcolor{blue!4} 2.0
elyza-7B-instruct,D5.1000,\cellcolor{blue!6} 5.0







Unnamed: 0_level_0,task,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1
line-4B-instruct,D5.10000,\cellcolor{blue!20} 26.0
calm-7B,D5.10000,\cellcolor{blue!11} 12.0
rinna-4B-instruct,D5.10000,\cellcolor{blue!9} 10.0
stablelm-7B,D5.10000,\cellcolor{blue!7} 6.0
elyza-7B-instruct,D5.10000,\cellcolor{blue!13} 15.0


In [13]:
def horizontal_concat(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    dfs = [df.copy() for df in dfs]
    # align index for horizontal concat
    # for df in dfs:
    #     df.index = range(len(df))
    return pd.concat(dfs, axis=1)

colored_concat_df = horizontal_concat(
    [pretty_df for task_name, pretty_df in pretty_dfs.items()]
)
colored_concat_df = colored_concat_df.drop(columns=[COL_TASK], axis=1)

print('    '.join([task_name for task_name in pretty_dfs.keys()]))
colored_concat_df

D1-.10    D1-.100    D1-.1000    D1-.10000    D1.10    D1.100    D1.1000    D1.10000    D3.10    D3.100    D3.1000    D3.10000    D5.10    D5.100    D5.1000    D5.10000


Unnamed: 0_level_0,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr,prf_acc.extr
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
line-4B-instruct,\cellcolor{blue!27} 36.0,\cellcolor{blue!43} 60.0,\cellcolor{blue!62} 89.0,\cellcolor{blue!67} 96.0,\cellcolor{blue!3} 1.0,\cellcolor{blue!11} 12.0,\cellcolor{blue!20} 26.0,\cellcolor{blue!37} 52.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!27} 36.0,\cellcolor{blue!5} 4.0,\cellcolor{blue!5} 3.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!20} 26.0
calm-7B,\cellcolor{blue!27} 36.0,\cellcolor{blue!30} 41.0,\cellcolor{blue!45} 64.0,\cellcolor{blue!53} 76.0,\cellcolor{blue!7} 6.0,\cellcolor{blue!9} 10.0,\cellcolor{blue!11} 12.0,\cellcolor{blue!14} 17.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!5} 4.0,\cellcolor{blue!23} 31.0,\cellcolor{blue!3} 1.0,\cellcolor{blue!6} 5.0,\cellcolor{blue!6} 5.0,\cellcolor{blue!11} 12.0
rinna-4B-instruct,\cellcolor{blue!31} 43.0,\cellcolor{blue!48} 68.0,\cellcolor{blue!64} 92.0,\cellcolor{blue!67} 97.0,\cellcolor{blue!11} 13.0,\cellcolor{blue!8} 8.0,\cellcolor{blue!17} 22.0,\cellcolor{blue!37} 51.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!22} 29.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!3} 1.0,\cellcolor{blue!9} 10.0
stablelm-7B,\cellcolor{blue!28} 38.0,\cellcolor{blue!35} 49.0,\cellcolor{blue!67} 97.0,\cellcolor{blue!69} 99.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!10} 11.0,\cellcolor{blue!19} 24.0,\cellcolor{blue!34} 47.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!5} 4.0,\cellcolor{blue!27} 36.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!3} 1.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!7} 6.0
elyza-7B-instruct,\cellcolor{blue!35} 49.0,\cellcolor{blue!47} 67.0,\cellcolor{blue!68} 98.0,\cellcolor{blue!70} 100.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!10} 11.0,\cellcolor{blue!56} 80.0,\cellcolor{blue!65} 94.0,\cellcolor{blue!3} 0.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!6} 5.0,\cellcolor{blue!15} 19.0,\cellcolor{blue!9} 10.0,\cellcolor{blue!4} 2.0,\cellcolor{blue!6} 5.0,\cellcolor{blue!13} 15.0


In [14]:
def add_gpt_row(latex_str: str) -> str:
    lines: List[str] = []
    for line in latex_str.split('\n'):
        if re.match(f'^ *{COL_METHOD}.*', line):
            gpt_line = re.sub('\&[^\&]*', '& - ', line).replace(COL_METHOD, 'GPT-4') + '  \\\\'
            lines.append(line)
            lines.append(gpt_line)
        else:
            lines.append(line)
    return '\n'.join(lines)
    
latex_str = to_latex(colored_concat_df, with_index=True)
print(add_gpt_row(latex_str))

\begin{tabular}{lllllllllllllllll}
\toprule
 & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr & prf\_acc.extr \\
method &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
GPT-4 & - & - & - & - & - & - & - & - & - & - & - & - & - & - & - & -   \\
\midrule
line-4B-instruct & \cellcolor{blue!27} 36.0 & \cellcolor{blue!43} 60.0 & \cellcolor{blue!62} 89.0 & \cellcolor{blue!67} 96.0 & \cellcolor{blue!3} 1.0 & \cellcolor{blue!11} 12.0 & \cellcolor{blue!20} 26.0 & \cellcolor{blue!37} 52.0 & \cellcolor{blue!3} 0.0 & \cellcolor{blue!4} 2.0 & \cellcolor{blue!4} 2.0 & \cellcolor{blue!27} 36.0 & \cellcolor{blue!5} 4.0 & \cellcolor{blue!5} 3.0 & \cellcolor{blue!4} 2.0 & \cellcolor{blue!20} 26.0 \\
calm-7B & \cellcolor{blue!27} 36.0 & \cellcolor{blue!30} 41.0 & \cellcolor{blue!45} 64.0 & \cellcolor{blue!53} 76.0 

In [15]:
def task_name_to_shot(task_name: str) -> str:
    return task_name.split('.')[-1]

def task_name_to_dataset_name(task_name: str) -> str:
    return '.'.join(task_name.split('.')[:-1])

num_metrics = len(METRIC_RENAMES)
num_task = len(pretty_dfs)

dataset_names = []
for task_name in pretty_dfs: 
    dataset_name = task_name_to_dataset_name(task_name)
    if dataset_name not in dataset_names:
        dataset_names.append(dataset_name)
num_datasets = len(dataset_names)
num_shot_settings = int(num_task / num_datasets)

dataset_row = '{}    &    ' + '  &  '.join([f'\multicolumn{{{num_shot_settings * num_metrics}}}{{c}}{{{dataset_name}}}'
                                            for dataset_name in dataset_names]) + '    \\\\'
dataset_underline = '    '.join(['\cmidrule(l{\\tabcolsep}r{\\tabcolsep})' + f'{{{2 + (num_shot_settings * num_metrics) * i_col}-{2 + (num_shot_settings * num_metrics) * (i_col + 1) - 1}}}'
                                 for i_col in range(num_datasets)])

shot_row = '{}    &    ' + '  &  '.join([f'\multicolumn{{{num_metrics}}}{{c}}{{{"$n$=" if i == 0 else ""}{task_name_to_shot(task_name)}}}'
                                         for i, task_name in enumerate(pretty_dfs.keys())]) + '    \\\\'
shot_underline = '    '.join(['\cmidrule(l{\\tabcolsep}r{\\tabcolsep})' + f'{{{2 + num_metrics * i_col}-{2 + num_metrics * (i_col + 1) - 1}}}'
                              for i_col in range(num_task)])

print(dataset_row.replace('_', '\_'))
print()
print(dataset_underline.replace('_', '\_'))
print()
print(shot_row.replace('_', '\_'))
print()
print(shot_underline.replace('_', '\_'))


{}    &    \multicolumn{4}{c}{D1-}  &  \multicolumn{4}{c}{D1}  &  \multicolumn{4}{c}{D3}  &  \multicolumn{4}{c}{D5}    \\

\cmidrule(l{\tabcolsep}r{\tabcolsep}){2-5}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){6-9}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){10-13}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){14-17}

{}    &    \multicolumn{1}{c}{$n$=10}  &  \multicolumn{1}{c}{100}  &  \multicolumn{1}{c}{1000}  &  \multicolumn{1}{c}{10000}  &  \multicolumn{1}{c}{10}  &  \multicolumn{1}{c}{100}  &  \multicolumn{1}{c}{1000}  &  \multicolumn{1}{c}{10000}  &  \multicolumn{1}{c}{10}  &  \multicolumn{1}{c}{100}  &  \multicolumn{1}{c}{1000}  &  \multicolumn{1}{c}{10000}  &  \multicolumn{1}{c}{10}  &  \multicolumn{1}{c}{100}  &  \multicolumn{1}{c}{1000}  &  \multicolumn{1}{c}{10000}    \\

\cmidrule(l{\tabcolsep}r{\tabcolsep}){2-2}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){3-3}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){4-4}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){5-5}    \cmidrule(l{\tabcolsep}r{\tabcols