In [46]:
import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Fingerprints import FingerprintMols
from DeepPurpose.pybiomed_helper import _GetPseudoAAC, CalculateAADipeptideComposition, \
calcPubChemFingerAll, CalculateConjointTriad, GetQuasiSequenceOrder
import torch
from torch.utils import data
from torch.autograd import Variable
try:
	from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors
except:
	raise ImportError("Please install pip install git+https://github.com/bp-kelley/descriptastorus.")
from DeepPurpose.chemutils import get_mol, atom_features, bond_features, MAX_NB, ATOM_FDIM, BOND_FDIM
from subword_nmt.apply_bpe import BPE
import codecs
import pickle
import wget
from zipfile import ZipFile 
import os
import sys

from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import SequentialSampler
from torch import nn 

from tqdm import tqdm
import matplotlib.pyplot as plt
from time import time
from sklearn.metrics import mean_squared_error, roc_auc_score, average_precision_score, f1_score, log_loss
from lifelines.utils import concordance_index
from scipy.stats import pearsonr
import pickle 
torch.manual_seed(2)
np.random.seed(3)
import copy
from prettytable import PrettyTable

import os

from DeepPurpose.utils import *
from DeepPurpose.model_helper import Encoder_MultipleLayers, Embeddings        
from DeepPurpose.encoders import *
from DeepPurpose import DTI

In [2]:
data_path = './data//BindingDB_All.tsv'
df = pd.read_csv(data_path, sep = '\t', error_bad_lines=False)
df = df[df['Number of Protein Chains in Target (>1 implies a multichain complex)'] == 1.0]
df = df[df['Ligand SMILES'].notnull()]

KeyboardInterrupt: 

In [None]:
df = df[['BindingDB Reactant_set_id', 'Ligand InChI', 'Ligand SMILES',\
                  'PubChem CID', 'UniProt (SwissProt) Primary ID of Target Chain',\
                  'Target Source Organism According to Curator or DataSource',\
                  'BindingDB Target Chain  Sequence', 'Kd (nM)', 'IC50 (nM)', 'Ki (nM)',\
                  'EC50 (nM)', 'kon (M-1-s-1)', 'koff (s-1)','pH','Temp (C)']]
df.rename(columns={'BindingDB Reactant_set_id':'ID',
                        'Ligand SMILES':'SMILES',
                        'Ligand InChI':'InChI',
                        'PubChem CID':'PubChem_ID',
                        'UniProt (SwissProt) Primary ID of Target Chain':'UniProt_ID',
                        'BindingDB Target Chain  Sequence': 'Target Sequence',
                        'Target Source Organism According to Curator or DataSource': 'Organism',
                        'Kd (nM)':'Kd',
                        'IC50 (nM)':'IC50',
                        'Ki (nM)':'Ki',
                        'EC50 (nM)':'EC50',
                        'kon (M-1-s-1)':'kon',
                        'koff (s-1)':'koff',
                        'Temp (C)':'Temp',}, 
                        inplace=True)

In [None]:
df.head()

In [72]:
df['Temp'] = df['Temp'].str.rstrip('C')
df.count()

ID                 1733850
InChI              1733282
SMILES             1733850
PubChem_ID         1718479
UniProt_ID         1538086
Organism           1238470
Target Sequence    1733850
Kd                   74761
IC50               1080811
Ki                  417859
EC50                164210
kon                    654
koff                   524
pH                  204919
Temp                191364
dtype: int64

In [36]:
df.to_pickle("./df.pkl")

In [2]:
df = pd.read_pickle("./df.pkl")

In [3]:
idx_str = ['Kd', 'IC50', 'Ki','EC50','pH','Temp']
df_want = df
convert_to_log = 0

# have at least uniprot or pubchem ID
df_want = df_want[df_want.PubChem_ID.notnull() | df_want.UniProt_ID.notnull()]
df_want = df_want[df_want.InChI.notnull()]

for label in idx_str:
#    df_want = df_want[df_want[label].notnull()]
#    print(df_want.size)
    df_want[label] = df_want[label].str.replace('>', '')
    df_want[label] = df_want[label].str.replace('<', '')
    #df_want[label] = df_want[label].astype(float)
#    df_want = df_want[df_want[label] <= 10000000.0]

y = df_want[idx_str]
for label in idx_str:
    if convert_to_log:
            print('Default set to logspace (nM -> p) for easier regression')
            y[label] = convert_y_unit(df_want[label].values, 'nM', 'p') 
    else:
            y[label] = df_want[label].values

