In [42]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    "custom:MOSES_512_10_lr1e-4_gine"
]

DATASET_metadata = {
    "hERG": ("hERG", "Tox"),
    "hERG_Karim": ("hERG (k)", "Tox"),
    "AMES": ("AMES", "Tox"),
    "DILI": ("DILI", "Tox"),
    "Carcinogens_Lagunin": ("Carcinogens", "Tox"),
    "Skin__Reaction": ("Skin R", "Tox"),
    "Tox21": ("Tox21", "Tox"),
    "ClinTox": ("ClinTox", "Tox"),
    "PAMPA_NCATS": ("PAMPA", "Absorption"),
    "HIA_Hou": ("HIA", "Absorption"),
    "Pgp_Broccatelli": ("Pgp", "Absorption"),
    "Bioavailability_Ma": ("Bioavailability", "Absorption"),
    "BBB_Martins": ("BBB", "Distribution"),
    "CYP2C19_Veith": ("CYP2C19", "Metabolism"),
    "CYP2D6_Veith": ("CYP2D6", "Metabolism"),
    "CYP3A4_Veith": ("CYP3A4", "Metabolism"),
    "CYP1A2_Veith": ("CYP1A2", "Metabolism"),
    "CYP2C9_Veith": ("CYP2C9", "Metabolism"),
    "CYP2C9_Substrate_CarbonMangels" : ("CYP2C9 (s)", "Metabolism"),
    "CYP2D6_Substrate_CarbonMangels" : ("CYP2D6 (s)", "Metabolism"),
    "CYP3A4_Substrate_CarbonMangels" : ("CYP3A4 (s)", "Metabolism"),
}


In [44]:
df_metadata

Unnamed: 0,0,1
hERG,hERG,Tox
hERG_Karim,hERG (k),Tox
AMES,AMES,Tox
DILI,DILI,Tox
Carcinogens_Lagunin,Carcinogens,Tox
Skin__Reaction,Skin R,Tox
Tox21,Tox21,Tox
ClinTox,ClinTox,Tox
PAMPA_NCATS,PAMPA,Absorption
HIA_Hou,HIA,Absorption


# Mean Performances

In [None]:
path = "../downstream_results"
dfs = []
for model in MODELS_TO_EVAL:
    model_path = os.path.join(path, model)
    for file in os.listdir(model_path):
        if file.endswith(".csv"):
            dataset = file.replace(".csv", "").replace("results_", "")
            if dataset in DATASET_metadata:
                df = pd.read_csv(os.path.join(model_path, file), index_col=0)
                df["embedder"] = model.replace("custom:MOSES_512_10_lr1e-4_gine", "student")
                df["dataset"] = dataset
                dfs.append(df)
        else:
            model_path = os.path.join(model_path, file)
            for file in os.listdir(model_path):
                if file.endswith(".csv"):
                    dataset = file.replace(".csv", "").replace("results_", "")
                    if dataset in DATASET_metadata:
                        df = pd.read_csv(os.path.join(model_path, file), index_col=0)
                        df["embedder"] = model.replace("custom:MOSES_512_10_lr1e-4_gine", "student")
                        df["dataset"] = dataset
                        dfs.append(df)
            continue

df_base= pd.concat(dfs)

df_m = df_base.groupby(["dataset", "embedder"]).metric_test.mean().reset_index()
df_m["dataset"] = df_m["dataset"] + " mean"
df_v = df_base.groupby(["dataset", "embedder"]).metric_test.std().reset_index()
df_v["dataset"] = df_v["dataset"] + " std"


df = df_m.pivot_table(index="embedder", columns="dataset", values="metric_test").join(df_v.pivot_table(index="embedder", columns="dataset", values="metric_test"))
df.dropna(axis=1, inplace=True)
# drop column and index names
df.index.name = None
order = df.mean(axis=1).sort_values(ascending=False).index




df.columns = pd.MultiIndex.from_tuples([(
    DATASET_metadata[c.split(" ")[0]][1], DATASET_metadata[c.split(" ")[0]][0] + c.split(" ")[1].replace("mean", "").replace("std", " std")) for c in df.columns])

df[(" ","avg")] = df_m.pivot_table(index="embedder", columns="dataset", values="metric_test").mean(axis=1)
df[(" ","avg std")] = df_m.pivot_table(index="embedder", columns="dataset", values="metric_test").std(axis=1)



df = df.loc[order[::-1],:]
df.index = df.index.str.replace("_", " ")

df = df[sorted(df.columns, key=lambda x: x[0])]
# Get ind of all maxs of mean
maxs = df.idxmax(axis=0)

In [None]:
style = df.copy()
for col in style.columns:
    if not col[1].endswith("std"):
        style[col] ="$"+ style[col].apply(lambda x: f"{x:.3f}") + " \pm " + style[(col[0], col[1] + " std")].apply(lambda x: f"{x:.3f}")+'$'

