In [None]:
%cd ../..

LATEX_PATH = "../latex/Distillation-MI-ICLR"

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    STUDENT_MODEL,
    SINGLE_TEACHER_BERT,
    SINGLE_TEACHER_TDINFO,
    TWO_TEACHER,
]
DATASETS = df_metadata.index.tolist()[:-3]
DATASETS.remove("ToxCast")

len(DATASETS)

# Mean Performances Classif


In [None]:
DATASETS = df_metadata[df_metadata.task_type == "cls"].index.tolist()

DATASET_GROUP = [["Distribution", "HTS", "Absorption", " "], ["Metabolism"], ["Tox",]]


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS)
df, order = aggregate_results_with_ci(df_base)


for i, datasets in enumerate(DATASET_GROUP):
    df_group = df[[col for col in df.columns if (col[0] in datasets)]]
    style,latex = style_df_ci(df_group, order[::-1])
    table_path = f"{LATEX_PATH}/tables/molecules/sgl_cls_{i}.tex"
    latex = add_hline(latex, 1)
    latex = add_hline(latex, -1)
    with open(table_path, "w") as f:
        f.write(latex)

# Mean Performances Reg

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()

DATASET_GROUP = [[" ", "Absorption", "Tox"], ["Distribution", "Excretion"]]


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS)
df, order = aggregate_results_with_ci(df_base)


for i, datasets in enumerate(DATASET_GROUP):
    df_group = df[[col for col in df.columns if (col[0] in datasets)]]
    style,latex = style_df_ci(df_group, order[::-1])
    table_path = f"{LATEX_PATH}/tables/molecules/sgl_reg_{i}.tex"
    latex = add_hline(latex, 1)
    latex = add_hline(latex, -1)
    with open(table_path, "w") as f:
        f.write(latex)

# Figures

In [None]:
DATASETS = df_metadata.index.tolist()[:-3]
DATASETS.remove("ToxCast")

df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS).reset_index(drop=True)
df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

g = sns.catplot(
    data=df_base.dropna(),
    col="short_dataset",
    y="metric_test",
    x="embedder",
    hue = "embedder",
    kind="point",
    palette="husl",
    height=1.3,
    aspect=0.9,
    col_wrap=8,
    sharey=False,
    alpha = 0.,
    legend=False,
    errorbar=None,
    order=order,
    hue_order = order,
)
g.map(sns.lineplot, "embedder", "metric_test", errorbar=None, color="black", alpha = 0.3, linewidth=2.5)
g.map(sns.pointplot, "embedder", "metric_test", "embedder", order=order,palette="husl", errorbar=None, alpha = 1, legend=False, hue_order = order)

g.set_titles(col_template="{col_name}", row_template="Test performance")
g.set_xlabels("")
g.set_ylabels("Test perf.")
# Rotate x-ticks
g.tick_params(axis = 'x',rotation=90)
g.tick_params(axis = 'y', labelsize=8)

g.figure.subplots_adjust(wspace=0.8, hspace=0.5)

for ax, dataset in zip(g.axes, df_base.short_dataset.unique()):
    ax.set_ylim(
        df_base[df_base.short_dataset == dataset].groupby(["embedder"]).mean().metric_test.min() - 0.02,
        df_base[df_base.short_dataset == dataset].groupby(["embedder"]).mean().metric_test.max() + 0.02
    )

plt.savefig(f"{LATEX_PATH}/figures/molecules/multi_vs_single_facetgrid.pdf", bbox_inches="tight")

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()

df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS).reset_index(drop=True)
df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

g = sns.catplot(
    data=df_base.dropna(),
    col="short_dataset",
    y="metric_test",
    x="embedder",
    hue = "embedder",
    kind="point",
    palette="husl",
    height=1.1,
    aspect=1.5,
    col_wrap=5,
    sharey=False,
    alpha = 0.,
    legend=False,
    errorbar=None,
    order=order,
    hue_order = order,
)
g.map(sns.lineplot, "embedder", "metric_test", errorbar=None, color="black", alpha = 0.3, linewidth=2.5)
g.map(sns.pointplot, "embedder", "metric_test", "embedder", order=order,palette="husl", errorbar=None, alpha = 1, legend=False, hue_order = order)

g.set_titles(col_template="{col_name}", row_template="Test performance")
g.set_xlabels("")
g.set_ylabels("$R^2$")
# Rotate x-ticks
g.tick_params(axis = 'x',rotation=90)
g.tick_params(axis = 'y', labelsize=8)

g.figure.subplots_adjust(wspace=0.8, hspace=0.4)
for ax, dataset in zip(g.axes, df_base.short_dataset.unique()):
    ax.set_ylim(
        df_base[df_base.short_dataset == dataset].groupby(["embedder"]).mean().metric_test.min() - 0.02,
        df_base[df_base.short_dataset == dataset].groupby(["embedder"]).mean().metric_test.max() + 0.02
    )


plt.savefig(f"{LATEX_PATH}/figures/molecules/multi_vs_single_facetgrid_reg.pdf", bbox_inches="tight")
