In [12]:
import pandas as pd

df_ar = pd.read_csv('ar.csv')
df_ahr = pd.read_csv('ahr.csv')
df_ar_lbd = pd.read_csv('ar-lbd.csv')
df_are = pd.read_csv('are.csv')
df_aromatase = pd.read_csv('aromatase.csv')
df_atad5 = pd.read_csv('atad5.csv')
df_er_lbd = pd.read_csv('er-lbd.csv')
df_er = pd.read_csv('er.csv')
df_hse = pd.read_csv('hse.csv')
df_mmp = pd.read_csv('mmp.csv')
df_p53 = pd.read_csv('p53.csv')
df_ppar_gamma = pd.read_csv('ppar-gamma.csv')

In [13]:
df_list = [df_ar, df_ahr, df_ar_lbd, df_are, df_aromatase, df_atad5, df_er_lbd, df_er, df_hse, df_mmp, df_p53, df_ppar_gamma]

df_name = ['df_ar', 'df_ahr', 'df_ar_lbd', 'df_are', 'df_aromatase', 'df_atad5', 'df_er_lbd', 'df_er', 'df_hse', 'df_mmp', 'df_p53', 'df_ppar_gamma']
for i, j in zip(df_name, df_list):
    print(i+':', len(j))

df_ar: 9360
df_ahr: 8167
df_ar_lbd: 8597
df_are: 7166
df_aromatase: 7224
df_atad5: 9089
df_er_lbd: 8751
df_er: 7695
df_hse: 8148
df_mmp: 7319
df_p53: 8632
df_ppar_gamma: 8182


In [14]:
#데이터 분할

import optuna
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve, f1_score, roc_auc_score, average_precision_score, accuracy_score
import numpy as np

def traintestsplit(dataframe):
    X = dataframe.loc[:, (dataframe.columns != 'smiles') & (dataframe.columns != 'Active')]
    y = dataframe['Active']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
    return X_train, X_test, y_train, y_test
import joblib

def train_one_assay(df, n_trials=100, random_state=42):
    X_train, X_test, y_train, y_test = traintestsplit(df)
    
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
    )

    study4 = optuna.create_study(direction='maximize')
    study4.optimize(objective, n_trials=100, show_progress_bar=True)
    best = {**study4.best_params, 'class_weight': 'balanced', 'random_state':42, 'n_jobs':-1}
    
    rfc = RandomForestClassifier(**study4.best_params)

    rfc.fit(X_tr, y_tr)

    val_score = rfc.predict_proba(X_val)[:, 1]
    prec, rec, thr = precision_recall_curve(y_val, val_score)
    beta = 2.0 
    fb = (1+beta**2) * prec*rec / (beta**2*prec + rec + 1e-12)
    best_idx = np.nanargmax(fb[:-1])
    best_thr = float(thr[best_idx])

    final_model = RandomForestClassifier(**best).fit(X_train, y_train)
    test_score = final_model.predict_proba(X_test)[:, 1]
    test_auc = roc_auc_score(y_test, test_score)
    test_pr = average_precision_score(y_test, test_score)
    
    return {
        'model':final_model,
        'best_thr':best_thr,
        'X_cols':list(X_train.columns),
        'test_auc':test_auc,
        'test_pr':test_pr,
        'best_params':best
    }
    
def train_all_assays(df_list, assay_names, n_trials=100):
    bundle = {}
    for name, df in zip(assay_names, df_list):
        bundle[name] = train_one_assay(df, n_trials=n_trials)
    joblib.dump(bundle, 'tox21_maccs_bundle.joblib')
    
    return bundle

In [15]:
#SMILES -> MACCS Key 변형

from rdkit import Chem
from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.Chem.rdmolops import SanitizeFlags
from rdkit.Chem.MolStandardize import rdMolStandardize as std
from rdkit import DataStructs
import numpy as np
import pandas as pd
from rdkit import RDLogger
from tqdm import tqdm
 

def SmilesToMaccs(smiles_list, mw_range=(50,1000)):
    RDLogger.DisableLog('rdApp.warning')
    maccs_array = []
    cnt = 0
    eliminated_index = []
    for index, smiles in enumerate(smiles_list):
        try:
            sanitize_ops = (SanitizeFlags.SANITIZE_ALL
                            ^ SanitizeFlags.SANITIZE_KEKULIZE
                            ^ SanitizeFlags.SANITIZE_SETAROMATICITY
                            )
            
            mol = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol is None:
                raise ValueError('Invalid SMILES')
            
            Chem.SanitizeMol(mol, sanitizeOps=sanitize_ops)
            
            params = std.CleanupParameters()
            mol = std.Cleanup(mol, params)
            mol = std.MetalDisconnector().Disconnect(mol)
            mol = std.Uncharger().uncharge(mol)
            mol = std.Reionizer().reionize(mol)
            
            chooser = std.LargestFragmentChooser(preferOrganic=True)
            mol = chooser.choose(mol)
            
            Chem.SanitizeMol(mol, sanitizeOps=sanitize_ops)
            
            mw = Descriptors.MolWt(mol)
            if mw < mw_range[0] or mw > mw_range[1]:
                cnt += 1
                eliminated_index.append(index)
                continue
            
            fp = MACCSkeys.GenMACCSKeys(mol)
            arr = np.zeros(167, dtype=np.uint8)
            DataStructs.ConvertToNumpyArray(fp, arr)
            maccs_array.append(arr.astype(float))
            
        except Exception as e:
            cnt += 1
            eliminated_index.append(index)
            continue
    
    print(f'{cnt} eliminated')

    X = np.vstack(maccs_array).astype(float) 
    cols = [f'MACCS_{i:03d}' for i in range(X.shape[1])]
    
    return pd.DataFrame(X, columns=cols), eliminated_index  

