In [1]:
import warnings
warnings.simplefilter('ignore')
from utils import *
import pickle
import pandas as pd
from collections import defaultdict
from scipy.stats import pearsonr

from sklearn.model_selection import KFold
from sklearn import ensemble
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, make_scorer

from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit import DataStructs

import umap

import matplotlib.pyplot as plt
import seaborn as sns
from statannotations.Annotator import Annotator

In [2]:
sns.set_style('ticks', 
              {'font.sans-serif':['Arial'], 
                'text.color': 'black',
                'xtick.color': 'black',
                'ytick.color': 'black',
                })
plt.rcParams.update({'font.size':20})
colors = sns.color_palette('Paired')

## perturbational representations-based model 

In [3]:
target = 'HTR2A'
cell_ls = ['A375', 'HA1E', 'HELA', 'HT29', 'MCF7', 'PC3', 'YAPC']

df_grid_search_results = pd.read_csv('../results/4.Ligand_based_virtual_screening/{}/grid_search_param.csv'.format(target))

df_metrics = pd.DataFrame(columns=['AUROC', 'AUPR', 'data', 'target', 'cid', 'random_seed'])
predict_result = defaultdict(list)
# random_seed_ls = [random.randint(0, 1000000) for i in range(5)]
random_seed_ls = [808431, 510395, 584403, 630680, 532783]
for data_name in ['TranSiGen', 'CIGER', 'DeepCE', 'DLEPS', 'MultiDCP']:

    data = load_from_HDF('../results/4.Ligand_based_virtual_screening/HTR2A/{}.h5'.format(data_name))
    data_all = data['data']
    data_label = data['label']
    data_smi_idx = data['cp_id']

    df_data = pd.DataFrame(data['split'], columns=['split'])
    df_data['cid'] = data['cid']
    df_data['data_name'] = data['data_name']
        
    for idx in range(len(cell_ls)):
        selected_cid = cell_ls[idx]
        best_params = df_grid_search_results[(df_grid_search_results['data']== data_name) & (df_grid_search_results['target']== target)
                                  & (df_grid_search_results['cid']== selected_cid)].reset_index(drop=True).to_dict(orient='index')[0]
        metrics_dict = defaultdict(list)
        run_idx = 0

        index_train = df_data[(df_data['data_name']==data_name) & (df_data['cid']== selected_cid) & (df_data['split']== 'train')].index.tolist()
        index_test = df_data[(df_data['data_name']==data_name) & (df_data['cid']== selected_cid) & (df_data['split']== 'test')].index.tolist()
        # print(data_name, 'TRAIN:', len(index_train), 'TEST:', len(index_test))
        data_train, data_test = data_all[index_train], data_all[index_test]
        label_train, label_test = data_label[index_train], data_label[index_test]
        smi_idx_train, smi_idx_test = data_smi_idx[index_train], data_smi_idx[index_test]
        
        ## early fusion feature
        if data_name == 'TranSiGen':
            try:
                data_train_concat = np.concatenate((data_train_concat, data_train),axis=1) 
                data_test_concat = np.concatenate((data_test_concat, data_test),axis=1) 
            except:
                data_train_concat = data_train
                data_test_concat = data_test

        for run_random_seed in random_seed_ls:
            clf = ensemble.RandomForestClassifier(max_depth=best_params['max_depth'], 
                                                  n_estimators=best_params['n_estimators'], 
                                                  max_features='auto', 
                                                  criterion=best_params['criterion'], 
                                                  oob_score=best_params['oob_score'],
                                                  random_state=run_random_seed)
            clf = clf.fit(data_train, label_train)
            data_test_pred = clf.predict_proba(data_test)
            if data_name == 'TranSiGen':
                predict_result['label'] += [list(label_test)]
                predict_result['cp_id'] += [list(smi_idx_test)]
                predict_result['cid'] += [[selected_cid] *len(list(label_test))]
                predict_result['target'] += [[target] *len(list(label_test))]
                predict_result['random_seed'] += [[run_random_seed] *len(list(label_test))]
                predict_result['idx'] += [[run_idx] *len(list(label_test))]
            
            predict_result[data_name] += [list(data_test_pred[:, 1])]
            metrics_dict['AUROC'] += [roc_auc_score(label_test, data_test_pred[:, 1])]
            precision, recall, _thresholds = precision_recall_curve(label_test, data_test_pred[:, 1])
            metrics_dict['AUPR'] += [auc(recall, precision)]
            metrics_dict['random_seed'] += [run_random_seed]
            run_idx +=1

        df_metrics_tmp = pd.DataFrame.from_dict(metrics_dict)
        df_metrics_tmp.loc[:, 'data'] = data_name
        df_metrics_tmp.loc[:, 'cid'] = selected_cid
        df_metrics_tmp.loc[:, 'target'] = target
        df_metrics = pd.concat([df_metrics, df_metrics_tmp])

