# Make tables/figures for train-test analysis

In [None]:
from typing import List
import pickle
import pandas as pd
from collections import defaultdict

In [None]:
# load data
with open("model_train_test.p", "rb") as f:
    all_results = pickle.load(f)

# convert data
df_results = defaultdict(dict)
for n_m in all_results.keys():
    for n_d in all_results[n_m].keys():
        # convert to df -- train dataset then model
        df_results[n_d][n_m] = pd.DataFrame(all_results[n_m][n_d])

In [None]:
# define the collection of subtables to use
subtables = [
    {
        "models": ["CARLA", "nuScenes", "UGV"],
        "datasets": ["CARLA", "nuScenes", "UGV"],
    },
    {
        "models": ["CARLA", "nuScenes", "UGV"],
        "datasets": ["CARLA_adv", "nuScenes_adv", "UGV_adv"],
    },
    {
        "models": ["CARLA_adv", "nuScenes_adv", "UGV_adv"],
        "datasets": ["CARLA_adv", "nuScenes_adv", "UGV_adv"]
    },
    {
        "models": ["CARLA_adv_mc", "nuScenes_adv_mc", "UGV_adv_mc"],
        "datasets": ["CARLA_adv", "nuScenes_adv", "UGV_adv"]
    }
]


def map_name(name: str):
    return name.replace("_adv", " Adv.").replace("_mc", " MC")

In [None]:
from functools import partial


def get_dataframe_subtable(df_results: pd.DataFrame, models: List[str], datasets: List[str]):

    # preset result structure
    res_tables = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))

    # define names of things
    row_names = {n_d: f"Test: {map_name(n_d)}" for n_d in datasets}
    supercols = {n_m: f"Train: {map_name(n_m)}" for n_m in models}
    metrics = {
        "precision": "Prec.",
        "recall": "Rec.",
        "accuracy": "Acc.",
        "f1": "F1",
    }

    # make precision/recall tables
    for i_m, n_m in enumerate(models, start=0):
        for i_d, n_d in enumerate(datasets, start=0):
            # -- separate tables
            for met in metrics.keys():
                # -- compute aggregate metrics
                res_tables[n_m][metrics[met]][row_names[n_d]] = df_results[n_d][n_m][
                    met
                ].mean()
    
    # make all subtables
    df_subtables = {
        supercols[n_m]: pd.DataFrame(
            res_tables[n_m],
            columns=list(metrics.values()),
        )
        for n_m in models
    }

    # merge into one
    df_merged = pd.concat(df_subtables, axis=1)
    df_merged

    return df_merged


def subtable_to_latex(subtable: pd.DataFrame, do_print: bool = True):
    # This function returns a function which formats numbers according to the specified string operations.
    # Inputs
    # - precision: integer -> accuracy to be shown (e.g. number of decimal places if format type is "f" or number of significant digits if format type is "g")
    # - num_format: string -> format type


    def format_num(precision, num_format):
        def create_formatter(num, inner_prec, inner_form):
            return "{:.{}{}}".format(num, inner_prec, inner_form)

        return partial(create_formatter, inner_prec=precision, inner_form=num_format)


    # define formatters
    formatters = [format_num(2, "f") for _ in range(len(subtable.columns))]

    # format and print table
    pd.set_option("display.max_colwidth", 1000)
    lat_str = subtable.to_latex(
        index=True,
        formatters=formatters,
        multicolumn_format="c",
        bold_rows=False,
    )

    # get the supercols (in case needed below)
    supercols = set([col[0] for col in subtable.columns])
    metrics = set([col[1] for col in subtable.columns])

    ############################
    # custom mods
    ############################

    # -- midrule between rows
    # lat_str = lat_str.replace('\\\\\n', '\\\\ \\midrule\n')

    # -- remove first toprule
    lat_str = lat_str.replace("\\toprule\n", "")

    # -- add midrule under supertitles
    idx_midrule = lat_str.index(" & Prec.")
    midrule_str = (
        "".join(
            [
                f"\cmidrule(l){{{2+i*len(metrics)}-{1+(i+1)*len(metrics)}}}"
                for i in range(len(supercols))
            ]
        )
        + "\n"
    )
    lat_str = lat_str[:idx_midrule] + midrule_str + lat_str[idx_midrule:]

    # -- bold each column
    # TODO

    if do_print:
        print(lat_str)

    return lat_str


def collate_subtables_latex(
    latex_strings: List[str],
    captions: List[str],
    label_base: str,
    do_print: bool = True
):
    """Make the subtables"""

    # text above the subtable
    text_above = "\\begin{subtable}[c]{\\linewidth}\n\\centering"
       
    # text below the subtable
    text_below = "\\caption{{{cap}}}\n\\label{{{lab}}}\n\\end{{subtable}}"

    # add the texts
    latex_strings_mod = [
        text_above + lstr + text_below.format(cap=cap, lab=label_base.format(i))
        for i, (lstr, cap) in enumerate(zip(latex_strings, captions))
    ]

    # joining all subtables together
    lat_str = "\n%\n\\newline\\vspace{6pt}\\newline\n%\n".join(latex_strings_mod)

    # print
    if do_print:
        print(lat_str)

    return lat_str

In [None]:
lat_strs = []
captions = ["Benign", "Test on adversarial", "Train and test on adversarial", "Train and test on adversarial with MC dropout"]
label_base = "tab:train-test-{}"
for idx_sub in range(4):
    df_subtable = get_dataframe_subtable(df_results, subtables[idx_sub]["models"], subtables[idx_sub]["datasets"])
    lat_strs.append(subtable_to_latex(df_subtable, do_print=False))
collate_subtables_latex(lat_strs, captions, label_base, do_print=True)