In [16]:
#예측 dataframe 출력

def predict_multilabel(smiles_list, bundle, return_proba=False):
    if isinstance(smiles_list, str):
        smiles_list = [smiles_list]
        
    X_raw, eliminated_index = SmilesToMaccs(smiles_list)

    yhat_cols = {}
    proba_cols = {}
    
    assays = list(bundle.keys())
    
    for assay in assays:
        info = bundle[assay]
        X = X_raw.reindex(columns=info['X_cols'], fill_value=0)
        p = info['model'].predict_proba(X)[:, 1]
        y = (p >= info['best_thr']).astype(int)
        
        yhat_cols[assay[3:]] = y
        if return_proba:
            proba_cols[assay] = p
    
    yhat_df = pd.DataFrame(yhat_cols, index=X_raw.index)
    
    if len(eliminated_index) >= 1:
        smiles_s = pd.Series(smiles_list, name="SMILES")
        smiles_s = smiles_s.drop(index=sorted(set(map(int, eliminated_index))))
        outcome_df = smiles_s.reset_index(drop=True).to_frame()
        
    else:
        outcome_df = pd.DataFrame({'SMILES':smiles_list}).reset_index(drop=True)
        
    outcome_df = pd.concat([outcome_df, yhat_df.reset_index(drop=True)], axis=1)
    
    if return_proba:
        proba_df = pd.DataFrame(proba_cols, index=X_raw.index)
        return outcome_df, proba_df
    return outcome_df    
    

In [17]:
import pandas as pd

verify_df = pd.read_csv('verify_df.csv')
verify_df = verify_df.drop('Unnamed: 0', axis=1)
new_smiles2 = list(verify_df['SMILES'].values)

In [18]:
bundle = joblib.load('tox21_maccs_bundle.joblib')
outcome_df, proba_df = predict_multilabel(new_smiles2, bundle, return_proba=True)
proba_df

[13:13:18] Initializing MetalDisconnector
[13:13:18] Running MetalDisconnector
[13:13:18] Initializing Normalizer
[13:13:18] Running Normalizer
[13:13:18] Initializing MetalDisconnector
[13:13:18] Running MetalDisconnector
[13:13:18] Running Uncharger
[13:13:18] Running LargestFragmentChooser
[13:13:18] Fragment: O=C(O)/C=C\C(=O)O
[13:13:18] New largest fragment: O=C(O)/C=C\C(=O)O (12)
[13:13:18] Fragment: C[C@]12C=CC(=O)C=C1CC[C@@H]1C2=CC[C@]2(C)[C@@H](C(=O)CN3CCN(c4cc(N5CCCC5)nc(N5CCCC5)n4)CC3)CC[C@@H]12
[13:13:18] New largest fragment: C[C@]12C=CC(=O)C=C1CC[C@@H]1C2=CC[C@]2(C)[C@@H](C(=O)CN3CCN(c4cc(N5CCCC5)nc(N5CCCC5)n4)CC3)CC[C@@H]12 (95)
[13:13:18] Initializing MetalDisconnector
[13:13:18] Running MetalDisconnector
[13:13:18] Initializing Normalizer
[13:13:18] Running Normalizer
[13:13:18] Initializing MetalDisconnector
[13:13:18] Running MetalDisconnector
[13:13:18] Running Uncharger
[13:13:18] Running LargestFragmentChooser
[13:13:18] Fragment: [Na+]
[13:13:18] New largest frag

5 eliminated


Unnamed: 0,df_ar,df_ahr,df_ar_lbd,df_are,df_aromatase,df_atad5,df_er_lbd,df_er,df_hse,df_mmp,df_p53,df_ppar_gamma
0,0.232704,0.166845,0.072994,0.395223,0.250729,0.047036,0.051960,0.111464,0.073789,0.222820,0.123597,0.112313
1,0.053139,0.353520,0.019138,0.311805,0.216288,0.096334,0.064356,0.087488,0.044000,0.197427,0.215936,0.131320
2,0.010933,0.381246,0.017091,0.087602,0.356106,0.086458,0.010159,0.102115,0.070784,0.127057,0.080244,0.161917
3,0.052092,0.465499,0.018474,0.253791,0.602272,0.070976,0.029569,0.223000,0.155165,0.263120,0.056854,0.074712
4,0.039641,0.156491,0.152941,0.252015,0.030303,0.015971,0.035415,0.068533,0.029320,0.311903,0.089691,0.167062
...,...,...,...,...,...,...,...,...,...,...,...,...
637,0.000000,0.013423,0.018057,0.073906,0.023087,0.000000,0.000000,0.059001,0.008170,0.000000,0.025796,0.010296
638,0.009942,0.063115,0.000000,0.050130,0.024041,0.004951,0.005516,0.013448,0.037251,0.000000,0.019585,0.004155
639,0.009942,0.267540,0.076090,0.129930,0.084015,0.058509,0.045453,0.103985,0.424203,0.117007,0.042684,0.219576
640,0.099997,0.360775,0.035354,0.328922,0.473450,0.087539,0.060539,0.116223,0.208742,0.437135,0.112414,0.118952