print('=========late fusion========')
for data_name in ['TranSiGen']:
    metrics_dict = defaultdict(list)
    for idx in range(len(random_seed_ls)):
        ensemble_dict = defaultdict(list)
        for cid_idx in range(len(cell_ls)):
            result_idx = cid_idx * 5 + idx
            ensemble_dict[cid_idx] += predict_result[data_name][result_idx]
        pred_ensemble = np.mean(np.array(list(ensemble_dict.values())) , axis=0)
        label_test = predict_result['label'][result_idx]

        predict_result[data_name + ' (late fusion)'] += [list(pred_ensemble)]
        metrics_dict['AUROC'] += [roc_auc_score(label_test, pred_ensemble)]
        precision, recall, _thresholds = precision_recall_curve(label_test, pred_ensemble)
        metrics_dict['AUPR'] += [auc(recall, precision)]

    df_metrics_tmp = pd.DataFrame.from_dict(metrics_dict)
    df_metrics_tmp.loc[:, 'data'] = data_name + ' (late fusion)'
    df_metrics_tmp.loc[:, 'cid'] = 'all'
    df_metrics_tmp.loc[:, 'target'] = target
    df_metrics = pd.concat([df_metrics, df_metrics_tmp])

print('=========early fusion========')
print(data_train_concat.shape, data_test_concat.shape)
data_name = 'TranSiGen (early fusion)'
selected_cid = 'all'
metrics_dict = defaultdict(list)
best_params = df_grid_search_results[(df_grid_search_results['data']== data_name) & (df_grid_search_results['target']== target)
                          & (df_grid_search_results['cid']== selected_cid)].reset_index(drop=True).to_dict(orient='index')[0]

for run_random_seed in random_seed_ls:
    clf = ensemble.RandomForestClassifier(max_depth=best_params['max_depth'], 
                                          n_estimators=best_params['n_estimators'], 
                                          max_features='auto', 
                                          criterion=best_params['criterion'], 
                                          oob_score=best_params['oob_score'],
                                          random_state=run_random_seed)

    clf = clf.fit(data_train_concat, label_train)
    data_test_pred = clf.predict_proba(data_test_concat)
    predict_result[data_name] += [list(data_test_pred[:, 1])]
    metrics_dict['AUROC'] += [roc_auc_score(label_test, data_test_pred[:, 1])]
    precision, recall, _thresholds = precision_recall_curve(label_test, data_test_pred[:, 1])
    metrics_dict['AUPR'] += [auc(recall, precision)]
    metrics_dict['random_seed'] += [run_random_seed]

df_metrics_tmp = pd.DataFrame.from_dict(metrics_dict)
df_metrics_tmp.loc[:, 'data'] = data_name
df_metrics_tmp.loc[:, 'cid'] = selected_cid
df_metrics_tmp.loc[:, 'target'] = target
df_metrics = pd.concat([df_metrics, df_metrics_tmp])

del data_train_concat, data_test_concat
        
df_metrics_kfold_mean = df_metrics.groupby(['data', 'target', 'cid']).mean().reset_index()
df_summary = df_metrics_kfold_mean.groupby(['target', 'data']).agg(func=['mean', 'std'])
round(df_summary,3) 

(197, 6846) (49, 6846)


Unnamed: 0_level_0,Unnamed: 1_level_0,AUROC,AUROC,AUPR,AUPR
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
target,data,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
HTR2A,CIGER,0.614,0.086,0.397,0.105
HTR2A,DLEPS,0.566,0.0,0.265,0.0
HTR2A,DeepCE,0.527,0.041,0.225,0.03
HTR2A,MultiDCP,0.547,0.069,0.267,0.062
HTR2A,TranSiGen,0.889,0.041,0.731,0.092
HTR2A,TranSiGen (early fusion),0.927,,0.768,
HTR2A,TranSiGen (late fusion),0.938,,0.785,


## structure representation-based model (ECFP4, KPGT)

