In [None]:
import json 
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
from helper import save_or_show, load_model_configs_and_allowed_models
from constants import exclude_models

In [None]:
SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/model_categories')
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config_wo_barlowtwins_n_alignment.json',
    exclude_models=exclude_models,
    exclude_alignment=True,
)

print(model_configs.shape, len(allowed_models))

In [None]:
model_categories = ['objective', 'architecture_class', 'dataset_class', 'size_class']
model_cat_mapping = {'objective':'Objective', 'architecture_class':'Architecture', 'dataset_class':'DS size', 'size_class':'Model size'}

In [None]:
# for mcat in model_categories:
#     print(model_cat_mapping[mcat])
#     print(model_configs[mcat].value_counts().sort_index())
#     print()

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(3*3, 2*3))
cnt = 0 
for i, mc1 in enumerate(model_categories):
    for j, mc2 in enumerate(model_categories):
        if i>=j:
            continue
        tmp = pd.pivot_table(model_configs.reset_index(), index=mc1, columns=mc2, values='index', aggfunc='count')
        tmp[tmp.isna()] = 0
        ax = axs[cnt//3, cnt%3]
        sns.heatmap(tmp, annot=True, ax=ax, cmap='Purples', cbar=False, annot_kws={'size': 13})
        ax.set_title(f"{model_cat_mapping[mc1]} vs {model_cat_mapping[mc2]}", fontsize=11)
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.tick_params('x', rotation=90, labelsize=11)
        ax.tick_params('y', rotation=0, labelsize=11)
        cnt += 1
fig.tight_layout()
save_or_show(fig, storing_path / f'frequency_model_categories_wo_alignment_no_model_duplicates.pdf', SAVE)