In [None]:
%cd ../..

LATEX_PATH = "../latex/Distillation-MI-ICLR"

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    STUDENT_MODEL,
    ZINC_MODEL,
    L2_MODEL,
    COS_MODEL,
]

DATASETS = df_metadata.index.tolist()
DATASETS.remove("ToxCast")

# Mean Performances

In [None]:
df_base= get_all_results(MODELS_TO_EVAL + [ZINC_MODEL], "downstream_results", DATASETS,)

df, order = aggregate_results_with_ci(df_base)


order.remove("student-250k")
order.remove("student-2M")
order.remove("L2")
order.remove("Cosine")
order = order[::-1] + ["L2", "Cosine", "student-250k", "student-2M"]

In [None]:
style,latex = style_df_ci(df, order)

table_path = f"{LATEX_PATH}/tables/molecules/all_raw.tex"
with open(table_path, "w") as f:
    f.write(latex)

In [None]:
style

# Rankings

In [None]:
df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS)
df_base.reset_index(inplace=True, drop=True)

step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step
df_ranked = get_ranked_df(df_base)
df_ranked

In [None]:
style, latex = style_df_ranked(df_ranked, order)


style

In [None]:
col_format = "r|"
over_cols = None
for col in style.columns:
    if col == "Avg":
        col_format += "|"
    col_format += "c"

latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)
latex = add_hline(latex, 1)
latex = add_hline(latex, -4)
latex = add_hline(latex, -3)
latex = add_hline(latex, -2)

table_path = f"{LATEX_PATH}/tables/molecules/all_ranks.tex"
with open(table_path, "w") as f:
    f.write(latex)

# Heatmap

In [None]:
MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    STUDENT_MODEL,
]

DATASETS = df_metadata.index.tolist()
DATASETS.remove("ToxCast")

In [None]:
df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS)
df, order = aggregate_results_with_ci(df_base)

df_base.reset_index(inplace=True, drop=True)

step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

In [None]:
order.remove("student-2M")
order = order[::-1] + ["student-2M"]

In [None]:
ranked_df = get_ranked_df(df_base)

ranked_df.columns = ["embedder"] + [df_metadata.loc[c, "short_name"] for c in ranked_df.columns[1:]]
ranked_df

In [None]:
from matplotlib.patches import Rectangle
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{lmodern}')



REG_DATASETS = df_metadata[df_metadata["task_type"] == "reg"].short_name.tolist()
REG_DATASET = [d for d in REG_DATASETS if d in ranked_df.columns]
CLS_DATASET = df_metadata[df_metadata["task_type"] == "cls"].short_name.tolist()
CLS_DATASET = [d for d in CLS_DATASET if d in ranked_df.columns]


df_plot = ranked_df.set_index("embedder").loc[order[::-1]].transpose()

df_plot.loc["Average (reg)"] = df_plot.loc[REG_DATASETS].mean()
df_plot.loc["Average (cls)"] = df_plot.loc[CLS_DATASET].mean()

order_dataset = REG_DATASETS + ["Average (reg)"] + CLS_DATASET + ["Average (cls)"]
df_plot = df_plot.loc[order_dataset]

df_plot.columns = [x.replace("_", " ").split("-")[0].replace("student", "\\textbf{Student-2M}") for x in df_plot.columns]

def highlight_value(data):
    if str(np.round(data,1))[-1] != "0":
        return r'\textbf{\underline{' + str(np.round(data,1)) + '}}'
    else:
        return r'\textbf{\underline{' + str(int(data)) + '}}'

mask_min = (df_plot.transpose() == df_plot.min(axis=1)).transpose()

In [None]:
bold_values = np.array(
    [
        highlight_value(data) for data in df_plot.to_numpy().ravel()
    ]
).reshape(
    np.shape(df_plot)
)
common_kwargs = {
    "cbar": False,
    "vmin": 1.4,
    "vmax": 9,
    "annot_kws": {"color": "white", "fontsize": 9},
}

In [None]:
def create_heatmap(ax, df_plot, mask_min, bold_values, cmap_name = "viridis_r", desat = 0.0):
    cmap = sns.color_palette(cmap_name, as_cmap=False)
    cmap_not_min = sns.color_palette(cmap_name, as_cmap=False, desat=desat)
    sns.heatmap(df_plot, mask = mask_min, annot=True, ax = ax,cmap=cmap_not_min, **common_kwargs)
    sns.heatmap(
        df_plot,
        mask = ~mask_min,
        annot= bold_values,
        fmt='',
        ax = ax,
        cmap=cmap,
        **common_kwargs
    )

    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha = "right")
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha = "right")
    ax.set_xlabel("")
    ax.set_ylabel("")

    # reduce y ticks
    ax.tick_params(axis='y', labelsize=12)
    ax.tick_params(axis='x', labelsize=12)

In [None]:
DATASETS_TO_PLOT = [REG_DATASETS, ["Average (reg)"], CLS_DATASET, ["Average (cls)"]]

fig,axes = plt.subplots(
    len(DATASETS_TO_PLOT),
    1,
    figsize=(3.5, 8.5),
    sharex=True,
    gridspec_kw={'height_ratios': [len(d) for d in DATASETS_TO_PLOT]}
)
axes = axes.flatten()
plt.subplots_adjust(hspace=0.02)

for i,dataset_to_plot in enumerate(DATASETS_TO_PLOT):
    filter = [df_plot.index.get_loc(c) for c in dataset_to_plot]
    create_heatmap(axes[i], df_plot.loc[dataset_to_plot], mask_min.loc[dataset_to_plot], bold_values[filter], cmap_name="flare", desat=0.7)


plt.savefig(f"{LATEX_PATH}/figures/molecules/hmp_rankings.pdf", bbox_inches="tight",)
