In [None]:
%cd ../..

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    "ChemBertMTR-77M",
    "GraphMVP",
    "GraphLog",
    "GraphCL",
    "FRAD_QM9",
    "ThreeDInfomax",
    STUDENT_MODEL,
    ZINC_MODEL
]
DATASETS = df_metadata[
    (df_metadata.task_type == "reg")
].index.tolist()

# Rankings

In [None]:
df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS,)
df, order = aggregate_results_with_ci(df_base)
df_base

In [None]:
FIG_SIZE=1
fig,axes = plt.subplots(1,len(DATASETS), figsize=(
        FIG_SIZE*len(DATASETS),
        FIG_SIZE*3.
    ),
    sharey=True
)
axes = axes.flatten()
TEACHERS = [t for t in order if "{(t)}" in t]
cmap = {
    emb: sns.color_palette("husl", df_base.embedder.nunique(), desat=0.15)[i] if not "student" in emb else sns.color_palette("husl", df_base.embedder.nunique())[i] for i, emb in enumerate(order)
}


for i in range(len(DATASETS)):
    dataset = DATASETS[i]
    df_plt = df_base[df_base.dataset == dataset].set_index("embedder").loc[order[::-1]].reset_index()
    #axes[i].axvline(df_plt.groupby('embedder').median().loc[TEACHERS].metric_test.max(), color=cmap["student-large"], linestyle="--", alpha=.7)

    sns.barplot(data=df_plt, x="metric_test", y="embedder", ax=axes[i], hue="embedder", palette=cmap, hue_order=order, errorbar=None, fill=True, estimator="median", alpha=.7)
    sns.boxplot(data=df_plt, x="metric_test", y="embedder", ax=axes[i], hue="embedder", palette=cmap, hue_order=order, fill=False,fliersize=0, width=.5, linewidth=2.)
    if dataset.startswith("Clearance") or dataset.startswith("Half"):
        dataset = df_metadata.loc[dataset].short_name.replace("Clearance", "Clear.")
        axes[i].set_title(dataset, size=12)
    else:
        axes[i].set_title(dataset.split('_')[0].replace("HydrationFreeEnergy", "FreeSolv").replace("Lipophilicity", "Lipo."), size=12)
    axes[i].set_xlabel("")
    axes[i].set_ylabel("")

    axes[i].set_xlim(np.round(max(df_plt.metric_test.quantile(0.01), 0),1))

    if i == len(DATASETS)//2:
        axes[i].set_xlabel("Test $R^2$")

# Reduce xtick size
for ax in axes:
    ax.tick_params(axis='x', labelsize=8)
    ax.tick_params(axis='y', labelsize=12)

plt.tight_layout()

plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/reg_boxplot.pdf", bbox_inches="tight")