In [None]:
%cd ../..

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    STUDENT_MODEL,
    ZINC_MODEL
]
DATASETS = df_metadata.index.tolist()
DATASETS.remove("ToxCast")

# Mean Performances

In [None]:


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (ZINC_MODEL, "student-small"),
    ],
)

df, order = aggregate_results_with_ci(df_base)
df = df.drop([(" ","avg"), (" ","avg std")], axis=1)
style,latex = style_df_ci(df, order)

table_path = f"/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_main.tex"
with open(table_path, "w") as f:
    f.write(latex)

# Rankings

In [None]:


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (ZINC_MODEL, "student-small"),
    ],
)
df_base.reset_index(inplace=True, drop=True)

step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base

In [None]:
df_ranked = get_ranked_df(df_base)


style, latex = style_df_ranked(df_ranked, order)


style


In [None]:
col_format = "r|"
over_cols = None
for col in style.columns:
    if col == "Avg":
        col_format += "|"
    col_format += "c"

latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)
latex = add_hline(latex, 1)
latex = add_hline(latex, -2)

table_path = "/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_all_rank.tex"
with open(table_path, "w") as f:
    f.write(latex)

In [None]:

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    STUDENT_MODEL,
]
DATASETS = df_metadata.index.tolist()
DATASETS.remove("ToxCast")

In [None]:
df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (ZINC_MODEL, "student-small"),
    ],
)
df, order = aggregate_results_with_ci(df_base)

df_base.reset_index(inplace=True, drop=True)

step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

In [None]:
ranked_df = get_ranked_df(df_base)

ranked_df.columns = ["embedder"] + [df_metadata.loc[c, "short_name"] for c in ranked_df.columns[1:]]
ranked_df

In [None]:

from matplotlib.patches import Rectangle
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{lmodern}')



REG_DATASETS = df_metadata[df_metadata["task_type"] == "reg"].short_name.tolist()
REG_DATASET = [d for d in REG_DATASETS if d in ranked_df.columns]
CLS_DATASET = df_metadata[df_metadata["task_type"] == "cls"].short_name.tolist()
CLS_DATASET = [d for d in CLS_DATASET if d in ranked_df.columns]


df_plot = ranked_df.set_index("embedder").loc[order[::-1]].transpose()

df_plot.loc["Average (reg)"] = df_plot.loc[REG_DATASETS].mean()
df_plot.loc["Average (cls)"] = df_plot.loc[CLS_DATASET].mean()

order_dataset = REG_DATASETS + ["Average (reg)"] + CLS_DATASET + ["Average (cls)"]
df_plot = df_plot.loc[order_dataset].transpose()


fig,ax = plt.subplots(1,1,figsize=(10.5, 3.8))



mask_min = df_plot == df_plot.min(axis = 0)
mask_too_bad = df_plot > 10


def highlight_value(data):
    if str(np.round(data,1))[-1] != "0":
        return r'\textbf{\underline{' + str(np.round(data,1)) + '}}'
    else:
        return r'\textbf{\underline{' + str(int(data)) + '}}'

bold_values = np.array(
    [
        highlight_value(data) for data in df_plot.to_numpy().ravel()
    ]
).reshape(
            np.shape(df_plot)
)
common_kwargs = {
    "cmap": "flare",
    "cbar": False,
    "vmin": 1.5,
    "vmax": 10,
    "annot_kws": {"color": "white"}
}

sns.heatmap(df_plot, mask = mask_min | mask_too_bad, annot=True, **common_kwargs)
sns.heatmap(df_plot, mask = ~mask_too_bad, annot=False, **common_kwargs)
sns.heatmap(
    df_plot,
    mask = ~mask_min,
    annot= bold_values,
    fmt='',
    **common_kwargs)

#sns.heatmap(df_plot, cmap="flare", annot=True, cbar=False, vmin=1.5, vmax = 8, annot_kws={"color": "white"})

# Rotate the tick labels for the x-axis
plt.xticks(rotation=45, ha = "right")
plt.xlabel("")
plt.ylabel("")

# Separate the last row (average) from the others
plt.axvline(x=df_plot.shape[1] - 1, color="w", linewidth=1.5)
plt.axvline(x=df_plot.shape[1] - len(CLS_DATASET)-1, color="w", linewidth=1.5)
plt.axvline(x=df_plot.shape[1] - len(CLS_DATASET)-2, color="w", linewidth=1.5)

plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/hmp_rankings.png", bbox_inches="tight")