In [None]:
%cd ..

In [None]:
import os
import numpy as np
from molecule.utils.utils_notebook import get_MI_df, get_ranked_df, LATEX_FIG_PATH, get_DTI_rank_df, process_dataset_name, prerpocess_emb_name
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output

DATASET = "ZINC"
results_dir_list = ["run_4"]
COLUMS_SPLIT = "cond_modes"
METRIC = r"$\overline{\mathcal{I}_S}$"

model_groups = {
    'ChemBert-MLM': ['ChemBertMLM-10M', 'ChemBertMLM-5M', 'ChemBertMLM-77M'],
    'ChemBert-MTR': ['ChemBertMTR-77M', 'ChemBertMTR-10M', 'ChemBertMTR-5M'],
    '3D': ['DenoisingPretrainingPQCMv4', 'FRAD_QM9'],
    'MolR': ['MolR_tag', 'MolR_gcn', 'MolR_gat'],
    'MoleOOD': ['MoleOOD_OGB_GCN', 'MoleOOD_OGB_SAGE', 'MoleOOD_OGB_GIN']
}
for key in model_groups:
    model_groups[key] = [prerpocess_emb_name(x) for x in model_groups[key]]


df = get_MI_df(DATASET, results_dir_list)
df_downs = get_ranked_df(df, path="results/TDC_ADMET_SCAFF.csv")
df_downs.embedder = df_downs.embedder.apply(prerpocess_emb_name)
df_downs["models_group"] = np.nan
for key,value in model_groups.items():
    df_downs.loc[df_downs.embedder.isin(value), "models_group"] = key
df_downs = df_downs.dropna()
clear_output()
df_downs

# Comparison of the rankings of similar models

In [None]:
cmap = sns.color_palette("husl", n_colors=df_downs.embedder.nunique())
cmap = {embedder:color for embedder,color in zip(df_downs.embedder.unique(),cmap)}

lim_low = 0.7
lim_high = 1.05

In [None]:
fig,axes = plt.subplots(df_downs.models_group.nunique(),2, figsize=(2.2,7.3))

for i,name in enumerate(model_groups.keys()):
    df_tmp = df_downs[df_downs.models_group == name]

    min_rank_downs = df_tmp.global_meanrank_metric.max()
    max_rank_downs = df_tmp.global_meanrank_metric.min()
    diff_rank_downs = max_rank_downs - min_rank_downs
    min_rank_info = df_tmp.meanrank_information.max()
    max_rank_info = df_tmp.meanrank_information.min()
    diff_rank_info = max_rank_info - min_rank_info

    #Barplot of the downstream task
    sns.barplot(data=df_tmp, x="global_meanrank_metric", y="models_group", hue="embedder", ax=axes[i,1], dodge=True, palette=cmap, legend=False, hue_order=model_groups[name], errorbar=None)
    #remove y label and yticks
    axes[i,1].set_ylabel("")
    axes[i,1].set_yticks([])
    min_val = int(df_tmp["global_meanrank_metric"].min() - 1)
    max_val = int(df_tmp["global_meanrank_metric"].max() + 1)

    axes[i,1].set_xticks(list(range(min_val, max_val)))
    axes[i,1].set_xticklabels([])
    axes[i,1].set_xlabel("")
    axes[i,1].set_xlim(min_val, max_val)
    #Add arrow from lowest barplot to highest labeled with the difference
    axes[i,1].arrow(
        max_rank_downs,
        -0.2,
        -diff_rank_downs,
        0,
        head_width=0.05,
        head_length=0.1,
        length_includes_head=True,
        shape="full",
        color="black",
    )
    axes[i,1].text(
        (max_rank_downs + min_rank_downs) / 2 +0.5,
        -0.3,
        f"{-diff_rank_downs:.1f}",
        horizontalalignment='center',
        verticalalignment='center',
    )

    #Barplot of the predictivity
    sns.barplot(data=df_tmp, x="meanrank_information", y="models_group", hue="embedder", ax=axes[i,0], dodge=True, palette=cmap, legend=False, hue_order=model_groups[name], errorbar=None)


    axes[i,0].set_ylabel(name)
    #remove y label and yticks
    axes[i,0].set_yticks([])
    min_val = int(df_tmp["meanrank_information"].min() - 1)
    max_val = int(df_tmp["meanrank_information"].max() + 1)
    axes[i,0].set_xticks(list(range(min_val, max_val)))
    axes[i,0].set_xticklabels([])
    axes[i,0].set_xlabel("")
    #Invert x axis

    axes[i,0].set_xlim(min_val, max_val)
    axes[i,0].invert_xaxis()
    #Add arrow from lowest barplot to highest labeled with the difference
    axes[i,0].arrow(
        max_rank_info,
        -0.2,
        -diff_rank_info,
        0,
        head_width=0.05,
        head_length=0.1,
        length_includes_head=True,
        shape="full",
        color="black",
    )
    axes[i,0].text(
        (max_rank_info + min_rank_info) / 2 +0.0,
        -0.3,
        f"{-diff_rank_info:.1f}",
        horizontalalignment='center',
        verticalalignment='center',
    )


    #legend outside of the whole cmap with rounds
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=embedder, markerfacecolor=cmap[embedder], markersize=10) for embedder in model_groups[name]]
    axes[i,1].legend(handles=handles, bbox_to_anchor=(1, 1), loc='upper left')

axes[-1,1].set_xlabel("Task\n rank "+r"($\leftarrow$)")
axes[-1,0].set_xlabel("EMIR\n rank "+r"($\rightarrow$)")
plt.subplots_adjust(wspace=0.02, hspace=0.1)

plt.savefig(f"{LATEX_FIG_PATH}/molecule/meanrank_models_group_vert.pdf", format = "pdf", bbox_inches = 'tight')