In [None]:
%cd ../..

In [None]:
import os.path

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
from tdc.single_pred import Tox, ADME, HTS, QM
from tdc.utils import retrieve_label_name_list
from tqdm import tqdm
from IPython.display import clear_output

from molDistill.baselines.utils.tdc_dataset import correspondancy_dict, get_dataset

DATASET_metadata = {
    "LD50_Zhu": ("LD50", "Tox", "reg"),
    "Caco2_Wang": ("Caco2", "Absorption", "reg"),
    "Lipophilicity_AstraZeneca": ("Lipophilicity", "Absorption", "reg"),
    "Solubility_AqSolDB": ("Solubility", "Absorption", "reg"),
    "HydrationFreeEnergy_FreeSolv": ("FreeSolv", "Absorption", "reg"),
    "PPBR_AZ": ("PPBR", "Distribution", "reg"),
    "VDss_Lombardo": ("VDss", "Distribution", "reg"),
    "Half_Life_Obach" : ("Half Life", "Excretion", "reg"),
    "Clearance_Hepatocyte_AZ" : ("Clearance (H)", "Excretion", "reg"),
    "Clearance_Microsome_AZ" : ("Clearance (M)", "Excretion", "reg"),
    "hERG": ("hERG", "Tox", "cls"),
    "hERG_Karim": ("hERG (k)", "Tox", "cls"),
    "AMES": ("AMES", "Tox", "cls"),
    "DILI": ("DILI", "Tox", "cls"),
    "Carcinogens_Lagunin": ("Carcinogens", "Tox", "cls"),
    "Skin__Reaction": ("Skin R", "Tox", "cls"),
    "Tox21": ("Tox21", "Tox", "cls"),
    "ClinTox": ("ClinTox", "Tox", "cls"),
    "ToxCast": ("ToxCast", "Tox", "cls"),
    "PAMPA_NCATS": ("PAMPA", "Absorption", "cls"),
    "HIA_Hou": ("HIA", "Absorption", "cls"),
    "Pgp_Broccatelli": ("Pgp", "Absorption", "cls"),
    "Bioavailability_Ma": ("Bioavailability", "Absorption", "cls"),
    "BBB_Martins": ("BBB", "Distribution", "cls"),
    "CYP2C19_Veith": ("CYP2C19", "Metabolism", "cls"),
    "CYP2D6_Veith": ("CYP2D6", "Metabolism", "cls"),
    "CYP3A4_Veith": ("CYP3A4", "Metabolism", "cls"),
    "CYP1A2_Veith": ("CYP1A2", "Metabolism", "cls"),
    "CYP2C9_Veith": ("CYP2C9", "Metabolism", "cls"),
    "CYP2C9_Substrate_CarbonMangels" : ("CYP2C9 (s)", "Metabolism", "cls"),
    "CYP2D6_Substrate_CarbonMangels" : ("CYP2D6 (s)", "Metabolism", "cls"),
    "CYP3A4_Substrate_CarbonMangels" : ("CYP3A4 (s)", "Metabolism", "cls"),
    "HIV": ("HIV", "HTS", "cls")
}

In [None]:
datasets = correspondancy_dict.keys()

df_metadata = pd.DataFrame(columns=[
    "dataset", "task_type", "category", "n_samples", "balanced", "short_name", "n_tasks"
])

for d in DATASET_metadata.keys():
    print(d)
    if correspondancy_dict[d] in [Tox, ADME, HTS]:
        try:
            labels = retrieve_label_name_list(d)
        except Exception as e:
            labels = [None]

        n_samples = 0
        bal = []
        task_type = DATASET_metadata[d][2]
        for l in tqdm(labels):
            df_task = correspondancy_dict[d](name=d, label_name=l).get_data()
            n_samples += df_task.shape[0]
            if task_type == "cls":
                bal.append(abs(0.5-df_task.Y.mean()))
            else:
                bal.append(df_task.Y.std())

        n_tasks = len(labels)
        bal = sum(bal)/n_tasks
        n_samples = n_samples/n_tasks


        row = [d, task_type, DATASET_metadata[d][1], n_samples, bal, DATASET_metadata[d][0], n_tasks]
        df_metadata.loc[len(df_metadata)] = row
        clear_output()

In [None]:
df_metadata.to_csv("molDistill/df_metadata.csv", index=False)

In [None]:
df_metadata