X_drugs = df_want.SMILES.values
X_targets = df_want['Target Sequence'].values
y = y.apply(pd.to_numeric, errors='coerce')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [5]:
X_drugs.size

1730866

In [4]:
import numpy as np
from sklearn.impute import SimpleImputer, KNNImputer

imputer = SimpleImputer() #TODO: better imputer
y_i = imputer.fit_transform(y)

In [6]:
y_i = pd.DataFrame(data=y_i,columns=idx_str)
y_i

Unnamed: 0,Kd,IC50,Ki,EC50,Temp
0,872669.450093,1.264663e+08,2.400000e-01,362878.629538,37.000000
1,872669.450093,1.264663e+08,2.500000e-01,362878.629538,37.000000
2,872669.450093,1.264663e+08,4.100000e-01,362878.629538,37.000000
3,872669.450093,1.264663e+08,8.000000e-01,362878.629538,37.000000
4,872669.450093,1.264663e+08,9.900000e-01,362878.629538,37.000000
...,...,...,...,...,...
1730861,872669.450093,1.264663e+08,1.941028e+06,152.000000,27.936324
1730862,872669.450093,1.264663e+08,1.941028e+06,601.000000,27.936324
1730863,872669.450093,1.264663e+08,1.941028e+06,12.000000,27.936324
1730864,872669.450093,1.264663e+08,1.941028e+06,402.000000,27.936324


In [7]:
df_data = y_i
df_data['SMILES'] = X_drugs
df_data['Target Sequence'] = X_targets

print('in total: ' + str(len(df_data)) + ' drug-target pairs')
df_data

in total: 1730866 drug-target pairs


