In [1]:
import pandas as pd
import numpy as np
import copy 
import os
import re
from collections import OrderedDict
from pprint import pprint
import itertools

In [2]:
data_models = ['cifar10', 'mnist', 'twenty_newsgroups','tinyimagenet100']
dm_identifier = {'cifar10': 'CIFAR-10', 
                'mnist': 'MNIST', 
                'twenty_newsgroups': '20newsgroups',
                'tinyimagenet100': 'TinyImageNet-100'}
sub_metrics = ['Error', 'Coverage']

def helper_find_files_and_read_dataframe(root_path, patterns=[r"cifar10"]):
    #dm_df= OrderedDict((dm, None) for dm in dm_identifier.keys()) 
    dm_df= OrderedDict((dm, None) for dm in dm_identifier.keys()) 
    for root, _, files in os.walk(root_path):
        for filename in files:
            filepath = os.path.join(root, filename)
            if os.path.isfile(filepath) and filepath.endswith(".xlsx"):
                for pattern in patterns:
                    if re.search(pattern, filename, re.IGNORECASE):
                        if pattern in dm_df:  # Check if the key already exists
                            dm_df[pattern] = pd.concat([dm_df[pattern], pd.read_excel(filepath, sheet_name=0).drop(columns=['Unnamed: 0'])], axis=0)
                            #display(dm_df[pattern])
                        else:
                            dm_df[pattern] = pd.read_excel(filepath, sheet_name=0).drop(columns=['Unnamed: 0']).copy(deep=True)
                        # dm_df[pattern].append(pd.read_excel(filepath, sheet_name=0).drop(columns=['Unnamed: 0']).copy(deep=True))
                        break  # Stop checking patterns for this file
    return dm_df 

def read_and_get_filtered_dataframes(root_path, patterns=[r"cifar10"]):
    dm_df = helper_find_files_and_read_dataframe(root_path, patterns = patterns)
    # Apply filter to all dataframes
    for dm, df in dm_df.items():
        if df is None:
            continue
        df1 = copy.copy(df) # Shallow copy to new dataframe
        df1['calib_conf'] = df1['calib_conf'].fillna("None")
        df1['calib_conf'] = df1['calib_conf'].astype(str)

        # Sort by col: Coverage-Mean in descending order, and then by col: calib_conf in ascending order
        df2 = df1.sort_values(["Coverage-Mean", "calib_conf"], ascending = [False, True]).copy(deep=True)
        # Retain the first row for each unique value in col: calib_conf
        df3 = df2.drop_duplicates(subset=['calib_conf', 'training_conf'], keep='first').copy(deep=True)
        dm_df[dm] = df3
    return dm_df 

dm_df = read_and_get_filtered_dataframes(
    root_path = "../../../outputs/final_results/final_results_to_tex_table", 
    patterns = dm_identifier.keys())

display(dm_df['cifar10'])

Unnamed: 0,calib_conf,training_conf,C_1,N_t,N_v,N_hyp_v,Auto-Labeling-Err-Mean,Coverage-Mean,Avg-ECE-Val-Mean,Auto-Labeling-Err-Std,...,training_conf.optimizer,training_conf.reg,training_conf.weight_decay,training_conf_g.batch_size,training_conf_g.max_epochs,training_conf_g.optimizer,training_conf_g.weight_decay,weight_decay,rank_target,rank_weight
1,auto_label_opt_v0,squentropy,0.25,10000,8000,2000,2.3387,79.0495,10.2829,0.5167,...,,,,64.0,500.0,adam,0.01,0.001,,
1,auto_label_opt_v0,std_cross_entropy,0.25,10000,8000,2000,2.9666,78.482,16.0513,0.2031,...,,,,64.0,500.0,adam,0.1,0.001,,
1,auto_label_opt_v0,crl,0.25,10000,8000,2000,2.2377,77.8595,22.8733,0.6454,...,,,,64.0,500.0,adam,0.01,0.01,softmax,0.8
1,auto_label_opt_v0,fmfp,0.25,10000,8000,2000,3.0204,77.4455,14.6882,0.4308,...,,,,64.0,500.0,adam,0.01,0.001,,
2,dirichlet,squentropy,0.25,10000,8000,2000,7.3309,29.3665,13.4658,0.3071,...,adam,0.01,,,,,,0.001,,
5,temp_scaling,squentropy,0.25,10000,8000,2000,6.9471,28.166,13.7981,0.5826,...,adam,,0.01,,,,,0.001,,
4,scaling_binning,squentropy,0.25,10000,8000,2000,6.1907,23.7565,13.3403,0.4242,...,,,0.01,,,,,0.001,,
5,temp_scaling,std_cross_entropy,0.25,10000,8000,2000,7.2664,23.157,16.9683,0.2994,...,adam,,0.1,,,,,0.001,,
2,dirichlet,std_cross_entropy,0.25,10000,8000,2000,7.681,22.3685,4.0057,0.4806,...,adam,0.1,,,,,,0.001,,
2,dirichlet,fmfp,0.25,10000,8000,2000,6.9052,21.6735,3.7992,0.3577,...,adam,0.01,,,,,,0.001,,


