### ToDo
- Train Random Forest as a baseline for each assay in the validation and test splits. For each assay, a support set of ten data points (5 active and 5 inactive molecules) will be created, on which the Random Forest will be trained, and a query set will be used to evaluate the predictions.
- Train Neural Network as a Frequent Model: All training assays will be aggregated into a pool. The neural network (Feedforward Neural Network) will be trained based on the aggregated data to achieve the best possible performance on the validation set.
- Compare the performance of the models using AUC. For Random Forest, calculate AUC for every assay and then take the mean. For NN, ???
- Filtering NaN Values: Molecules with missing values (NaN) will be removed from the training pool. The neural network will only be trained with valid data, and only molecules with labels (0 or 1) will be considered in the test set.

### Imports

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score

from rdkit import Chem

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [2]:
# Load from pickle file
with open("datasets/datasets.pkl", "rb") as f:
    datasets = pickle.load(f)

# Extract individual DataFrames
train_set_import = datasets["train"]
validation_set_import = datasets["validation"]
test_set_import = datasets["test"]

# Convert SMILES strings back to RDKit molecule objects
def postprocess_from_pickle(df):
    df = df.copy()
    df['molecule'] = df['molecule'].apply(lambda smiles: Chem.MolFromSmiles(smiles) if smiles else None)
    return df

# Apply postprocessing
train_set = postprocess_from_pickle(train_set_import)
validation_set = postprocess_from_pickle(validation_set_import)
test_set = postprocess_from_pickle(test_set_import)

In [3]:
train_set.head()

Unnamed: 0,molecule,quantilesXecfps,ATG_AR_TRANS_dn,TOX21_p53_BLA_p1_ratio,ATG_FoxA2_CIS_up,BSK_hDFCGF_TIMP1_down,ATG_LXRb_TRANS_up,Tanguay_ZF_120hpf_JAW_up,BSK_LPS_IL8_up,NVS_ADME_hCYP2D6,...,TOX21_ERa_BLA_Antagonist_ch1,NVS_ENZ_oCOX2,TOX21_ARE_BLA_Agonist_ch2,BSK_LPS_CD40_down,TOX21_Aromatase_Inhibition,BSK_hDFCGF_MCSF_down,ATG_RARa_TRANS_up,BSK_LPS_IL1a_up,TOX21_ESRE_BLA_ratio,NVS_ENZ_hBACE
0,<rdkit.Chem.rdchem.Mol object at 0x000001BEB7A...,"[1.391232737002302, -4.7158484260537215, 1.391...",0.0,0.0,0.0,0.0,0.0,0.0,0.0,,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,
1,<rdkit.Chem.rdchem.Mol object at 0x000001BEB7A...,"[1.0041392180991264, -4.368181003558055, 1.004...",,0.0,,,,,,,...,0.0,,0.0,,0.0,,,,0.0,
2,<rdkit.Chem.rdchem.Mol object at 0x000001BEB7A...,"[1.5037410958956674, -0.4051478194032391, 1.50...",,0.0,,,,,,,...,0.0,,0.0,,0.0,,,,0.0,
3,<rdkit.Chem.rdchem.Mol object at 0x000001BEB7A...,"[1.399287681836219, -4.764877717801065, 1.3992...",0.0,0.0,0.0,0.0,0.0,0.0,0.0,,...,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,
4,<rdkit.Chem.rdchem.Mol object at 0x000001BEB7A...,"[1.3495965901701643, -4.508145310901967, 1.349...",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,


In [4]:
def train_random_forest_assays(train_set, validation_set, test_set):
    # Prepare the datasets
    train_features = train_set['quantilesXecfps'].tolist()
    val_features = validation_set['quantilesXecfps'].tolist()
    test_features = test_set['quantilesXecfps'].tolist()
    
    # Combine all features and assays into a dictionary
    all_features = np.concatenate([train_features, val_features, test_features])
    all_assays = pd.concat([train_set, validation_set, test_set], axis=0).drop(columns=['molecule', 'quantilesXecfps'])
    
    # Store the ROC-AUCs for each assay
    roc_aucs = []

    # Loop through each assay column (output)
    for assay in all_assays.columns:
        # Drop NaN values and select the assay's target column
        assay_data = all_assays[[assay]].dropna()
        features = all_features[assay_data.index]
        targets = assay_data[assay].values
        
        # Ensure we have at least 5 positive and 5 negative samples
        pos_indices = np.where(targets == 1)[0]
        neg_indices = np.where(targets == 0)[0]
        
        if len(pos_indices) < 5 or len(neg_indices) < 5:
            print(f"Not enough positive/negative samples for assay '{assay}', skipping...")
            continue
        
        # Randomly select 5 positive and 5 negative samples for support set
        support_pos = np.random.choice(pos_indices, 5, replace=False)
        support_neg = np.random.choice(neg_indices, 5, replace=False)
        support_indices = np.concatenate([support_pos, support_neg])
        
        # Define support set and query set (remaining data)
        support_features = features[support_indices]
        support_targets = targets[support_indices]
        
        query_indices = np.setdiff1d(np.arange(len(targets)), support_indices)
        query_features = features[query_indices]
        query_targets = targets[query_indices]
        
        # Train the Random Forest on the support set
        model = RandomForestClassifier()
        model.fit(support_features, support_targets)
        
        # Predict and calculate ROC-AUC on the query set
        query_predictions = model.predict_proba(query_features)[:, 1]  # Probability of class 1
        auc_score = roc_auc_score(query_targets, query_predictions)
        roc_aucs.append(auc_score)
        
        print(f"Assay '{assay}' ROC-AUC: {auc_score:.4f}")
    
    # Calculate mean ROC-AUC across all assays
    mean_auc = np.mean(roc_aucs) if roc_aucs else float('nan')
    print(f"\nMean ROC-AUC across all assays: {mean_auc:.4f}")
    return mean_auc