Unnamed: 0,Kd,IC50,Ki,EC50,Temp,SMILES,Target Sequence
0,872669.450093,1.264663e+08,2.400000e-01,362878.629538,37.000000,COc1cc2c(Nc3ccc(Br)cc3F)ncnc2cc1OCC1CCN(C)CC1,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...
1,872669.450093,1.264663e+08,2.500000e-01,362878.629538,37.000000,O[C@@H]1[C@@H](O)[C@@H](Cc2ccccc2)N(C\C=C\c2cn...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...
2,872669.450093,1.264663e+08,4.100000e-01,362878.629538,37.000000,O[C@@H]1[C@@H](O)[C@@H](Cc2ccccc2)N(CC2CC2)C(=...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...
3,872669.450093,1.264663e+08,8.000000e-01,362878.629538,37.000000,OCCCCCCN1[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)[C@@...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...
4,872669.450093,1.264663e+08,9.900000e-01,362878.629538,37.000000,OCCCCCN1[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)[C@@H...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...
...,...,...,...,...,...,...,...
1730861,872669.450093,1.264663e+08,1.941028e+06,152.000000,27.936324,Oc1ccc(Br)cc1Cn1c(nc2ccc(cc12)[N+]([O-])=O)-c1...,MWRCGGRQGLCVLRRLSGGHAHHRAWRWNSNRACERALQYKLGDKI...
1730862,872669.450093,1.264663e+08,1.941028e+06,601.000000,27.936324,Oc1ccc(Br)cc1CN1C(N(Cc2cc(Br)ccc2O)c2cc(ccc12)...,MWRCGGRQGLCVLRRLSGGHAHHRAWRWNSNRACERALQYKLGDKI...
1730863,872669.450093,1.264663e+08,1.941028e+06,12.000000,27.936324,Oc1ccc(Br)cc1Cn1c(nc2ccc(cc12)[N+]([O-])=O)-c1...,MWRCGGRQGLCVLRRLSGGHAHHRAWRWNSNRACERALQYKLGDKI...
1730864,872669.450093,1.264663e+08,1.941028e+06,402.000000,27.936324,Oc1ccc(Br)cc1CN1C(N(Cc2cc(Br)ccc2O)c2cc(ccc12)...,MWRCGGRQGLCVLRRLSGGHAHHRAWRWNSNRACERALQYKLGDKI...


In [13]:
df_backup = df_data
df_data = df_data.head(10000)

In [14]:
import time

drug_func_list= [smiles2morgan,trans_drug,drug2emb_encoder]
#TODO: add calcPubChemFingerAll back in when it's not broken
#TODO: smiles2rdkit2d takes forever and can be added later
#TODO: smiles2mpnnfeature doesn't take super long (around 40 min on desktop) but can be added later
#TODO: same wrt smiles2daylight
column_name = 'SMILES'
start = time.time()

for func in drug_func_list:
    save_column_name = func.__name__
    unique = pd.Series(df_data[column_name].unique()).apply(func)
    unique_dict = dict(zip(df_data[column_name].unique(), unique))
    df_data[save_column_name] = [unique_dict[i] for i in df_data[column_name]]
    end = time.time()
    print(end - start)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  from ipykernel import kernelapp as app


3.8191072940826416
4.143819570541382
4.53897762298584


In [15]:
prot_func_list = [CalculateConjointTriad, trans_protein, protein2emb_encoder]
#TODO: run CalculateAADipeptideComposition and _GetPseudoAAC when time permits
#TODO: GetQuasiSequenceOrder is broken
column_name = 'Target Sequence'
start = time.time()

for func in prot_func_list:
    save_column_name = func.__name__
    AA = pd.Series(df_data[column_name].unique()).apply(func)
    AA_dict = dict(zip(df_data[column_name].unique(), AA))
    df_data[save_column_name] = [AA_dict[i] for i in df_data[column_name]]
    end = time.time()
    print(end - start)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  # This is added back by InteractiveShellApp.init_path()


0.21258163452148438
0.5405821800231934
0.6225848197937012


In [13]:
df_data.to_csv("./df_data.csv")

KeyboardInterrupt: 

In [18]:
df_data.to_hdf('df_data.h5', key='df', mode='w')

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['SMILES', 'Target Sequence', 'smiles2morgan', 'smiles2daylight',
       'trans_drug', 'drug2emb_encoder', 'CalculateConjointTriad',
       'trans_protein', 'protein2emb_encoder'],
      dtype='object')]

  encoding=encoding,


MemoryError: 

In [17]:
import tables

In [None]:
df_data.to_csv('df_data1.csv.gz'
         , sep='|'
         , header=True
         , index=False
         , chunksize=10000
         , compression='gzip'
         , encoding='utf-8')

In [None]:
tinydata.size

In [19]:
# dti split

split_method = 'cold_drug'
random_seed = 1
frac = [0.7, 0.2, 0.1]

print('splitting dataset...')

#TODO: what is HTS

if split_method == 'random': 
    train, val, test = create_fold(df_data, random_seed, frac)
elif split_method == 'cold_drug':
    train, val, test = create_fold_setting_cold_drug(df_data, random_seed, frac)
elif split_method == 'HTS':
    train, val, test = create_fold_setting_cold_drug(df_data, random_seed, frac)
    val = pd.concat([val[val.Label == 1].drop_duplicates(subset = 'SMILES'), val[val.Label == 0]])
    test = pd.concat([test[test.Label == 1].drop_duplicates(subset = 'SMILES'), test[test.Label == 0]])        
elif split_method == 'cold_protein':
    train, val, test = create_fold_setting_cold_protein(df_data, random_seed, frac)
elif split_method == 'repurposing_VS':
    train = df_data
    val = df_data
    test = df_data
elif split_method == 'no_split':
    print('do not do train/test split on the data for already splitted data')
else:
    raise AttributeError("Please select one of the three split method: random, cold_drug, cold_target!")
    
print('Done.')

train = train.reset_index(drop=True)
val = val.reset_index(drop=True)
test = test.reset_index(drop=True)
    

splitting dataset...
Done.


In [13]:
train

NameError: name 'train' is not defined

In [37]:

result_folder = "./result/"
input_dim_drug = 1024
input_dim_protein = 8420
hidden_dim_drug = 256
hidden_dim_protein = 256
cls_hidden_dims = [1024, 1024, 512]
mlp_hidden_dims_drug = [1024, 256, 64]
mlp_hidden_dims_target = [1024, 256, 64]
batch_size = 256
train_epoch = 10
test_every_X_epoch = 20
LR = 1e-4
decay = 0
transformer_emb_size_drug = 128
transformer_intermediate_size_drug = 512
transformer_num_attention_heads_drug = 8
transformer_n_layer_drug = 8
transformer_emb_size_target = 64
transformer_intermediate_size_target = 256
transformer_num_attention_heads_target = 4
transformer_n_layer_target = 2
transformer_dropout_rate = 0.1
transformer_attention_probs_dropout = 0.1
transformer_hidden_dropout_rate = 0.1
mpnn_hidden_size = 50
mpnn_depth = 3
cnn_drug_filters = [32,64,96]
cnn_drug_kernels = [4,6,8]
cnn_target_filters = [32,64,96]
cnn_target_kernels = [4,8,12]
rnn_Use_GRU_LSTM_drug = 'GRU'
rnn_drug_hid_dim = 64
rnn_drug_n_layers = 2
rnn_drug_bidirectional = True
rnn_Use_GRU_LSTM_target = 'GRU'
rnn_target_hid_dim = 64
rnn_target_n_layers = 2
rnn_target_bidirectional = True
num_workers = 0 

base_config = {'input_dim_drug': input_dim_drug,
                'input_dim_protein': input_dim_protein,
                'hidden_dim_drug': hidden_dim_drug, # hidden dim of drug
                'hidden_dim_protein': hidden_dim_protein, # hidden dim of protein
                'cls_hidden_dims' : cls_hidden_dims, # decoder classifier dim 1
                'batch_size': batch_size,
                'train_epoch': train_epoch,
                'test_every_X_epoch': test_every_X_epoch, 
                'LR': LR,
                'result_folder': result_folder,
                'binary': False,
                'num_workers' : num_workers,
                'result_folder' : "./result/",
                'input_dim_drug' : 1024,
                'input_dim_protein': 8420,
                'hidden_dim_drug': 256,
                'hidden_dim_protein': 256,
                'cls_hidden_dims': [1024, 1024, 512],
                'mlp_hidden_dims_drug': [1024, 256, 64],
                'mlp_hidden_dims_target': [1024, 256, 64],
                'batch_size': 256,
                'train_epoch': 10,
                'test_every_X_epoch': 20,
                'LR': 1e-4,
                'decay': 0,
                'transformer_emb_size_drug': 128,
                'transformer_intermediate_size_drug': 512,
                'transformer_num_attention_heads_drug': 8,
                'transformer_n_layer_drug': 8,
                'transformer_emb_size_target': 64,
                'transformer_intermediate_size_target': 256,
                'transformer_num_attention_heads_target': 4,
                'transformer_n_layer_target': 2,
                'transformer_dropout_rate': 0.1,
                'transformer_attention_probs_dropout': 0.1,
                'transformer_hidden_dropout_rate': 0.1,
                'mpnn_hidden_size': 50,
                'mpnn_depth': 3,
                'cnn_drug_filters': [32,64,96],
                'cnn_drug_kernels': [4,6,8],
                'cnn_target_filters': [32,64,96],
                'cnn_target_kernels': [4,8,12],
                'rnn_Use_GRU_LSTM_drug': 'GRU',
                'rnn_drug_hid_dim': 64,
                'rnn_drug_n_layers': 2,
                'rnn_drug_bidirectional' : True,
                'rnn_Use_GRU_LSTM_target' : 'GRU',
                'rnn_target_hid_dim' : 64,
                'rnn_target_n_layers' : 2,
                'rnn_target_bidirectional' : True,
                'num_workers' : 0 
}
base_config['result_folder']

'./result/'

In [38]:
if not os.path.exists(base_config['result_folder']):
    os.makedirs(base_config['result_folder'])

base_config['mlp_hidden_dims_drug'] = mlp_hidden_dims_drug # MLP classifier dim 1				
base_config['input_dim_drug'] = 881 #could be 2048 or 200 or 2586
base_config['cnn_drug_filters'] = cnn_drug_filters
base_config['cnn_drug_kernels'] = cnn_drug_kernels
base_config['rnn_Use_GRU_LSTM_drug'] = rnn_Use_GRU_LSTM_drug
base_config['rnn_drug_hid_dim'] = rnn_drug_hid_dim
base_config['rnn_drug_n_layers'] = rnn_drug_n_layers
base_config['rnn_drug_bidirectional'] = rnn_drug_bidirectional 
base_config['transformer_emb_size_drug'] = transformer_emb_size_drug
base_config['transformer_num_attention_heads_drug'] = transformer_num_attention_heads_drug
base_config['transformer_intermediate_size_drug'] = transformer_intermediate_size_drug
base_config['transformer_n_layer_drug'] = transformer_n_layer_drug
base_config['transformer_dropout_rate'] = transformer_dropout_rate
base_config['transformer_attention_probs_dropout'] = transformer_attention_probs_dropout
base_config['transformer_hidden_dropout_rate'] = transformer_hidden_dropout_rate
base_config['hidden_dim_drug'] = transformer_emb_size_drug #could also be hidden_dim_drug
base_config['batch_size'] = batch_size 
base_config['mpnn_hidden_size'] = mpnn_hidden_size
base_config['mpnn_depth'] = mpnn_depth

base_config['mlp_hidden_dims_target'] = mlp_hidden_dims_target # MLP classifier dim 1				
base_config['input_dim_protein'] = 30 #could be 343 or 100 or 4114
base_config['cnn_target_filters'] = cnn_target_filters
base_config['cnn_target_kernels'] = cnn_target_kernels
base_config['rnn_Use_GRU_LSTM_target'] = rnn_Use_GRU_LSTM_target
base_config['rnn_target_hid_dim'] = rnn_target_hid_dim
base_config['rnn_target_n_layers'] = rnn_target_n_layers
base_config['rnn_target_bidirectional'] = rnn_target_bidirectional 
base_config['cnn_target_filters'] = cnn_target_filters
base_config['cnn_target_kernels'] = cnn_target_kernels
base_config['transformer_emb_size_target'] = transformer_emb_size_target
base_config['transformer_num_attention_heads_target'] = transformer_num_attention_heads_target
base_config['transformer_intermediate_size_target'] = transformer_intermediate_size_target
base_config['transformer_n_layer_target'] = transformer_n_layer_target	
base_config['transformer_dropout_rate'] = transformer_dropout_rate
base_config['transformer_attention_probs_dropout'] = transformer_attention_probs_dropout
base_config['transformer_hidden_dropout_rate'] = transformer_hidden_dropout_rate
base_config['hidden_dim_protein'] = transformer_emb_size_target

config = base_config

In [39]:
from collections import namedtuple
model_drug_tuple = namedtuple("model_drug_tuple", "MLP CNN CNN_RNN transformer MPNN")
model_protein_tuple = namedtuple("model_protein_tuple", "MLP CNN CNN_RNN transformer")

model_drug_MLP = MLP(config['input_dim_drug'], config['hidden_dim_drug'], config['mlp_hidden_dims_drug'])
model_drug_CNN = CNN('drug', **config)
model_drug_CNN_RNN = CNN_RNN('drug', **config)
model_drug_transformer = transformer('drug', **config)
model_drug_MPNN = MPNN(config['hidden_dim_drug'], config['mpnn_depth'])

model_drug = model_drug_tuple(model_drug_MLP, model_drug_CNN, model_drug_CNN_RNN, model_drug_transformer, model_drug_MPNN)

model_protein_MLP = MLP(config['input_dim_protein'], config['hidden_dim_protein'], config['mlp_hidden_dims_target'])
model_protein_CNN = CNN('protein', **config)
model_protein_CNN_RNN = CNN_RNN('protein', **config)
model_protein_transformer = transformer('protein', **config)

model_protein = model_protein_tuple(model_protein_MLP, model_protein_CNN, model_protein_CNN_RNN, model_protein_transformer)

model_feature_tuple = namedtuple("model_feature_tuple","model_drug model_protein model_df")
model_features = model_feature_tuple(model_drug, model_protein, train)

In [40]:
class Classifier_o(nn.Sequential):
    def __init__(self, model_struct, **config):
        super(Classifier_o, self).__init__()
        self.input_dim_drug = config['hidden_dim_drug']
        self.input_dim_protein = config['hidden_dim_protein']

        self.model_struct = model_struct

        self.dropout = nn.Dropout(0.1)

        self.hidden_dims = config['cls_hidden_dims']
        layer_size = len(self.hidden_dims) + 1
        dims = [self.input_dim_drug + self.input_dim_protein] + self.hidden_dims + [1]

        self.predictor = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(layer_size)])

    def forward(self, v_D, v_P):
        # each encoding
        v_D = self.model_struct.model_drug(v_D)
        v_P = self.model_struct.model_protein(v_P)
        # concatenate and classify
        v_f = torch.cat((v_D, v_P), 1)
        for i, l in enumerate(self.predictor):
            if i==(len(self.predictor)-1):
                v_f = l(v_f)
            else:
                v_f = F.relu(self.dropout(l(v_f)))
        return v_f

In [41]:
model = Classifier_o(model_features, **config)

In [42]:
input_dim_drug = config['hidden_dim_drug']
input_dim_protein = config['hidden_dim_protein']
model_drug = model_drug
model_protein = model_protein
dropout = nn.Dropout(0.1)
hidden_dims = config['cls_hidden_dims']
layer_size = len(hidden_dims) + 1
dims = [input_dim_drug + input_dim_protein] + hidden_dims + [1]
predictor = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(layer_size)])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
result_folder = config['result_folder']
       
