In [None]:
%cd ../..

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from molDistill.utils.notebooks import *

MODELS_TO_EVAL = [
    STUDENT_MODEL,
    SMALL_KERNEL,
    LARGE_KERNEL,
]
DATASETS = df_metadata.index.tolist()[:-3]
DATASETS.remove("ToxCast")

len(DATASETS)

# Mean Performances Classif


In [None]:
DATASETS = df_metadata[df_metadata.task_type == "cls"].index.tolist()

DATASET_GROUP = [["Distribution", "HTS", "Absorption", " "], ["Metabolism"], ["Tox",]]


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (SMALL_KERNEL, "2-layers-kernel"),
    (LARGE_KERNEL, "5-layers-kernel"),
    ],
)

df, order = aggregate_results_with_ci(df_base)


for i, datasets in enumerate(DATASET_GROUP):
    df_group = df[[col for col in df.columns if (col[0] in datasets)]]
    style,latex = style_df_ci(df_group, order[::-1])
    table_path = f"/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/kern_cls_{i}.tex"
    latex = add_hline(latex, 1)
    latex = add_hline(latex, -1)
    with open(table_path, "w") as f:
        f.write(latex)

In [None]:
df_base

# Mean Performances Reg

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()

DATASET_GROUP = [[" ", "Absorption", "Tox"], ["Distribution", "Excretion"]]


df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (SMALL_KERNEL, "2-layers-kernel"),
    (LARGE_KERNEL, "5-layers-kernel"),
    ],
)
df, order = aggregate_results_with_ci(df_base)


for i, datasets in enumerate(DATASET_GROUP):
    df_group = df[[col for col in df.columns if (col[0] in datasets)]]
    style,latex = style_df_ci(df_group, order[::-1])
    table_path = f"/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/kern_reg_{i}.tex"
    latex = add_hline(latex, 1)
    latex = add_hline(latex, -1)
    with open(table_path, "w") as f:
        f.write(latex)

# All dataset Plot

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "cls"].index.tolist()



df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (SMALL_KERNEL, "2-layers-kernel"),
    (LARGE_KERNEL, "5-layers-kernel"),
    ],
)

df_base.reset_index(inplace=True, drop=True)
step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

def get_diff(row, df_base):
    embedder = row.embedder
    dataset = row.dataset
    reference_value = df_base[(df_base.embedder == "student-large") & (df_base.id == row.id)].metric_test.values[0]
    return (row.metric_test - reference_value)

df_base["metric_test_diff"] = df_base.apply(
    lambda  row: get_diff(row,df_base),
    axis=1)

def get_diff_perc(row, df_base):
    embedder = row.embedder
    dataset = row.dataset
    reference_value = df_base[(df_base.embedder == "student-large") & (df_base.id == row.id)].metric_test.values[0]
    return (row.metric_test - reference_value) / reference_value * 100

df_base["metric_test_diff_perc"] = df_base.apply(
    lambda  row: get_diff_perc(row,df_base),
    axis=1)


In [None]:
g = sns.catplot(
    data=df_base[df_base.embedder != "student-large"],
    col="short_dataset",
    y="metric_test",
    hue="embedder",
    kind="box",
    palette="husl",
    height=1.2,
    aspect=1.2,
    col_wrap = 8,
    sharey=False,
)

g.set_titles(col_template="{col_name}", row_template="AUROC")
g.set_axis_labels("", "")

# Put a hline at the student-large value
for ax, datasets in zip(g.axes, df_base.dataset.unique()):
    ax.axhline(df_base.groupby(["dataset", "embedder"]).get_group((datasets, "student-large")).metric_test.mean(),
               color="red", linestyle="--")

# Move legend where there is no plots
g._legend.set_bbox_to_anchor([0.84, 0.18])

plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/kern_cls_diff.pdf", bbox_inches="tight")

In [None]:
g = sns.catplot(
    data=df_base,
    col="short_dataset",
    y="metric_test",
    hue="embedder",
    kind="box",
    palette="husl",
    height=1.2,
    aspect=1.2,
    col_wrap = 13,
    sharey=False,
)

g.set_titles(col_template="{col_name}", row_template="AUROC")
g.set_axis_labels("", "")

# Move legend where there is no plots
g._legend.set_bbox_to_anchor([0.78, 0.12])

In [None]:
DATASETS = df_metadata[df_metadata.task_type == "reg"].index.tolist()



df_base= get_all_results(MODELS_TO_EVAL, "downstream_results", DATASETS, renames=[
    (STUDENT_MODEL,"student-large"),
    (SMALL_KERNEL, "2-layers-kernel"),
    (LARGE_KERNEL, "5-layers-kernel"),
    ],
)

df_base.reset_index(inplace=True, drop=True)
step = df_base.embedder.value_counts().max()
df_base["id"] = df_base.index%step

df_base["short_dataset"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)

def get_diff(row, df_base):
    embedder = row.embedder
    dataset = row.dataset
    reference_value = df_base[(df_base.embedder == "student-large") & (df_base.id == row.id)].metric_test.values[0]
    return (row.metric_test - reference_value)

df_base["metric_test_diff"] = df_base.apply(
    lambda  row: get_diff(row,df_base),
    axis=1)

def get_diff_perc(row, df_base):
    embedder = row.embedder
    dataset = row.dataset
    reference_value = df_base[(df_base.embedder == "student-large") & (df_base.id == row.id)].metric_test.values[0]
    return (row.metric_test - reference_value) / reference_value * 100

df_base["metric_test_diff_perc"] = df_base.apply(
    lambda  row: get_diff_perc(row,df_base),
    axis=1)


In [None]:
g = sns.catplot(
    data=df_base[df_base.embedder != "student-large"],
    col="short_dataset",
    y="metric_test",
    hue="embedder",
    kind="box",
    palette="husl",
    height=1.2,
    aspect=1.2,
    col_wrap = 3,
    sharey=False,
)

g.set_titles(col_template="{col_name}", row_template="AUROC")
g.set_axis_labels("", "")

# Put a hline at the student-large value
for ax, datasets in zip(g.axes, df_base.dataset.unique()):
    ax.axhline(df_base.groupby(["dataset", "embedder"]).get_group((datasets, "student-large")).metric_test.mean(),
               color="red", linestyle="--")

# Move legend where there is no plots
g._legend.set_bbox_to_anchor([0.6, 0.12])


plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/kern_reg_diff.pdf", bbox_inches="tight")