In [5]:
mean_auc = train_random_forest_assays(train_set, validation_set, test_set)

Assay 'ATG_AR_TRANS_dn' ROC-AUC: 0.5572
Assay 'TOX21_p53_BLA_p1_ratio' ROC-AUC: 0.6582
Assay 'ATG_FoxA2_CIS_up' ROC-AUC: 0.6147
Assay 'BSK_hDFCGF_TIMP1_down' ROC-AUC: 0.5164
Assay 'ATG_LXRb_TRANS_up' ROC-AUC: 0.6287
Assay 'Tanguay_ZF_120hpf_JAW_up' ROC-AUC: 0.5692
Assay 'BSK_LPS_IL8_up' ROC-AUC: 0.5154
Assay 'NVS_ADME_hCYP2D6' ROC-AUC: 0.6621
Assay 'CLD_CYP1A2_6hr' ROC-AUC: 0.5476
Assay 'NVS_GPCR_rOpiate_NonSelectiveNa' ROC-AUC: 0.6450
Assay 'BSK_SAg_CD69_down' ROC-AUC: 0.6479
Assay 'OT_AR_ARSRC1_0960' ROC-AUC: 0.5562
Assay 'BSK_CASM3C_MCP1_down' ROC-AUC: 0.4947
Assay 'ATG_NRF2_ARE_CIS_up' ROC-AUC: 0.6858
Assay 'APR_HepG2_CellCycleArrest_24h_dn' ROC-AUC: 0.5252
Assay 'BSK_BE3C_uPA_down' ROC-AUC: 0.4674
Assay 'TOX21_VDR_BLA_agonist_ch2' ROC-AUC: 0.7163
Assay 'NVS_LGIC_rGABAR_NonSelective' ROC-AUC: 0.4266
Assay 'ATG_VDR_TRANS_dn' ROC-AUC: 0.4485
Assay 'TOX21_PPARd_BLA_antagonist_ratio' ROC-AUC: 0.6391
Assay 'NVS_NR_hRARa_Agonist' ROC-AUC: 0.6388
Assay 'ATG_Ets_CIS_dn' ROC-AUC: 0.6143
Ass

Assay 'NVS_GPCR_rAdra1_NonSelective' ROC-AUC: 0.7851
Assay 'APR_Hepat_CellLoss_48hr_dn' ROC-AUC: 0.4976
Assay 'ACEA_T47D_80hr_Negative' ROC-AUC: 0.7024
Assay 'Tanguay_ZF_120hpf_PIG_up' ROC-AUC: 0.5689
Assay 'TOX21_MMP_ratio_up' ROC-AUC: 0.7462
Assay 'NVS_ENZ_hMMP13' ROC-AUC: 0.6043
Assay 'NVS_ADME_hCYP4F12' ROC-AUC: 0.5578
Assay 'NVS_GPCR_rmAdra2B' ROC-AUC: 0.7047
Assay 'TOX21_HSE_BLA_agonist_viability' ROC-AUC: 0.5847
Assay 'APR_HepG2_NuclearSize_72h_up' ROC-AUC: 0.6017
Assay 'TOX21_ESRE_BLA_ch2' ROC-AUC: 0.7043
Assay 'ATG_M_32_CIS_up' ROC-AUC: 0.5670
Assay 'ATG_GLI_CIS_up' ROC-AUC: 0.5920
Assay 'ATG_VDRE_CIS_dn' ROC-AUC: 0.4758
Assay 'OT_ERa_EREGFP_0480' ROC-AUC: 0.4885
Assay 'NVS_ENZ_rMAOBC' ROC-AUC: 0.5463
Assay 'APR_HepG2_MitoticArrest_24h_up' ROC-AUC: 0.6641
Assay 'TOX21_PPARg_BLA_antagonist_ratio' ROC-AUC: 0.4167
Assay 'TOX21_p53_BLA_p5_viability' ROC-AUC: 0.6988
Assay 'ATG_PXRE_CIS_dn' ROC-AUC: 0.6080
Assay 'BSK_hDFCGF_Proliferation_down' ROC-AUC: 0.6568
Assay 'NVS_GPCR_hOpiate

