In [None]:
%cd ..

In [None]:
import os
import numpy as np
from molecule.utils.utils_notebook import get_MI_df, get_ranked_df, LATEX_FIG_PATH, get_DTI_rank_df, process_dataset_name, prerpocess_emb_name
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output
import pandas as pd

DATASET = "ZINC"
results_dir_list = ["run_4"]
COLUMS_SPLIT = "cond_modes"
METRIC = r"$\overline{\mathcal{I}_S}$"

model_groups = {
    'ChemBert-MLM': ['ChemBertMLM-10M', 'ChemBertMLM-5M', 'ChemBertMLM-77M'],
    'ChemBert-MTR': ['ChemBertMTR-77M', 'ChemBertMTR-10M', 'ChemBertMTR-5M'],
    '3D': ['DenoisingPretrainingPQCMv4', 'FRAD_QM9'],
    'MolR': ['MolR_tag', 'MolR_gcn', 'MolR_gat'],
    'MoleOOD': ['MoleOOD_OGB_GCN', 'MoleOOD_OGB_SAGE', 'MoleOOD_OGB_GIN']
}
for key in model_groups:
    model_groups[key] = [prerpocess_emb_name(x) for x in model_groups[key]]


df = get_MI_df(DATASET, results_dir_list)

df_downs = get_ranked_df(df, path="results/TDC_ADMET_SCAFF.csv")

df_downs.embedder = df_downs.embedder.apply(prerpocess_emb_name)
df_downs["models_group"] = np.nan
for key,value in model_groups.items():
    df_downs.loc[df_downs.embedder.isin(value), "models_group"] = key
df_downs = df_downs.dropna()
clear_output()
df_downs

# Dependency to task size

In [None]:
df_corr = pd.DataFrame(columns=["dataset", "kendall", "pearson", "spearman"])

for i,dataset in enumerate(df_downs.sort_values("task_size").dataset.unique()):
    df_tmp = df_downs[(df_downs.dataset == dataset)]
    df_tmp.embedder = df_tmp.embedder.apply(prerpocess_emb_name)
    df_tmp = df_tmp.groupby("embedder").median()
    pearson = df_tmp[["metric","meanrank_information"]].corr("pearson").iloc[0,1]
    corr = df_tmp[["metric","meanrank_information"]].corr("spearman").iloc[0,1]
    corr_kendall = df_tmp[["metric","meanrank_information"]].corr("kendall").iloc[0,1]
    df_corr = df_corr.append({"dataset":dataset, "kendall":-corr_kendall, "pearson":-pearson, "spearman":-corr}, ignore_index=True)

df_downs = df_downs.join(df_corr.set_index("dataset"), on="dataset")

In [None]:
fig, axes = plt.subplots(1,3, figsize=(7,2))
df_plot = df_downs.drop_duplicates(subset=["dataset"])

sns.scatterplot(data=df_plot, x="task_size", y="pearson", hue="task_category", style="task_type", palette="husl", legend=False, ax=axes[0])
#sns.regplot(data=df_plot, x="task_size", y="pearson", ax=axes[0], scatter=False, color="blue", line_kws = {"alpha":0.2},  logx=True, x_ci="ci", ci=95,)

sns.scatterplot(data=df_plot, x="task_size", y="spearman", hue="task_category", style="task_type", palette="husl", legend=False, ax=axes[1])
#sns.regplot(data=df_plot, x="task_size", y="spearman", ax=axes[1], scatter=False, color="blue", line_kws = {"alpha":0.2},  logx=True, x_ci="ci", ci=95)

sns.scatterplot(data=df_plot, x="task_size", y="kendall", hue="task_category", style="task_type", palette="husl", legend=False, ax=axes[2])
#sns.regplot(data=df_plot, x="task_size", y="kendall", ax=axes[2], scatter=False, color="blue", line_kws = {"alpha":0.2},  logx=True, x_ci="ci", ci=95)

# Put x in log scale
axes[0].set_xscale("log")
axes[1].set_xscale("log")
axes[2].set_xscale("log")


axes[0].set_ylabel(r"$\varrho_p$")
axes[1].set_ylabel(r"$\varrho_s$")
axes[2].set_ylabel(r"$\tau$")

fig.tight_layout()

plt.savefig(f"{LATEX_FIG_PATH}/molecule/correlation_task_size_scaff.pdf", format = "pdf", bbox_inches = 'tight')