In [4]:
selected_cid = 'all'
for data_name in ['ECFP4', 'KPGT']:
    best_params = df_grid_search_results[(df_grid_search_results['data']== data_name) & (df_grid_search_results['target']== target)
                              & (df_grid_search_results['cid']== selected_cid)].reset_index(drop=True).to_dict(orient='index')[0]
    metrics_dict = defaultdict(list)
    
    data = load_from_HDF('../results/4.Ligand_based_virtual_screening/HTR2A/{}.h5'.format(data_name))
    data_all = data['data']
    data_label = data['label']
    data_smi_idx = data['cp_id']
    
    df_data = pd.DataFrame(data['split'], columns=['split'])
    df_data['cid'] = data['cid']
    df_data['data_name'] = data['data_name']
    
    index_train = df_data[(df_data['data_name']==data_name) & (df_data['cid']== selected_cid) & (df_data['split']== 'train')].index.tolist()
    index_test = df_data[(df_data['data_name']==data_name) & (df_data['cid']== selected_cid) & (df_data['split']== 'test')].index.tolist()
    # print('TRAIN:', len(index_train), 'TEST:', len(index_test))
    
    data_train, data_test = data_all[index_train], data_all[index_test]
    label_train, label_test = data_label[index_train], data_label[index_test]
    smi_idx_train, smi_idx_test = data_smi_idx[index_train], data_smi_idx[index_test]
    
    for run_random_seed in random_seed_ls:
        clf = ensemble.RandomForestClassifier(max_depth=best_params['max_depth'], 
                                              n_estimators=best_params['n_estimators'], 
                                              max_features='auto', 
                                              criterion=best_params['criterion'], 
                                              oob_score=best_params['oob_score'],
                                              random_state=run_random_seed)
        clf = clf.fit(data_train, label_train)
        data_test_pred = clf.predict_proba(data_test)

        predict_result[data_name] += [list(data_test_pred[:, 1])]
        metrics_dict['AUROC'] += [roc_auc_score(label_test, data_test_pred[:, 1])]
        precision, recall, _thresholds = precision_recall_curve(label_test, data_test_pred[:, 1])
        metrics_dict['AUPR'] += [auc(recall, precision)]
        metrics_dict['random_seed'] += [run_random_seed]
        run_idx +=1

    df_metrics_tmp = pd.DataFrame.from_dict(metrics_dict)
    df_metrics_tmp.loc[:, 'data'] = data_name
    df_metrics_tmp.loc[:, 'cid'] = selected_cid
    df_metrics_tmp.loc[:, 'target'] = target
    df_metrics = pd.concat([df_metrics, df_metrics_tmp])

In [5]:
df_metrics_kfold_mean = df_metrics.groupby(['data', 'target', 'cid']).mean().reset_index()
df_summary = df_metrics_kfold_mean.groupby(['target', 'data']).agg(func=['mean', 'std'])

In [16]:
with open('../results/4.Ligand_based_virtual_screening/{}/prediction.pkl'.format(target) , 'wb') as f:
    pickle.dump(predict_result, f)
df_metrics.to_csv('../results/4.Ligand_based_virtual_screening/{}/results.csv'.format(target), index=False)

In [6]:
model_order = ['TranSiGen (early fusion)', 'TranSiGen (late fusion)',]
df_metrics_tmp = df_metrics[(df_metrics['data'].isin(model_order))]
df_metrics_funsion = df_metrics[(df_metrics['data'].isin(['TranSiGen'] ))
                               ].groupby(['random_seed','data', 'target']).mean().reset_index()
df_metrics_funsion['data'] = df_metrics_funsion['data'] + ' (7 cells)'
df_metrics_funsion['cid'] = 'all'
df_metrics_funsion = df_metrics_funsion[df_metrics_tmp.columns]
df_metrics_funsion = pd.concat([df_metrics_funsion, df_metrics_tmp])
df_metrics_funsion

Unnamed: 0,AUROC,AUPR,data,target,cid,random_seed
0,0.887546,0.734065,TranSiGen (7 cells),HTR2A,all,510395.0
1,0.882967,0.722346,TranSiGen (7 cells),HTR2A,all,532783.0
2,0.887546,0.724046,TranSiGen (7 cells),HTR2A,all,584403.0
3,0.885165,0.722715,TranSiGen (7 cells),HTR2A,all,630680.0
4,0.902015,0.751094,TranSiGen (7 cells),HTR2A,all,808431.0
0,0.94359,0.817895,TranSiGen (late fusion),HTR2A,all,
1,0.951282,0.809332,TranSiGen (late fusion),HTR2A,all,
2,0.933333,0.761364,TranSiGen (late fusion),HTR2A,all,
3,0.933333,0.770342,TranSiGen (late fusion),HTR2A,all,
4,0.930769,0.767021,TranSiGen (late fusion),HTR2A,all,


In [7]:
df_data = pd.read_csv('../results/4.Ligand_based_virtual_screening/{}/data.csv'.format(target))
smiles_train, smiles_test = list(df_data[df_data['split']=='train']['canonical_smiles']), list(df_data[df_data['split']=='test']['canonical_smiles'])

df_sim_and_result_all = pd.DataFrame(columns=['label', 'cp_id', 'ECFP_max_similarity', 'idx', 
                                              'ECFP4', 'KPGT', 'TranSiGen (early fusion)', 'TranSiGen (late fusion)'])