binary = False

In [None]:
lr = config['LR']
decay = config['decay']
BATCH_SIZE = config['batch_size']
train_epoch = config['train_epoch']
loss_history = []
verbose = True

model = model.to(device)

# support multiple GPUs
if torch.cuda.device_count() > 1:
    if verbose:
        print("Let's use " + str(torch.cuda.device_count()) + " GPUs!")
    model = nn.DataParallel(model, dim = 0)
elif torch.cuda.device_count() == 1:
    if verbose:
        print("Let's use " + str(torch.cuda.device_count()) + " GPU!")
else:
    if verbose:
        print("Let's use CPU/s!")
# Future TODO: support multiple optimizers with parameters
opt = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = decay)
if verbose:
    print('--- Data Preparation ---')

params = {'batch_size': BATCH_SIZE,
        'shuffle': True,
        'num_workers': config['num_workers'],
        'drop_last': False}

params['collate_fn'] = DTI.mpnn_collate_func

In [52]:
train

Unnamed: 0,Kd,IC50,Ki,EC50,Temp,SMILES,Target Sequence,smiles2morgan,trans_drug,drug2emb_encoder,CalculateConjointTriad,trans_protein,protein2emb_encoder
0,872669.450093,1.264663e+08,2.400000e-01,362878.629538,37.0,COc1cc2c(Nc3ccc(Br)cc3F)ncnc2cc1OCC1CCN(C)CC1,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[C, O, c, 1, c, c, 2, c, (, N, c, 3, c, c, c, ...","([515, 343, 982, 52, 93, 210, 614, 1244, 690, ...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
1,872669.450093,1.264663e+08,2.500000e-01,362878.629538,37.0,O[C@@H]1[C@@H](O)[C@@H](Cc2ccccc2)N(C\C=C\c2cn...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], 1, [, C, ?, ?, H, ], (, ...","([1138, 186, 144, 265, 199, 188, 381, 1734, 13...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
2,872669.450093,1.264663e+08,4.100000e-01,362878.629538,37.0,O[C@@H]1[C@@H](O)[C@@H](Cc2ccccc2)N(CC2CC2)C(=...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], 1, [, C, ?, ?, H, ], (, ...","([1138, 186, 144, 265, 261, 158, 322, 65, 188,...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
3,872669.450093,1.264663e+08,1.400000e+00,362878.629538,37.0,OCCCCCN1[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)[C@@H...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, C, C, C, C, C, N, 1, [, C, ?, H, ], (, C, ...","([700, 223, 1769, 833, 144, 265, 261, 103, 373...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
4,872669.450093,1.264663e+08,1.600000e+00,362878.629538,37.0,CCCCN1[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)[C@@H](...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[C, C, C, C, N, 1, [, C, ?, H, ], (, C, c, 2, ...","([1667, 1769, 833, 144, 265, 1821, 1598, 0, 0,...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
6920,872669.450093,4.000000e-02,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1cccnc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 1069, 138, 202, 179, 92, 454, 258...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
6921,872669.450093,1.000000e-01,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1ccncc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 113, 30, 185, 83, 202, 179, 92, 4...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
6922,872669.450093,2.000000e-02,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1cccnc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 1069, 138, 202, 179, 92, 454, 258...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
6923,872669.450093,3.000000e-02,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1ccncc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 113, 30, 185, 83, 202, 179, 92, 4...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [58]:
class data_process_loader_o(data.Dataset):

    def __init__(self, list_IDs, df, **config):
        'Initialization'
        self.list_IDs = list_IDs
        self.df = df
        self.config = config

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        index = self.list_IDs[index]
        v = self.df.iloc[index]
#        v_d = self.df.iloc[index]['drug_encoding']        
#        v_d = drug_2_embed(v_d)
#        v_p = self.df.iloc[index]['target_encoding']
#        v_p = protein_2_embed(v_p)
        #y = self.labels[index]
        return v

In [57]:
from sklearn.preprocessing import OneHotEncoder
enc_protein = OneHotEncoder().fit(np.array(amino_char).reshape(-1, 1))
enc_drug = OneHotEncoder().fit(np.array(smiles_char).reshape(-1, 1))

def protein_2_embed(x):
	return enc_protein.transform(np.array(x).reshape(-1,1)).toarray().T
def drug_2_embed(x):
	return enc_drug.transform(np.array(x).reshape(-1,1)).toarray().T    

In [None]:
from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)

    def __len__(self):
        return count

In [71]:
train

Unnamed: 0,Kd,IC50,Ki,EC50,Temp,SMILES,Target Sequence,smiles2morgan,trans_drug,drug2emb_encoder,CalculateConjointTriad,trans_protein,protein2emb_encoder
0,872669.450093,1.264663e+08,2.400000e-01,362878.629538,37.0,COc1cc2c(Nc3ccc(Br)cc3F)ncnc2cc1OCC1CCN(C)CC1,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[C, O, c, 1, c, c, 2, c, (, N, c, 3, c, c, c, ...","([515, 343, 982, 52, 93, 210, 614, 1244, 690, ...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
1,872669.450093,1.264663e+08,2.500000e-01,362878.629538,37.0,O[C@@H]1[C@@H](O)[C@@H](Cc2ccccc2)N(C\C=C\c2cn...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], 1, [, C, ?, ?, H, ], (, ...","([1138, 186, 144, 265, 199, 188, 381, 1734, 13...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
2,872669.450093,1.264663e+08,4.100000e-01,362878.629538,37.0,O[C@@H]1[C@@H](O)[C@@H](Cc2ccccc2)N(CC2CC2)C(=...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], 1, [, C, ?, ?, H, ], (, ...","([1138, 186, 144, 265, 261, 158, 322, 65, 188,...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
3,872669.450093,1.264663e+08,1.400000e+00,362878.629538,37.0,OCCCCCN1[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)[C@@H...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, C, C, C, C, C, N, 1, [, C, ?, H, ], (, C, ...","([700, 223, 1769, 833, 144, 265, 261, 103, 373...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
4,872669.450093,1.264663e+08,1.600000e+00,362878.629538,37.0,CCCCN1[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)[C@@H](...,PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[C, C, C, C, N, 1, [, C, ?, H, ], (, C, c, 2, ...","([1667, 1769, 833, 144, 265, 1821, 1598, 0, 0,...","[0, 3, 0, 1, 0, 1, 0, 3, 2, 1, 0, 0, 1, 0, 1, ...","[P, Q, I, T, L, W, Q, R, P, L, V, T, I, K, I, ...","([14, 212, 35, 2864, 47, 69, 86, 497, 3636, 21..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
6920,872669.450093,4.000000e-02,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1cccnc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 1069, 138, 202, 179, 92, 454, 258...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
6921,872669.450093,1.000000e-01,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1ccncc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 113, 30, 185, 83, 202, 179, 92, 4...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
6922,872669.450093,2.000000e-02,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1cccnc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 1069, 138, 202, 179, 92, 454, 258...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
6923,872669.450093,3.000000e-02,1.941028e+06,362878.629538,30.0,O[C@@H](C[C@@H](Cc1ccncc1)C(=O)N[C@@H]1[C@H](O...,pqitlwkrplvtikiggqlkealldtgaddtvleemnlpgrwkpkm...,"[0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[O, [, C, ?, ?, H, ], (, C, [, C, ?, ?, H, ], ...","([1245, 144, 113, 30, 185, 83, 202, 179, 92, 4...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[P, Q, I, T, L, W, K, R, P, L, V, T, I, K, I, ...","([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [69]:
training_generator = torch.utils.data.DataLoader(train, **params)
validation_generator = torch.utils.data.DataLoader(val, **params)

In [68]:
training_generator[1]

TypeError: 'DataLoader' object is not subscriptable

In [70]:
if test is not None:
    info = data_process_loader_o(test.index.values,  test, **config)
    params_test = {'batch_size': BATCH_SIZE,
            'shuffle': False,
            'num_workers': config['num_workers'],
            'drop_last': False,
            'sampler':SequentialSampler(info)}
    params_test['collate_fn'] = DTI.mpnn_collate_func
    testing_generator = data.DataLoader(data_process_loader_o(test.index.values, test, **config), **params_test)

# early stopping
if binary:
    max_auc = 0
else:
    max_MSE = 10000
model_max = copy.deepcopy(model)

valid_metric_record = []
valid_metric_header = ["# epoch"] 
if binary:
    valid_metric_header.extend(["AUROC", "AUPRC", "F1"])
else:
    valid_metric_header.extend(["MSE", "Pearson Correlation", "with p-value", "Concordance Index"])
table = PrettyTable(valid_metric_header)
float2str = lambda x:'%0.4f'%x
if verbose:
    print('--- Go for Training ---')
t_start = time() 
for epo in range(train_epoch):
    for v in enumerate(training_generator):
        if self.target_encoding == 'Transformer':
            v_p = v_p
        else:
            v_p = v_p.float().to(self.device) 
        if self.drug_encoding == "MPNN" or self.drug_encoding == 'Transformer':
            v_d = v_d
        else:
            v_d = v_d.float().to(self.device)                
            #score = self.model(v_d, v_p.float().to(self.device))

        score = self.model(v_d, v_p)
        label = Variable(torch.from_numpy(np.array(label)).float()).to(self.device)

        if self.binary:
            loss_fct = torch.nn.BCELoss()
            m = torch.nn.Sigmoid()
            n = torch.squeeze(m(score), 1)
            loss = loss_fct(n, label)
        else:
            loss_fct = torch.nn.MSELoss()
            n = torch.squeeze(score, 1)
            loss = loss_fct(n, label)
        loss_history.append(loss.item())

        opt.zero_grad()
        loss.backward()
        opt.step()

        if verbose:
            if (i % 100 == 0):
                t_now = time()
                print('Training at Epoch ' + str(epo + 1) + ' iteration ' + str(i) + \
                    ' with loss ' + str(loss.cpu().detach().numpy())[:7] +\
                    ". Total time " + str(int(t_now - t_start)/3600)[:7] + " hours") 
                ### record total run time

    ##### validate, select the best model up to now 
    with torch.set_grad_enabled(False):
        if self.binary:  
            ## binary: ROC-AUC, PR-AUC, F1, cross-entropy loss
            auc, auprc, f1, loss, logits = self.test_(validation_generator, self.model)
            lst = ["epoch " + str(epo)] + list(map(float2str,[auc, auprc, f1]))
            valid_metric_record.append(lst)
            if auc > max_auc:
                model_max = copy.deepcopy(self.model)
                max_auc = auc   
            if verbose:
                print('Validation at Epoch '+ str(epo + 1) + ' , AUROC: ' + str(auc)[:7] + \
                  ' , AUPRC: ' + str(auprc)[:7] + ' , F1: '+str(f1)[:7] + ' , Cross-entropy Loss: ' + \
                  str(loss)[:7])
        else:  
            ### regression: MSE, Pearson Correlation, with p-value, Concordance Index  
            mse, r2, p_val, CI, logits = self.test_(validation_generator, self.model)
            lst = ["epoch " + str(epo)] + list(map(float2str,[mse, r2, p_val, CI]))
            valid_metric_record.append(lst)
            if mse < max_MSE:
                model_max = copy.deepcopy(self.model)
                max_MSE = mse
            if verbose:
                print('Validation at Epoch '+ str(epo + 1) + ' , MSE: ' + str(mse)[:7] + ' , Pearson Correlation: '\
                 + str(r2)[:7] + ' with p-value: ' + str(p_val)[:7] +' , Concordance Index: '+str(CI)[:7])
    table.add_row(lst)


# load early stopped model
self.model = model_max

#### after training 
prettytable_file = os.path.join(self.result_folder, "valid_markdowntable.txt")
with open(prettytable_file, 'w') as fp:
    fp.write(table.get_string())

if test is not None:
    if verbose:
        print('--- Go for Testing ---')
    if self.binary:
        auc, auprc, f1, loss, logits = self.test_(testing_generator, model_max, test = True)
        test_table = PrettyTable(["AUROC", "AUPRC", "F1"])
        test_table.add_row(list(map(float2str, [auc, auprc, f1])))
        if verbose:
            print('Validation at Epoch '+ str(epo + 1) + ' , AUROC: ' + str(auc)[:7] + \
              ' , AUPRC: ' + str(auprc)[:7] + ' , F1: '+str(f1)[:7] + ' , Cross-entropy Loss: ' + \
              str(loss)[:7])				
    else:
        mse, r2, p_val, CI, logits = self.test_(testing_generator, model_max)
        test_table = PrettyTable(["MSE", "Pearson Correlation", "with p-value", "Concordance Index"])
        test_table.add_row(list(map(float2str, [mse, r2, p_val, CI])))
        if verbose:
            print('Testing MSE: ' + str(mse) + ' , Pearson Correlation: ' + str(r2) 
              + ' with p-value: ' + str(p_val) +' , Concordance Index: '+str(CI))
    np.save(os.path.join(self.result_folder, str(self.drug_encoding) + '_' + str(self.target_encoding) 
             + '_logits.npy'), np.array(logits))                

    ######### learning record ###########

    ### 1. test results
    prettytable_file = os.path.join(self.result_folder, "test_markdowntable.txt")
    with open(prettytable_file, 'w') as fp:
        fp.write(test_table.get_string())

### 2. learning curve 
fontsize = 16
iter_num = list(range(1,len(loss_history)+1))
plt.figure(3)
plt.plot(iter_num, loss_history, "bo-")
plt.xlabel("iteration", fontsize = fontsize)
plt.ylabel("loss value", fontsize = fontsize)
pkl_file = os.path.join(self.result_folder, "loss_curve_iter.pkl")
with open(pkl_file, 'wb') as pck:
    pickle.dump(loss_history, pck)

fig_file = os.path.join(self.result_folder, "loss_curve.png")
plt.savefig(fig_file)
if verbose:
    print('--- Training Finished ---')

--- Go for Training ---


KeyError: 588