In [3]:
cms_ = OrderedDict({'None': '-',
                    'auto_label_opt_v0': 'Ours' ,
                    'temp_scaling': 'TS',
                    'dirichlet': 'Dirichlet',
                    'scaling_binning': 'SB',
                    'histogram_binning_top_label': 'Top-HB'})
ttms_ = OrderedDict({'std_cross_entropy': 'Vanilla',
                     'crl': 'CRL', 
                     'fmfp': 'FMFP', 
                     'squentropy': 'Squentropy'})
visited = []
body_txt= ""
bs = "\\"
num_dp = 2
global_font_size = (8,11)
std_font_size = (6, 11)
pm_factor = 0.6 

for tm, cm in itertools.product(ttms_.keys(), cms_.keys()):
    # Add post-hoc method name 
    # cross_prod_i = cm.replace("_", "\\_")
    cross_prod_i = cms_[cm]
    if tm not in visited:
        #temp_tm = tm.replace("_", "\\_")
        temp_tm = ttms_[tm]
        cross_prod_i = rf"""\multirow{{6}}{{*}}{{{temp_tm}}}                     & """ + cross_prod_i 
        visited.append(tm)
    else:
        cross_prod_i = " ".join(["                                 & ", cross_prod_i]) 

    # For each dataset, add columns for Error and Coverage 
    for dm, df in dm_df.items():
        # print(dm)
        # display(df)
        if df is not None:
            mask1 = (df["calib_conf"] == f"{cm}") & (df["training_conf"] == f"{tm}")
            al_mean = df[mask1]['Auto-Labeling-Err-Mean'].values[0] if not df[mask1]['Auto-Labeling-Err-Mean'].empty else -1 
            al_std = df[mask1]['Auto-Labeling-Err-Std'].values[0] if not df[mask1]['Auto-Labeling-Err-Std'].empty else -1 
            c_mean = df[mask1]['Coverage-Mean'].values[0] if not df[mask1]['Coverage-Mean'].empty else -1 
            c_std= df[mask1]['Coverage-Std'].values[0] if not df[mask1]['Coverage-Std'].empty else -1 

            # Determine if al_mean and c_mean should be bold
            mask2 = (df["training_conf"] == f"{tm}")
            if al_mean == df[mask2]['Auto-Labeling-Err-Mean'].min():
                al_mean_str_format, al_std_str_format = (rf"\textbf" + "{" + rf"{al_mean:.{num_dp}f}" + "}", rf"\textbf" + "{" + rf"{al_std:.{num_dp}f}" + "}")
                al_plus_minus = f"\\scalebox{{{pm_factor}}}{{\\ensuremath{{\\bm{{\\pm}}}}}}"
            else:
                al_mean_str_format, al_std_str_format = (rf"{al_mean:.{num_dp}f}", rf"{al_std:.{num_dp}f}")
                al_plus_minus = f"\\scalebox{{{pm_factor}}}{{\\ensuremath{{{bs}pm}}}}"

            if c_mean == df[mask2]['Coverage-Mean'].max():
                c_mean_str_format, c_std_str_format = (rf"\textbf" + "{" + rf"{c_mean:.{num_dp}f}" + "}",rf"\textbf" + "{" + rf"{c_std:.{num_dp}f}" + "}")
                c_plus_minus = f"\\scalebox{{{pm_factor}}}{{\\ensuremath{{\\bm{{\\pm}}}}}}"
            else: 
                c_mean_str_format, c_std_str_format = (rf"{c_mean:.{num_dp}f}", rf"{c_std:.{num_dp}f}")
                c_plus_minus = f"\\scalebox{{{pm_factor}}}{{\\ensuremath{{{bs}pm}}}}"

            
        else:
            al_mean, al_std, c_mean, c_std = -1, -1, -1, -1 
        open_std_font = "{" + f"{bs}fontsize{{{std_font_size[0]}}}{{{std_font_size[1]}}}{bs}selectfont"
        closing_std_font = "}" 
        # plus_minus = rf" \scalebox{{{pm_factor}}}" + "{" + "\ensuremath" + "{"  "${bs}pm$ " + "}" + "}"
        # plus_minus = f"\\scalebox{{{pm_factor}}}{{\\ensuremath{{{bs}pm}}}}"

        # cross_prod_i = cross_prod_i + " & " + f""" { rf"{al_mean:.{num_dp}f}" + plus_minus + open_std_font + rf"{al_std:.{num_dp}f}" } """ + closing_std_font + " & " + f""" { rf"{c_mean:.{num_dp}f}" + plus_minus + open_std_font + rf"{c_std:.{num_dp}f}" } """ + closing_std_font
        cross_prod_i = cross_prod_i + " & " + f""" { al_mean_str_format + al_plus_minus + open_std_font + al_std_str_format} """ + closing_std_font + " & " + f""" { c_mean_str_format + c_plus_minus + open_std_font + c_std_str_format} """ + closing_std_font

    if cm == list(cms_.keys())[-1] and tm == list(ttms_.keys())[-1]:
        line = rf"\bottomrule"
    elif cm == list(cms_.keys())[-1] and tm != list(ttms_.keys())[-1]:
        line = "\hline"
    else:
        line = ""
    cross_prod_i = cross_prod_i + r"\\" + line
    body_txt= body_txt+ cross_prod_i + "\n"