for te_fold_nums in range(5):
    train_ECFP_array = []
    for smi in smiles_train:
        mol = Chem.MolFromSmiles(smi)
        ECFP = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
        train_ECFP_array.append(ECFP)

    test_ECFP_array = []
    for smi in smiles_test:
        mol = Chem.MolFromSmiles(smi)
        ECFP = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
        test_ECFP_array.append(ECFP)

    max_ECFP_sims_in_train = []
    for idx in range(len(test_ECFP_array)):
        sims = DataStructs.BulkTanimotoSimilarity(test_ECFP_array[idx],train_ECFP_array)
        max_ECFP_sims_in_train.append(max(sims))


    df_sim_and_result = pd.DataFrame(predict_result['label'][te_fold_nums], columns=['label'])
    df_sim_and_result['cp_id'] = predict_result['cp_id'][te_fold_nums]
    df_sim_and_result['ECFP_max_similarity'] = max_ECFP_sims_in_train
    df_sim_and_result['idx'] = predict_result['idx'][te_fold_nums]

    for data_name in ['ECFP4', 'KPGT', 'TranSiGen (early fusion)', 'TranSiGen (late fusion)']:
        df_sim_and_result[data_name] = predict_result[data_name][te_fold_nums]
    df_sim_and_result_all = pd.concat([df_sim_and_result_all, df_sim_and_result])
    df_sim_and_result_all['label'] = df_sim_and_result_all['label'].astype('float64')

threshold_range = [0, 0.3, 1]
cuts = pd.cut(df_sim_and_result_all['ECFP_max_similarity'], bins=threshold_range)
df_sim_and_result_all['ECFP_max_similarity_threshold'] = cuts

df_result_by_sim_threshold = pd.DataFrame(columns=['ECFP_max_similarity_threshold', 'model', 'fold', 'AUROC', 'AUPR', 'active', 'inactive', 'count'])
for (idx, threshold), group in df_sim_and_result_all.groupby(by=['idx', 'ECFP_max_similarity_threshold']):   
    print(threshold, 'active:', group[group['label'] == 1].shape[0], 'inactive:', group[group['label'] == 0].shape[0],)
    for data_name in ['TranSiGen (early fusion)', 'TranSiGen (late fusion)', 'ECFP4', 'KPGT']:

        AUROC = roc_auc_score(group['label'], group[data_name])
        precision, recall, _thresholds = precision_recall_curve(group['label'], group[data_name])
        AUPR = auc(recall, precision)
        if data_name == 'TranSiGen (early fusion)': data_name = 'TranSiGen_EF'
        elif data_name == 'TranSiGen (late fusion)': data_name = 'TranSiGen_LF'
        df_result_by_sim_threshold.loc[df_result_by_sim_threshold.shape[0],:] = [threshold, data_name, idx, AUROC, AUPR,
                                                                                group[group['label'] == 1].shape[0], 
                                                                                group[group['label'] == 0].shape[0],
                                                                                group.shape[0] ]
df_result_by_sim_threshold

(0.0, 0.3] active: 4 inactive: 28
(0.3, 1.0] active: 6 inactive: 11
(0.0, 0.3] active: 4 inactive: 28
(0.3, 1.0] active: 6 inactive: 11
(0.0, 0.3] active: 4 inactive: 28
(0.3, 1.0] active: 6 inactive: 11
(0.0, 0.3] active: 4 inactive: 28
(0.3, 1.0] active: 6 inactive: 11
(0.0, 0.3] active: 4 inactive: 28
(0.3, 1.0] active: 6 inactive: 11


Unnamed: 0,ECFP_max_similarity_threshold,model,fold,AUROC,AUPR,active,inactive,count
0,"(0.0, 0.3]",TranSiGen_EF,0,0.964286,0.795833,4,28,32
1,"(0.0, 0.3]",TranSiGen_LF,0,0.991071,0.94375,4,28,32
2,"(0.0, 0.3]",ECFP4,0,0.915179,0.56875,4,28,32
3,"(0.0, 0.3]",KPGT,0,0.866071,0.639583,4,28,32
4,"(0.3, 1.0]",TranSiGen_EF,0,0.909091,0.878042,6,11,17
5,"(0.3, 1.0]",TranSiGen_LF,0,0.954545,0.906151,6,11,17
6,"(0.3, 1.0]",ECFP4,0,0.924242,0.788095,6,11,17
7,"(0.3, 1.0]",KPGT,0,0.984848,0.974206,6,11,17
8,"(0.0, 0.3]",TranSiGen_EF,1,0.982143,0.908333,4,28,32
9,"(0.0, 0.3]",TranSiGen_LF,1,0.982143,0.870833,4,28,32
