In [1]:
import os 
import glob
from tqdm import tqdm
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
from math import sqrt
import pickle
from multiprocessing import Pool

import numpy as np 
import matplotlib.pyplot as plt 

In [2]:
def prepare_dataset(feature_path):
    feature = pd.read_csv(feature_path) # Load feature (Compressed File)
    ligand_descriptors = pd.read_csv("./Features/RDKit_Discriptors.csv") # Load ligand descriptors
    binding_data = pd.read_csv("./Features/BindingData.csv") # Load binding affinity data
    # Merge descriptors
    feature = feature.merge(ligand_descriptors, left_on="PDB", right_on="PDB")
    feature = feature.merge(binding_data, left_on="PDB", right_on="PDB")
    feature.head()

    # Split training and test sets
    x_train = feature[feature["SET"] == "Train"][list(feature.columns)[1:-2]]
    y_train = feature[feature["SET"] == "Train"]["pK"]

    x_test = feature[feature["SET"] == "Test"][list(feature.columns)[1:-2]]
    y_test = feature[feature["SET"] == "Test"]["pK"]
    return x_train, y_train, x_test, y_test

### Model training with multi-shelled ECIF feature

In [3]:
feature_path = './Features/MSECIFv2/MSECIFv2_thresh10.0_step2.0.csv'
x_train, y_train, x_test, y_test = prepare_dataset(feature_path)


In [4]:
x_train

Unnamed: 0,C;4;1;3;0-Br;1;1;0;0-2.5,C;4;1;3;0-C;3;3;0;1-2.5,C;4;1;3;0-C;4;1;1;0-2.5,C;4;1;3;0-C;4;1;2;0-2.5,C;4;1;3;0-C;4;1;3;0-2.5,C;4;1;3;0-C;4;2;0;0-2.5,C;4;1;3;0-C;4;2;1;0-2.5,C;4;1;3;0-C;4;2;1;1-2.5,C;4;1;3;0-C;4;2;2;0-2.5,C;4;1;3;0-C;4;2;2;1-2.5,...,fr_quatN,fr_sulfide,fr_sulfonamd,fr_sulfone,fr_term_acetylene,fr_tetrazole,fr_thiazole,fr_thiocyan,fr_thiophene,fr_urea
0,0,0,0,0,0,0,0,0,0,0,...,1,1,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,1,1,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9579,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9580,0,0,0,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0
9581,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9582,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
seed = 3094
params = dict(
    n_estimators=20000, 
    max_features="sqrt", 
    max_depth=10, 
    min_samples_split=3, 
    learning_rate=0.005, 
    loss='squared_error', 
    subsample=0.6
)
GBT = GradientBoostingRegressor(random_state=seed, **params)
GBT.fit(x_train,y_train)

GradientBoostingRegressor(learning_rate=0.005, max_depth=10,
                          max_features='sqrt', min_samples_split=3,
                          n_estimators=20000, random_state=3094, subsample=0.6)

In [6]:
y_pred = GBT.predict(x_test)
pcc = pearsonr(y_test,y_pred)[0]
rmse = sqrt(mean_squared_error(y_test,y_pred))
print("Pearson correlation coefficient for GBT: ", pcc)
print("RMSE for GBT:", rmse)

Pearson correlation coefficient for GBT:  0.8768865160831968
RMSE for GBT: 1.1516889227460871


### Model training with weighted ECIF feature

In [7]:
feature_path = './Features/WECIFv2/WECIFv2_thresh10.0_squaredTrue.csv'
x_train, y_train, x_test, y_test = prepare_dataset(feature_path)

In [8]:
seed = 3439
params = dict(
    n_estimators=30000, 
    max_features="sqrt", 
    max_depth=10, 
    min_samples_split=2, 
    learning_rate=0.005, 
    loss='squared_error', 
    subsample=0.6
)
GBT = GradientBoostingRegressor(random_state=seed, **params)
GBT.fit(x_train,y_train)

GradientBoostingRegressor(learning_rate=0.005, max_depth=10,
                          max_features='sqrt', n_estimators=30000,
                          random_state=3439, subsample=0.6)

In [9]:
y_pred = GBT.predict(x_test)
pcc = pearsonr(y_test,y_pred)[0]
rmse = sqrt(mean_squared_error(y_test,y_pred))
print("Pearson correlation coefficient for GBT: ", pcc)
print("RMSE for GBT:", rmse)

Pearson correlation coefficient for GBT:  0.8679858671372087
RMSE for GBT: 1.176307355878775