style.drop(columns=[(col[0], col[1] + " std") for col in style.columns if not col[1].endswith("std")], inplace=True)

for col in style.columns:
    best = maxs[col]
    style.loc[best,col] = "\\boldsymbol{" + style.loc[best,col] + "}"


style = style.style.highlight_max(axis=0, props='bfseries:')

# Make the table with confidence intervals
style

In [None]:
col_format= "r|"

over_cols = 'This is not a column name that will be used'
for ov_col, col in style.columns:
    if over_cols != ov_col:
        col_format += "|"
        over_cols = ov_col
    col_format += "c"
col_format += "|"

latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)


table_path = "/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_cls.tex"
with open(table_path, "w") as f:
    f.write(latex)

In [None]:
style

In [None]:
df_base

# Rankings

In [None]:
path = "../downstream_results"
dfs = []
for model in MODELS_TO_EVAL:
    model_path = os.path.join(path, model)
    for file in os.listdir(model_path):
        if file.endswith(".csv"):
            dataset = file.replace(".csv", "").replace("results_", "")
            if dataset in DATASET_metadata:
                df = pd.read_csv(os.path.join(model_path, file), index_col=0)
                df["embedder"] = model.replace("custom:MOSES_512_10_lr1e-4_gine", "student")
                df["dataset"] = dataset
                dfs.append(df)
        else:
            model_path = os.path.join(model_path, file)
            for file in os.listdir(model_path):
                if file.endswith(".csv"):
                    dataset = file.replace(".csv", "").replace("results_", "")
                    if dataset in DATASET_metadata:
                        df = pd.read_csv(os.path.join(model_path, file), index_col=0)
                        df["embedder"] = model.replace("custom:MOSES_512_10_lr1e-4_gine", "student")
                        df["dataset"] = dataset
                        dfs.append(df)
            continue

df_base= pd.concat(dfs)


df_base["embedder"] = df_base["embedder"].replace("custom:MOSES_512_10_lr1e-4_gine", "student")
df_base.reset_index(inplace=True, drop=True)
df_base["id"] = df_base.index%5

df_base

In [None]:
# Compute meanrank
import autorank

import seaborn as sns
import matplotlib.pyplot as plt

def get_ranked_df(df_base):
    ranked_df = pd.DataFrame()
    ranked_df["embedder"] = order[::-1]
    for dataset in df_base.dataset.unique():
        df_to_rank = df_base[df_base.dataset == dataset]
        df_to_rank = df_to_rank.pivot_table(index="id", columns="embedder", values="metric_test")
        results = autorank.autorank(df_to_rank, alpha=0.05, verbose=True, force_mode="nonparametric",).rankdf.reset_index()
        results[dataset] = results["meanrank"]
        results= results[["embedder", dataset]]
        ranked_df = ranked_df.merge(results, on="embedder", how="outer")
    return ranked_df

In [None]:

ranked_df = get_ranked_df(df_base)
ranked_df.columns = ["embedder"] + [DATASET_metadata[c][0] for c in ranked_df.columns[1:]]

melted_ranked = ranked_df.melt(id_vars="embedder", var_name="dataset", value_name="meanrank")
order_datas =ranked_df.set_index("embedder").transpose()["student"].sort_values().index


#define colors for models so that student has more saturation
colors = sns.color_palette("husl", len(MODELS_TO_EVAL), desat=0.4)
colors[-1] = sns.color_palette("husl", len(MODELS_TO_EVAL))[-1]

cmapping = {model.replace(
    "custom:MOSES_512_10_lr1e-4_gine", "student"
): color for model, color in zip(MODELS_TO_EVAL, colors)}

fig,ax = plt.subplots(figsize=(20,3))

sns.barplot(data=melted_ranked, x="dataset", y="meanrank", hue="embedder", ax=ax, order=order_datas, palette=cmapping)
sns.despine()
#rotate x labels
plt.xticks(rotation=45)
#Move legend outside
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1)

In [None]:
style = get_ranked_df(df_base)
style["embedder"] = style["embedder"].str.replace("_", " ")

style.set_index("embedder", inplace=True)

style.columns= pd.MultiIndex.from_tuples([(
    DATASET_metadata[c][1], DATASET_metadata[c][0]) for c in style.columns])
style = style[sorted(style.columns, key=lambda x: x[0])]

style = style.style.highlight_min(axis=0, props='bfseries:')

style


In [None]:
col_format = "r|"
over_cols = None
for ov_col, col in style.columns:
    if over_cols != ov_col:
        col_format += "|"
        over_cols = ov_col
    col_format += "c"
col_format += "|"


latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)


table_path = "/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_cls_rank.tex"
with open(table_path, "w") as f:
    f.write(latex)