In [None]:
%cd ../..

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    STUDENT_MODEL
]
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()


# Mean Performances

In [None]:

df_base= get_all_results(MODELS_TO_EVAL + [ZINC_MODEL], "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (ZINC_MODEL, "student-small"),
    ],
)

df, order = aggregate_results_with_ci(df_base)
df

In [None]:
df = df.drop([(" ","avg"), (" ","avg std")], axis=1)
style,latex = style_df_ci(df, order)

table_path = f"/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/all_reg.tex"
with open(table_path, "w") as f:
    f.write(latex)

# Rankings

In [None]:
df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    ],
)

df, order = aggregate_results_with_ci(df_base)
df_base.reset_index(inplace=True, drop=True)

step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base

In [None]:
ranked_df = get_ranked_df(df_base)

ranked_df.columns = ["embedder"] + [df_metadata.loc[c, "short_name"] for c in ranked_df.columns[1:]]

ranked_df

In [None]:
df_plot = ranked_df.set_index("embedder").loc[order[::-1]].transpose()
df_plot.loc["Average"] = df_plot.mean()

fig = plt.figure(figsize=(5, 3.5))

sns.heatmap(df_plot, cmap="flare", annot=True, fmt=".1f", cbar=False, vmin=1.5, vmax = 9)
# Rotate the tick labels for the x-axis
plt.xticks(rotation=65)

# Separate the last row (average) from the others
plt.axhline(y=df_plot.shape[0] - 1, color="w", linewidth=1)
plt.xlabel("")



plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/reg_rankings.png", bbox_inches="tight")

In [None]:
df_ranked = get_ranked_df(df_base)


style, latex = style_df_ranked(df_ranked, order)


style


In [None]:
col_format = "r|"
over_cols = None
for col in style.columns:
    col_format += "|"
    col_format += "c"


latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)


table_path = "/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/reg_rankings.tex"
with open(table_path, "w") as f:
    f.write(latex)