Assay 'TOX21_ARE_BLA_Agonist_ch1' ROC-AUC: 0.5053
Assay 'CEETOX_H295R_ANDR_dn' ROC-AUC: 0.4574
Assay 'BSK_3C_SRB_down' ROC-AUC: 0.5519
Assay 'OT_FXR_FXRSRC1_1440' ROC-AUC: 0.6348
Assay 'NVS_ENZ_hPTPN13' ROC-AUC: 0.6642
Assay 'ATG_PXR_TRANS_dn' ROC-AUC: 0.6003
Assay 'TOX21_p53_BLA_p2_viability' ROC-AUC: 0.5644
Assay 'TOX21_HSE_BLA_agonist_ch1' ROC-AUC: 0.6406
Assay 'ATG_PXR_TRANS_up' ROC-AUC: 0.7084
Assay 'NVS_IC_rNaCh_site2' ROC-AUC: 0.7569
Assay 'NVS_ADME_hCYP2A6' ROC-AUC: 0.7149
Assay 'CLD_CYP1A2_24hr' ROC-AUC: 0.6409
Assay 'NVS_GPCR_hH1' ROC-AUC: 0.3527
Assay 'ATG_Myc_CIS_up' ROC-AUC: 0.5187
Assay 'TOX21_AutoFluor_HEPG2_Media_blue' ROC-AUC: 0.7851
Assay 'ATG_HNF4a_TRANS_dn' ROC-AUC: 0.6659
Assay 'ATG_PPRE_CIS_up' ROC-AUC: 0.6870
Assay 'BSK_CASM3C_Proliferation_down' ROC-AUC: 0.6033
Assay 'NVS_GPCR_hAdoRA1' ROC-AUC: 0.5912
Assay 'ATG_HNF4a_TRANS_up' ROC-AUC: 0.3415
Assay 'NVS_ENZ_hMMP7' ROC-AUC: 0.4274
Assay 'CLD_CYP1A1_6hr' ROC-AUC: 0.4989
Assay 'BSK_SAg_CD38_down' ROC-AUC: 0.6795
A

Assay 'ATG_ISRE_CIS_up' ROC-AUC: 0.5337
Assay 'NVS_ENZ_rAChE' ROC-AUC: 0.5179
Assay 'ACEA_T47D_80hr_Positive' ROC-AUC: 0.5229
Assay 'TOX21_VDR_BLA_Antagonist_ch1' ROC-AUC: 0.6420
Assay 'ATG_FXR_TRANS_up' ROC-AUC: 0.6998
Assay 'BSK_3C_uPAR_down' ROC-AUC: 0.4073
Assay 'APR_Hepat_DNADamage_24hr_up' ROC-AUC: 0.5028
Assay 'TOX21_AR_BLA_Antagonist_ratio' ROC-AUC: 0.6604
Assay 'NVS_GPCR_hOpiate_mu' ROC-AUC: 0.6780
Assay 'NVS_MP_hPBR' ROC-AUC: 0.5211
Assay 'NVS_ENZ_hGSK3b' ROC-AUC: 0.6268
Assay 'ATG_GRE_CIS_up' ROC-AUC: 0.5726
Assay 'NCCT_HEK293T_CellTiterGLO' ROC-AUC: 0.7018
Assay 'TOX21_PPARd_BLA_antagonist_viability' ROC-AUC: 0.6918
Assay 'NVS_ADME_hCYP19A1' ROC-AUC: 0.5959
Assay 'BSK_CASM3C_SAA_down' ROC-AUC: 0.5627
Assay 'APR_HepG2_MitoMass_72h_dn' ROC-AUC: 0.4321
Assay 'CEETOX_H295R_OHPREG_up' ROC-AUC: 0.6006
Assay 'NVS_ENZ_oCOX1' ROC-AUC: 0.6243
Assay 'BSK_LPS_SRB_down' ROC-AUC: 0.5225
Assay 'ATG_STAT3_CIS_dn' ROC-AUC: 0.6451
Assay 'Tanguay_ZF_120hpf_SNOU_up' ROC-AUC: 0.6171
Assay 'OT_A