In [None]:
%cd ../..

LATEX_PATH = "../latex/Distillation-MI-ICLR"

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    ZINC_MODEL,
    SMALL_KERNEL,
    LARGE_KERNEL,
]
DATASETS = df_metadata.index.tolist()[:-3]
DATASETS.remove("ToxCast")

len(DATASETS)

# Mean Performances Classif


In [None]:
DATASETS = df_metadata[df_metadata.task_type == "cls"].index.tolist()

DATASET_GROUP = [["Distribution", "HTS", "Absorption", " "], ["Metabolism"], ["Tox",]]


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS,)

df, order = aggregate_results_with_ci(df_base)


for i, datasets in enumerate(DATASET_GROUP):
    df_group = df[[col for col in df.columns if (col[0] in datasets)]]
    style,latex = style_df_ci(df_group, order[::-1])
    table_path = f"{LATEX_PATH}/tables/molecules/kern_cls_{i}.tex"
    latex = add_hline(latex, 1)
    latex = add_hline(latex, -1)
    with open(table_path, "w") as f:
        f.write(latex)

# Mean Performances Reg

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()

DATASET_GROUP = [[" ", "Absorption", "Tox"], ["Distribution", "Excretion"]]


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS)
df, order = aggregate_results_with_ci(df_base)


for i, datasets in enumerate(DATASET_GROUP):
    df_group = df[[col for col in df.columns if (col[0] in datasets)]]
    style,latex = style_df_ci(df_group, order[::-1])
    table_path = f"{LATEX_PATH}/tables/molecules/kern_reg_{i}.tex"
    latex = add_hline(latex, 1)
    latex = add_hline(latex, -1)
    with open(table_path, "w") as f:
        f.write(latex)

In [None]:
df

# All dataset Plot

In [None]:
df_base

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "cls"].index.tolist()[:-3]
DATASETS.remove("ToxCast")



df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS,)

df_base.reset_index(inplace=True, drop=True)
step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

n_clusters = {
    "2-layers-kernel": 2,
    "5-layers-kernel": 5,
    "student-250k": 3,
}

df_base["n_cluster"] = df_base.embedder.apply(lambda x: n_clusters[x])

g = sns.catplot(
    data=df_base.dropna(),
    col="short_dataset",
    y="metric_test",
    x="n_cluster",
    hue = "embedder",
    kind="point",
    palette="husl",
    height=1.1,
    aspect=1.5,
    col_wrap=5,
    sharey=False,
    alpha = 0.7,
    legend=False,
    errorbar="ci",
    native_scale=True,
)
g.map(sns.lineplot, "n_cluster", "metric_test", errorbar=None, color="black", alpha = 0.3, linewidth=2.5, )
g.map(sns.pointplot, "n_cluster", "metric_test", "embedder",palette="husl", errorbar=None, alpha = 1, legend=False,
    native_scale=True,)

g.set_titles(col_template="{col_name}", row_template="Test performance")

g.set_ylabels("AUROC")
# Rotate x-ticks
g.tick_params(axis = 'y', labelsize=8)

g.figure.subplots_adjust(wspace=0.8, hspace=0.4)



plt.savefig(f"{LATEX_PATH}/figures/molecules/kernel_point_cls.pdf", bbox_inches="tight")


In [None]:
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()



df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS,)

df_base.reset_index(inplace=True, drop=True)
step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

n_clusters = {
    "2-layers-kernel": 2,
    "5-layers-kernel": 5,
    "student-250k": 3,
}

df_base["n_cluster"] = df_base.embedder.apply(lambda x: n_clusters[x])

g = sns.catplot(
    data=df_base.dropna(),
    col="short_dataset",
    y="metric_test",
    x="n_cluster",
    hue = "embedder",
    kind="point",
    palette="husl",
    height=1.1,
    aspect=1.5,
    col_wrap=5,
    sharey=False,
    alpha = 0.7,
    legend=False,
    errorbar="ci",
    native_scale=True,
)
g.map(sns.lineplot, "n_cluster", "metric_test", errorbar=None, color="black", alpha = 0.3, linewidth=2.5, )
g.map(sns.pointplot, "n_cluster", "metric_test", "embedder",palette="husl", errorbar=None, alpha = 1, legend=False,
    native_scale=True,)

g.set_titles(col_template="{col_name}", row_template="Test performance")

g.set_ylabels("$R^2$")
# Rotate x-ticks
g.tick_params(axis = 'y', labelsize=8)

g.figure.subplots_adjust(wspace=0.8, hspace=0.4)



plt.savefig(f"{LATEX_PATH}/figures/molecules/kernel_point_reg.pdf", bbox_inches="tight")


In [None]:
DATASETS = df_metadata.index.tolist()
DATASETS.remove("ToxCast")


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS,)

df_base.reset_index(inplace=True, drop=True)
step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

n_clusters = {
    "2-layers-kernel": 2,
    "5-layers-kernel": 5,
    "student-250k": 3,
}

df_base["n_cluster"] = df_base.embedder.apply(lambda x: n_clusters[x])

g = sns.catplot(
    data=df_base.dropna(),
    col="short_dataset",
    y="metric_test",
    x="n_cluster",
    hue = "embedder",
    kind="point",
    palette="husl",
    height=1.1,
    aspect=1.5,
    col_wrap=4,
    sharey=False,
    alpha = 0.7,
    legend=False,
    errorbar="ci",
    native_scale=True,
)
g.map(sns.lineplot, "n_cluster", "metric_test", errorbar=None, color="black", alpha = 0.3, linewidth=2.5, )
g.map(sns.pointplot, "n_cluster", "metric_test", "embedder",palette="husl", errorbar=None, alpha = 1, legend=False,
    native_scale=True,)

g.set_titles(col_template="{col_name}", row_template="Test performance")

g.set_ylabels("")
# Rotate x-ticks
g.tick_params(axis = 'y', labelsize=8)

g.figure.subplots_adjust(wspace=0.8, hspace=0.4)
g.figure.supylabel("AUROC/$R^2$")


plt.savefig(f"{LATEX_PATH}/figures/molecules/kernel_point.pdf", bbox_inches="tight")