In [4]:


#metrics_txt = " & ".join( [ "\multicolumn{1}{c}" + "{" + rf"\textbf" + "{" + sm + "}" + "}" for sm in sub_metrics] * len(data_models))
metrics_txt = " & ".join( [ "\multicolumn{1}{c}" + "{" + rf"\textbf" + "{" + sm + "}" + "}" for sm in sub_metrics] * len(dm_identifier.keys()))
#data_models_txt = ' & ' + ' & '.join([rf"\multicolumn{{{len(sub_metrics)}}}{{c}}" + "{" + rf"\textbf" + rf"{{{dm_identifier[dm]}}}" + "}" for dm in data_models])
data_models_txt = ' & ' + ' & '.join([rf"\multicolumn{{{len(sub_metrics)}}}{{c}}" + "{" + rf"\textbf" + rf"{{{dm_identifier[dm]}}}" + "}" for dm in dm_identifier.keys()])
caption = "Example TBAL LaTeX Table"

template = rf"""
\begin{{table*}}[t]
\fontsize{{{global_font_size[0]}}}{{{global_font_size[1]}}}\selectfont
\begin{{tabular}}{{llllllllll}}
\toprule
\multicolumn{{1}}{{c}}{{\multirow{{2}}{{*}}{{\textbf{{Train-time}}}}}} & \multicolumn{{1}}{{c}}{{\multirow{{2}}{{*}}{{\textbf{{Post-hoc}}}}}} {data_models_txt} \\ \cline{{3-10}}
\multicolumn{{1}}{{c}}{{}}                      & \multicolumn{{1}}{{c}}{{}}  & {metrics_txt} \\ \toprule 
""" + body_txt + rf"""
\end{{tabular}}
\caption{{{caption}}}
\end{{table*}}"""

# TODO: Add new column shading color features 

# template = rf"""
# \begin{{table*}}[t]
# \fontsize{{{global_font_size[0]}}}{{{global_font_size[1]}}}\selectfont
# \begin{{tabular}}{{lll>{{\columncolor{{gray}}}}lllllll}}
# \toprule
# \multicolumn{{1}}{{c}}{{\multirow{{2}}{{*}}{{\textbf{{Train-time}}}}}} & \multicolumn{{1}}{{c}}{{\multirow{{2}}{{*}}{{\textbf{{Post-hoc}}}}}} {data_models_txt} \\ \cline{{3-10}}
# \multicolumn{{1}}{{c}}{{}}                      & \multicolumn{{1}}{{c}}{{}}  & {metrics_txt} \\ \toprule 
# """ + body_txt + rf"""
# \end{{tabular}}
# \caption{{{caption}}}
# \end{{table*}}"""


In [5]:
with open("./final_table_latex_template.txt", "w") as file:
    file.write(template)