In [None]:
%cd ../..

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

STUDENT_MODEL = "model_275.pth"

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    STUDENT_MODEL
]
DATASETS = df_metadata.index.tolist()
DATASETS.remove("ToxCast")

# Mean Performances

In [None]:


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[(STUDENT_MODEL,"student")],)

df, order = aggregate_results_with_ci(df_base)
df = df.drop([(" ","avg"), (" ","avg std")], axis=1)
style,latex = style_df_ci(df, order)

table_path = f"/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_main.tex"
with open(table_path, "w") as f:
    f.write(latex)

# Rankings

In [None]:


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[(STUDENT_MODEL,"student")],)
df_base.reset_index(inplace=True, drop=True)

step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base

In [None]:
ranked_df = get_ranked_df(df_base)

ranked_df.columns = ["embedder"] + [df_metadata.loc[c, "short_name"] for c in ranked_df.columns[1:]]

melted_ranked = ranked_df.melt(id_vars="embedder", var_name="dataset", value_name="meanrank")
order_datas =ranked_df.set_index("embedder").transpose()["student"].sort_values().index


#define colors for models so that student has more saturation
colors = sns.color_palette("husl", len(MODELS_TO_EVAL), desat=0.4)
colors[-1] = sns.color_palette("husl", len(MODELS_TO_EVAL))[-1]

cmapping = {model.replace(
    STUDENT_MODEL, "student"
): color for model, color in zip(MODELS_TO_EVAL, colors)}

fig,ax = plt.subplots(figsize=(20,3))

sns.barplot(data=melted_ranked, x="dataset", y="meanrank", hue="embedder", ax=ax, order=order_datas, palette=cmapping)
sns.despine()
#rotate x labels
plt.xticks(rotation=45)
#Move legend outside
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1)

In [None]:
df_ranked = get_ranked_df(df_base)


style, latex = style_df_ranked(df_ranked, order)


style


In [None]:
col_format = "r|"
over_cols = None
for col in style.columns:
    col_format += "|"
    col_format += "c"


latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)


table_path = "/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_all_rank.tex"
with open(table_path, "w") as f:
    f.write(latex)