In [1]:
from pprint import pprint
from collections import defaultdict
import json
import re
import sys
from typing import Union, List, Optional, Set, Tuple, Dict, Optional, Callable
from pprint import pprint
from IPython.display import display, HTML
from collections import OrderedDict

import numpy as np
from lab.utils import shorten
import pandas as pd
from pathlib import Path
from machine_learning.analysis.dataframe import (
    pivot_rotate,

    slice_rows,
    slice_cols,
    sort_rows,
    sort_cols,
    aggregate,
    percentize,
    round,

    rename_index,
    rename_cols,
    rename_cells,

    isnan,
    to_latex,
    color_by_rank,
)
from machine_learning.analysis.series import (
    maybe_numeric_series,
)
from machine_learning.analysis.utils import (
    maybe_round,
)
pd.set_option('display.max_columns', 300)
pd.set_option('display.max_rows', 300)
pd.set_option('display.max_colwidth', 1000)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def find_result_paths(top_dir: Union[str, Path], regexps: Optional[List[str]] = None) -> List[str]:
    regexps = regexps or []
    top_dir = Path(top_dir)
    return [
        str(path) for path in top_dir.glob('**/results.tsv')
        if all(re.match(regexp, str(path)) for regexp in regexps)
    ]

In [3]:
def name_method(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
            
    df[COL_METHOD] = df.apply(
        lambda row: '__'.join(str(row[col]) for col in METHOD_DEFINE_COLS),
        axis=1,
    )

    return df

def name_task(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    df[COL_TASK] = df.apply(
        lambda row: '__'.join(str(row[col]) for col in TASK_DEFINE_COLS),
        axis=1,
    )

    return df

def prettify_df(df: pd.DataFrame) -> pd.DataFrame:
    df = percentize(df)
    df = round(df)
    return df

In [4]:
# --------------- LREC_2024_submission ---------------
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230826.jpn/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230901.overfit/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230904.LLM_FS/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230905.LLM_FS/'

# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230910.preliminary'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230916.jpn/'

# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230916.jpn.FT/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230919.jpn/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230919.jpn.seed--1/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230919.jpn.seed--0-1/'   # LREC_2024_submission!
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20230919.jpn.seed--1/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231005.jpn.seed--0'

# --------------- NLP_2024 ---------------
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231203.jpn/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231203.jpn.no_subproof_for_unknown20231203.jpn/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231206.new_models/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231213.jpn'
_TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231213.jpn.seed--1/'

# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231213.jpn.seed--0-1/'
# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231223.seed--1.timeout_fix'

# _TOP_DIR = '../outputs/02.aggregate_tf_results.py/20231226.jpn.epoch--10'

In [5]:
# version = 'LREC_2024_submission'
version = 'NLP_2024'

In [6]:
if version == 'NLP_2024':
    COL_DATASET = 'FLD_dataset_uname'

    METHOD_RENAMES = OrderedDict([
        ('^retrieva-jp/t5-base-long$', 'retrieva-t5-base'),
        ('^retrieva-jp/t5-xl$', 'retrieva-t5-xl'),
    
        # ('^line-corporation/japanese-large-lm-1.7b$', 'line-1B'),
        # ('^line-corporation/japanese-large-lm-1.7b-instruction-sft$', 'line-1B-instruct'),
        ('^line-corporation/japanese-large-lm-3.6b$', 'line-4B'),
        # ('^line-corporation/japanese-large-lm-3.6b-instruction-sft$', 'line-4B-instruct'),
       
        ('^rinna/japanese-gpt-neox-3.6b$', 'rinna-4B'),
        # ('^rinna/japanese-gpt-neox-3.6b-instruction-ppo$', 'rinna-4B-instruct'),
        
        # ('^cyberagent/open-calm-medium$', 'calm-0.4B'),
        # ('^cyberagent/open-calm-1b$', 'calm-1B'),
        # ('^cyberagent/open-calm-3b$', 'calm-3B'),
        ('^cyberagent/open-calm-7b$', 'calm-7B'),
        ('^cyberagent/calm2-7b$', 'calm2-7B'),
        # ('^cyberagent/calm2-7b-chat$', 'calm2-7B-instruct'),   # the training fails somehow.
        
        ('^stabilityai/japanese-stablelm-base-alpha-7b$', 'stablelm-7B'),
        # ('^stabilityai/japanese-stablelm-instruct-alpha-7b-v2$', 'stablelm-7B-instruct'),
        
        ('^elyza/ELYZA-japanese-Llama-2-7b-fast$', 'elyza-7B'),
        # ('^elyza/ELYZA-japanese-Llama-2-7b-fast-instruct$', 'elyza-7B-instruct'),

        ('^elyza/ELYZA-japanese-Llama-2-13b-fast$', 'elyza-7B'),
        # ('^elyza/ELYZA-japanese-Llama-2-13b-fast-instruct$', 'elyza-7B-instruct'),
        
        ('^matsuo-lab/weblab-10b$$', 'weblab-10B'),
        # ('^matsuo-lab/weblab-10b-instruction-sft$', 'weblab-10B-instruct'),

        ('^elyza/ELYZA-japanese-Llama-2-13b-fast$', 'elyza-7B'),
        # ('^elyza/ELYZA-japanese-Llama-2-13b-fast-instruct$', 'elyza-7B-instruct'),
        
        ('^stockmark/stockmark-13b$', 'stockmark-13B'),
        ('^pfnet/plamo-13b$', 'plamo-13B'),
        ('^llm-jp/llm-jp-13b-v1.0$', 'llmjp-13B'),
        # ('^llm-jp/llm-jp-13b-instruct-full-jaster-v1.0$', 'llmjp-13B-instruct'),

        ('^tokyotech-llm/Swallow-13b-hf$', 'swallow-13b'),
        # ('^tokyotech-llm/Swallow-13b-instruct-hf$', 'swallow-13b-instruct'),
        
    ])
    
    LRATE = 1e-05
    DO_SLICE_BY_LRATE = True
 
    NG_TASK_REGEXPS = [
        '.*\.10$',  # exclude shot-10
    ]   
    
    METRIC_RENAMES = OrderedDict([
        # ('train/FLD_proof_eval_strct.D-all.proof_accuracy.zero_one', 'prf.strct'),
        ('train/FLD_proof_eval_strct.D-all.answer_accuracy', 'ans'),
    ])


elif version == 'LREC_2024_submission':
    COL_DATASET = 'dataset_uname'

    METHOD_RENAMES = OrderedDict([
        ('^retrieva-jp/t5-base-long$', 'retrieva-t5-base'),
        ('^retrieva-jp/t5-xl$', 'retrieva-t5-xl'),
    
        # ('^line-corporation/japanese-large-lm-1.7b$', 'line-1B'),
        # ('^line-corporation/japanese-large-lm-1.7b-instruction-sft$', 'line-1B-instruct'),
        ('^line-corporation/japanese-large-lm-3.6b$', 'line-4B'),
        ('^line-corporation/japanese-large-lm-3.6b-instruction-sft$', 'line-4B-instruct'),
       
        ('^rinna/japanese-gpt-neox-3.6b$', 'rinna-4B'),
        ('^rinna/japanese-gpt-neox-3.6b-instruction-ppo$', 'rinna-4B-instruct'),
        
        # ('^cyberagent/open-calm-medium$', 'calm-0.4B'),
        # ('^cyberagent/open-calm-1b$', 'calm-1B'),
        # ('^cyberagent/open-calm-3b$', 'calm-3B'),
        ('^cyberagent/open-calm-7b$', 'calm-7B'),
        
        ('^stabilityai/japanese-stablelm-base-alpha-7b$', 'stablelm-7B'),
        
        ('^elyza/ELYZA-japanese-Llama-2-7b-fast$', 'elyza-7B'),
        ('^elyza/ELYZA-japanese-Llama-2-7b-fast-instruct$', 'elyza-7B-instruct'),
        
        ('^matsuo-lab/weblab-10b$$', 'weblab-10B'),
        ('^matsuo-lab/weblab-10b-instruction-sft$', 'weblab-10B-instruct'),
        
        ('^pfnet/plamo-13b$', 'plamo-13B'),

    ])
    
    LRATE = 1e-05
    DO_SLICE_BY_LRATE = False

    NG_TASK_REGEXPS = []
    
    METRIC_RENAMES = OrderedDict([
        # ('eval/extr_stps.D-all.proof_accuracy.zero_one', 'prf.extr'),
        ('eval/strct.D-all.proof_accuracy.zero_one', 'prf.strct'),
        # ('eval/strct.D-all.answer_accuracy', 'ans'),
    ])


else:
    raise ValueError()

COL_LEARNING = 'learning'
COL_MODEL_NAME_OR_PATH = 'model_name_or_path'
COL_LRATE = 'learning_rate'

TASK_DEFINE_COLS = [COL_DATASET, COL_LEARNING]
COL_TASK = 'task'

METHOD_DEFINE_COLS = [COL_MODEL_NAME_OR_PATH]
COL_METHOD = 'method'

METRIC_NAMES = list(METRIC_RENAMES.values())

COLOR_SCALE_LOWER = 0
COLOR_SCALE_UPPER = 100
COLOR_PARETTE_LOWER = 3
COLOR_PARETTE_UPPER = 70

DARK = True

In [7]:
result_paths = find_result_paths(_TOP_DIR)
if len(result_paths) == 0:
    raise Exception(f'Results not found under {_TOP_DIR}')
elif len(result_paths) == 1:
    results_path = result_paths[0]
else:
    print('Choose the result fomr the following paths:')
    pprint(result_paths)
    results_path = input('path = ')

In [8]:
print(f'loading results from {str(results_path)}')
master_df = pd.read_csv(results_path, sep='\t')
# master_df

loading results from ../outputs/02.aggregate_tf_results.py/20231213.jpn.seed--1/results.tsv


In [9]:
df = name_method(master_df)
df = sort_cols(df, [COL_METHOD, '.*'])

df = rename_cells(df, [COL_METHOD], METHOD_RENAMES)
df = sort_rows(df, COL_METHOD, [f'^{name}$' for name in METHOD_RENAMES.values()])


if DO_SLICE_BY_LRATE:
    df = slice_rows(df, lambda row: row[COL_LRATE] == LRATE)
    if len(df) == 0:
        raise ValueError()


def rename_task(task_name: str):
    task_name = re.sub('_wo_dist', '-', task_name)
    task_name = re.sub('.*jpn.', '', task_name)
    task_name = re.sub('__LLM_FS.shot-', '.', task_name)
    return task_name

df = name_task(df)
df = sort_cols(df, [COL_TASK, '.*'])

task_renames = [(f'^{task_name}$', rename_task(task_name))
                for task_name in df[COL_TASK].unique()]
task_renames = sorted(
    task_renames,
    key = lambda name_rename: (
        name_rename[1] if name_rename[1].find('.5') < 0
        else name_rename[1].replace('.5', '.05')  # shot-5 should be the first
    )
)
task_rename_dict = OrderedDict(task_renames)
df = rename_cells(df, [COL_TASK], task_rename_dict)

TASK_NAMES = [task_rename for task_name, task_rename in task_rename_dict.items()
              if task_rename in df[COL_TASK].values]
for task_name in df[COL_TASK].values:
    if task_name not in TASK_NAMES:
        TASK_NAMES.append(task_name)
TASK_NAMES = [task_name for task_name in TASK_NAMES
              if not any(re.match(ng_task_regexp, task_name) for ng_task_regexp in NG_TASK_REGEXPS)]
MAJOR_TASK = TASK_NAMES[0]
 
df = rename_cols(df, METRIC_RENAMES)
                  
df = aggregate(
    df,
    [COL_METHOD, COL_TASK],
    {metric_name: lambda vals: np.mean([val for val in vals if not isnan(val)]) for metric_name in METRIC_NAMES},
)

TASK_NAMES

['D1-.5',
 'D1-.100',
 'D1-.1000',
 'D1-.10000',
 'D1-.30000',
 'D1.5',
 'D1.100',
 'D1.1000',
 'D1.10000',
 'D1.30000',
 'D3.5',
 'D3.100',
 'D3.1000',
 'D3.10000',
 'D3.30000',
 'D8.5',
 'D8.100',
 'D8.1000',
 'D8.10000',
 'D8.30000']

In [10]:
task_dfs: Dict[str, pd.DataFrame] = OrderedDict()
for task_name in TASK_NAMES:
    task_dfs[task_name] = slice_rows(
        df,
        lambda row: row[COL_TASK] == task_name
    )

# task_dfs[MAJOR_TASK]

In [11]:
metric_dfs: Dict[str, pd.DataFrame] = OrderedDict()
for task_name, task_df in task_dfs.items():
    metric_dfs[task_name] = slice_cols(task_df, [COL_TASK, COL_METHOD] + METRIC_NAMES)

# for task_name in TASK_NAMES:
#     print('\n\n')
#     print(f'========================= {task_name} ========================')
#     display(metric_dfs[task_name])

In [12]:
pretty_dfs: Dict[str, pd.DataFrame] = OrderedDict()

for task_name, metric_df in metric_dfs.items():
    pretty_df = prettify_df(metric_df)
    pretty_df = rename_cols(pretty_df, METRIC_RENAMES)
    
    pretty_df.index = pretty_df[COL_METHOD]
    pretty_df = pretty_df.drop(columns=[COL_METHOD])
    
    pretty_df = color_by_rank(pretty_df,
                              'col',
                              scale_lower=COLOR_SCALE_LOWER,
                              scale_upper=COLOR_SCALE_UPPER,
                              color_lower=COLOR_PARETTE_LOWER,
                              color_upper=COLOR_PARETTE_UPPER)
    
    pretty_dfs[task_name] = pretty_df

# for task_name in TASK_NAMES:
#     print('\n\n')
#     print(f'========================= {task_name} ========================')
#     display(pretty_dfs[task_name])

In [13]:
def horizontal_concat(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    dfs = [df.copy() for df in dfs]
    # align index for horizontal concat
    # for df in dfs:
    #     df.index = range(len(df))
    return pd.concat(dfs, axis=1)

colored_concat_df = horizontal_concat(
    [pretty_df for task_name, pretty_df in pretty_dfs.items()]
)
colored_concat_df = colored_concat_df.drop(columns=[COL_TASK], axis=1)

# print('    '.join([task_name for task_name in pretty_dfs.keys()]))
# colored_concat_df
# colored_concat_df


# The latex outputs

## The upper part of the table

## The values of the table

In [14]:
def add_gpt_row(latex_str: str) -> str:
    lines: List[str] = []
    for line in latex_str.split('\n'):
        if re.match(f'^ *{COL_METHOD}.*', line):
            gpt_line = re.sub('\&[^\&]*', '& - ', line).replace(COL_METHOD, 'GPT-4') + '  \\\\'
            lines.append(line)
            lines.append(gpt_line)
        else:
            lines.append(line)
    return '\n'.join(lines)
    
latex_str = to_latex(colored_concat_df, with_index=True)
print(add_gpt_row(latex_str))

\begin{tabular}{lllllllllllllllllllll}
\toprule
 & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans & ans \\
method &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
GPT-4 & - & - & - & - & - & - & - & - & - & - & - & - & - & - & - & - & - & - & - & -   \\
\midrule
line-4B & \cellcolor{blue!25} 33.9 & \cellcolor{blue!45} 63.8 & \cellcolor{blue!63} 90.7 & \cellcolor{blue!68} 97.3 & \cellcolor{blue!69} 99.7 & \cellcolor{blue!30} 41.2 & \cellcolor{blue!30} 40.9 & \cellcolor{blue!41} 57.8 & \cellcolor{blue!50} 70.4 & \cellcolor{blue!62} 88.4 & \cellcolor{blue!25} 33.9 & \cellcolor{blue!28} 37.8 & \cellcolor{blue!33} 44.9 & \cellcolor{blue!29} 38.9 & \cellcolor{blue!38} 52.5 & \cellcolor{blue!22} 28.5 & \cellcolor{blue!28} 37.5 & \cellcolor{blue!31} 43.2 & \cellcolor{blue!24} 32.2 & \cellcolor{blue!26} 35.5 \\
rinna-4B & \cellcolor{blue!25} 34.2 & \cellcolor{blue!40} 56.5 & \cellcolor{blue!66} 94.4 & \cellcolor{blue

In [15]:
def task_name_to_shot(task_name: str) -> str:
    return re.sub('000([^0-9]*)$', ',000\g<1>', task_name.split('.')[-1])

def task_name_to_dataset_name(task_name: str) -> str:
    return '.'.join(task_name.split('.')[:-1])

num_metrics = len(METRIC_RENAMES)
num_task = len(pretty_dfs)

dataset_names = []
for task_name in pretty_dfs: 
    dataset_name = task_name_to_dataset_name(task_name)
    if dataset_name not in dataset_names:
        dataset_names.append(dataset_name)
num_datasets = len(dataset_names)
num_shot_settings = int(num_task / num_datasets)

dataset_row = '{}    &    ' + '  &  '.join([f'\multicolumn{{{num_shot_settings * num_metrics}}}{{c}}{{{dataset_name}}}'
                                            for dataset_name in dataset_names]) + '    \\\\'
dataset_underline = '    '.join(['\cmidrule(l{\\tabcolsep}r{\\tabcolsep})' + f'{{{2 + (num_shot_settings * num_metrics) * i_col}-{2 + (num_shot_settings * num_metrics) * (i_col + 1) - 1}}}'
                                 for i_col in range(num_datasets)])

shot_row = '{}    &    ' + '  &  '.join([f'\multicolumn{{{num_metrics}}}{{c}}{{\\scriptsize {"$n$=" if i == 0 else ""}{task_name_to_shot(task_name)}}}'
                                         for i, task_name in enumerate(pretty_dfs.keys())]) + '    \\\\'
shot_underline = '    '.join(['\cmidrule(l{\\tabcolsep}r{\\tabcolsep})' + f'{{{2 + num_metrics * i_col}-{2 + num_metrics * (i_col + 1) - 1}}}'
                              for i_col in range(num_task)])

print('\\toprule')
print(dataset_row.replace('_', '\_'))
print()
print(dataset_underline.replace('_', '\_'))
print()
print(shot_row.replace('_', '\_'))
print()
print(shot_underline.replace('_', '\_'))


\toprule
{}    &    \multicolumn{5}{c}{D1-}  &  \multicolumn{5}{c}{D1}  &  \multicolumn{5}{c}{D3}  &  \multicolumn{5}{c}{D8}    \\

\cmidrule(l{\tabcolsep}r{\tabcolsep}){2-6}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){7-11}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){12-16}    \cmidrule(l{\tabcolsep}r{\tabcolsep}){17-21}

{}    &    \multicolumn{1}{c}{\scriptsize $n$=5}  &  \multicolumn{1}{c}{\scriptsize 100}  &  \multicolumn{1}{c}{\scriptsize 1,000}  &  \multicolumn{1}{c}{\scriptsize 10,000}  &  \multicolumn{1}{c}{\scriptsize 30,000}  &  \multicolumn{1}{c}{\scriptsize 5}  &  \multicolumn{1}{c}{\scriptsize 100}  &  \multicolumn{1}{c}{\scriptsize 1,000}  &  \multicolumn{1}{c}{\scriptsize 10,000}  &  \multicolumn{1}{c}{\scriptsize 30,000}  &  \multicolumn{1}{c}{\scriptsize 5}  &  \multicolumn{1}{c}{\scriptsize 100}  &  \multicolumn{1}{c}{\scriptsize 1,000}  &  \multicolumn{1}{c}{\scriptsize 10,000}  &  \multicolumn{1}{c}{\scriptsize 30,000}  &  \multicolumn{1}{c}{\scriptsize 5}  &  \multicolu