In [None]:
import os.path
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

MODELS_TO_EVAL = [
    "ChemBertMLM-10M",
    "ChemBertMTR-77M",
    "ChemGPT-1.2B",
    "GraphMVP",
    "GROVER",
    "GraphLog",
    "GraphCL",
    "InfoGraph",
    "FRAD_QM9",
    "MolR_gat",
    "ThreeDInfomax",
    "custom:MOSES_512_10_lr1e-4_gine"
]

DATASET_metadata = {
    "LD50_Zhu": ("LD50", "Tox"),
    "Caco2_Wang": ("Caco2", "Absorption"),
    "Lipophilicity_AstraZeneca": ("Lipophilicity", "Absorption"),
    "Solubility_AqSolDB": ("Solubility", "Absorption"),
    "HydrationFreeEnergy_FreeSolv": ("FreeSolv", "Absorption"),
    "PPBR_AZ": ("PPBR", "Distribution"),
    "VDss_Lombardo": ("VDss", "Distribution"),
    "Half_Life_Obach" : ("Half Life", "Excretion"),
    "Clearance_Hepatocyte_AZ" : ("Clearance (H)", "Excretion"),
    "Clearance_Microsome_AZ" : ("Clearance (M)", "Excretion"),
}

In [None]:
path = "../downstream_results"
dfs = []
for model in MODELS_TO_EVAL:
    model_path = os.path.join(path, model)
    for file in os.listdir(model_path):
        if file.endswith(".csv"):
            dataset = file.replace(".csv", "").replace("results_", "")
            if dataset in DATASET_metadata:
                df = pd.read_csv(os.path.join(model_path, file), index_col=0)
                df["embedder"] = model.replace("custom:MOSES_512_10_lr1e-4_gine", "student")
                df["dataset"] = dataset
                dfs.append(df)
        else:
            model_path = os.path.join(model_path, file)
            for file in os.listdir(model_path):
                if file.endswith(".csv"):
                    dataset = file.replace(".csv", "").replace("results_", "")
                    if dataset in DATASET_metadata:
                        df = pd.read_csv(os.path.join(model_path, file), index_col=0)
                        df["embedder"] = model.replace("custom:MOSES_512_10_lr1e-4_gine", "student")
                        df["dataset"] = dataset
                        dfs.append(df)
            continue

df= pd.concat(dfs)

df_or = df.groupby(["dataset", "embedder"]).metric_test.mean().reset_index()
df = df_or.pivot_table(index="embedder", columns="dataset", values="metric_test")
df.dropna(axis=1, inplace=True)
# drop column and index names
df.index.name = None
order = df.mean(axis=1).sort_values(ascending=False).index.tolist()
order.remove("student")
order= ["student"] + order

df.mean(axis=1).sort_values(ascending=False)

df = df.loc[order[::-1],:]
df.index = df.index.str.replace("_", " ")

df.columns = pd.MultiIndex.from_tuples([(
    DATASET_metadata[c][1], DATASET_metadata[c][0]
) for c in df.columns])
df[("","mean")] = df.mean(axis=1)
df

In [None]:
# Bold the best value for each dataset
def highlight_max(s):
    is_max = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_max]
# Order columns
df = df[sorted(df.columns, key=lambda x: x[0])]

style = df.style.format("{:.3f}").highlight_max(axis=0, props='bfseries:')
style = style.format("{:.3f}")
style

In [None]:


col_format = "r|"

over_cols = None
for ov_col, col in df.columns:
    if over_cols != ov_col:
        col_format += "|"
        over_cols = ov_col
    col_format += "c"
col_format += "|"

latex = style.to_latex(
    column_format=col_format,
    multicol_align="|c|",
    siunitx=True,
)

table_path = "/home/philippe/Distill/latex/Distillation-MI-ICLR/tables/molecules/molecule_results_reg.tex"
with open(table_path, "w") as f:
    f.write(latex)