## Notebook to visualize the frequency of models in different categories

In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from constants import (
    BASE_PATH_RESULTS,
    cat_name_mapping,
    exclude_models,
    model_cat_mapping,
    model_categories,
    model_config_file,
    model_size_order,
    fontsizes
)
from helper import (
    load_model_configs_and_allowed_models,
    pp_storing_path,
    save_or_show
)

#### Storing info

In [None]:
SAVE = True
storing_path = pp_storing_path(BASE_PATH_RESULTS / f'final/3_op/app_B_model_n_configs', SAVE)

#### Load model configs

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=exclude_models,
    exclude_alignment=True,
)

print(model_configs.shape, len(allowed_models))

In [None]:
for mcat in model_categories:
    print(model_cat_mapping[mcat])
    print(model_configs[mcat].value_counts().sort_index())
    print()

In [5]:
for cat in model_categories:
    model_configs[cat] = model_configs[cat].map(cat_name_mapping)

#### Frequency of models in different categories

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(3 * 3, 2 * 3))
cnt = 0
for i, mc1 in enumerate(model_categories):
    for j, mc2 in enumerate(model_categories):
        if i >= j:
            continue
        tmp = pd.pivot_table(model_configs.reset_index(), index=mc1, columns=mc2, values='index', aggfunc='count')
        tmp[tmp.isna()] = 0
        ax = axs[cnt // 3, cnt % 3]
        if mc1 == 'size_class':
            tmp = tmp.loc[model_size_order, :]
        elif mc2 == 'size_class':
            tmp = tmp.loc[:, model_size_order]

        sns.heatmap(tmp, annot=True, ax=ax, cmap='Purples', cbar=False, annot_kws={'size': fontsizes['ticks']})

        ax.set_title(f"{model_cat_mapping[mc1]} vs. {model_cat_mapping[mc2]}", fontsize=fontsizes['ticks'])
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.tick_params('x', rotation=90, labelsize=fontsizes['ticks'])
        ax.tick_params('y', rotation=0, labelsize=fontsizes['ticks'])
        cnt += 1
fig.tight_layout()
save_or_show(fig, storing_path / f'frequency_model_categories_wo_alignment_no_model_duplicates.pdf', SAVE)