In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.helpers.tables import create_table, df_to_latex, rename_df, add_hline, add_double_column_header
from src.helpers.load_save_data import load_experiment_data

import numpy as np
import pandas as pd
from IPython.display import display

# Main experiment tables

Brier-score, accurary, log-loss, confidence ECE, classwisce ECE.

In [None]:
methods = [
    "uncalibrated",
    "isotonic",
    "LECE_KL",   
    "MS",
    "iop_diag",
    "GP",
    "TS",
    "dec2TS",
    "TS_isotonic",
    "TS_LECE_KL",
]

metrics = ["bs", "ll", "conf_ece", "cw_ece", "accuracy"] 
roundings = [4, 3, 2, 3, 3]

for m_idx in range(len(metrics)):
    metric = metrics[m_idx]
    rounding = roundings[m_idx]

    df_c10 = create_table(datasets=["densenet40_c10", "resnet110_c10", "resnet_wide32_c10"],
                          metric=metric, experiment_names=methods,
                          add_avg_rank=True, rounding=rounding, n_bins=15)

    df_c100 = create_table(datasets=["densenet40_c100", "resnet110_c100", "resnet_wide32_c100"],
                          metric=metric, experiment_names=methods,
                          add_avg_rank=True, rounding=rounding, n_bins=15)

    df_combined = rename_df(df_c10).append(df_c100))

    latex = df_to_latex(df_combined)
    latex = add_hline(latex, [6, 7, 10],
                      lengths=[(2, len(methods) + 2),
                               (1, len(methods) + 2),
                               (2, len(methods) + 2)])
    latex = add_double_column_header(latex, ["", "ours", "", "ours"], [4,1,6,1])
    latex = add_hline(latex, [2],lengths=[(len(methods) + 2, len(methods) + 2)])
    latex = add_hline(latex, [2],lengths=[(5,5)])

    print(metric)
    print(latex)
    display(df_combined)

# Ablation study
Log-loss, confidence ECE, classwise ECE

In [None]:
methods = ["LECE_KL", "LECE_EUC", "LECD_KL", "TS_LECE_KL", "TS_LECE_EUC", "TS_LECD_KL"]
metrics = ["ll", "conf_ece", "cw_ece"] 
roundings = [3, 2, 3]

for m_idx in range(len(metrics)):
    metric = metrics[m_idx]
    rounding = roundings[m_idx]

    df_c10 = create_table(datasets=["densenet40_c10", "resnet110_c10", "resnet_wide32_c10"],
                          metric=metric, experiment_names=methods,
                          add_avg_rank=True, rounding=rounding, n_bins=15)

    df_c100 = create_table(datasets=["densenet40_c100", "resnet110_c100", "resnet_wide32_c100"],
                          metric=metric, experiment_names=methods,
                          add_avg_rank=True, rounding=rounding, n_bins=15)
    
    df_combined = rename_df(df_c10.append(df_c100))
    
    latex = df_to_latex(df_combined)
    latex = add_hline(latex, [6, 7, 10],
                      lengths=[(2, len(methods) + 2),
                               (1, len(methods) + 2),
                               (2, len(methods) + 2)])
    print(metric)
    print(latex)
    display(df_combined)

# Ablation study hyperparams
Neighborhood size, threshold

In [None]:
methods = ["LECE_KL", "LECE_EUC", "LECD_KL", "TS_LECE_KL", "TS_LECE_EUC", "TS_LECD_KL"]

for hyperparam in ["neighborhood_size", "threshold"]:

    df = {}

    for (dataset, model) in [("c10", "densenet40"),
                             ("c10", "resnet110"),
                             ("c10", "resnet_wide32"),
                             ("c100", "densenet40"),
                             ("c100", "resnet110"),
                             ("c100", "resnet_wide32")
                            ]:
        dataset_model = model + "_" + dataset
        df[(dataset, model)] = {}

        for method in methods:
            exp_data = load_experiment_data(dataset=dataset_model, experiment_name=method)
            df[(dataset, model)][method] = str(exp_data.cv_best_hyperparams[hyperparam])
            
    df = pd.DataFrame.from_dict(df).T
    df = rename_df(df)

    latex = df_to_latex(df)
    latex = add_hline(latex, [6], lengths=[(1, len(methods) + 2)])
    
    print(hyperparam)
    print(latex)
    display(df)
