In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder

# PREDICTING FUNCTION

In [2]:
df = pd.read_csv('data/364_interaction_energies_state_function_v3.txt', sep='\t')

#drop index
df.drop('index', axis = 1, inplace=True)

#copy df to structure_df
structure_df = df.copy()
structure_df.drop(['PDBID','State'], axis = 1, inplace = True)

In [3]:
structure_df['Function'].value_counts()

Antagonist           183
Agonist              148
Inverse agonist       23
Agonist (partial)     10
Name: Function, dtype: int64

In [4]:
# get columns with 'sum' in their name
sum_cols = [col for col in structure_df.columns if 'sum' in col]

# create empty list for residue numbers
resnums = []

# loop through sum columns and count interactions that don't have nonzero energies
for col in sum_cols:
    resnum = col[:4]
    resnums.append(resnum)

# drop columns from df in which > 10% of entries are NaN
for resnum in resnums:
    intenergysum_col = resnum + '_intenergysum'
    inttype1_col = resnum + '_inttype1'
    intenergy1_col = resnum + '_intenergy1'
    inttype2_col = resnum + '_inttype2'
    intenergy2_col = resnum + '_intenergy2'
    
    print('structures with interactions at position', resnum, ':', structure_df[intenergysum_col][structure_df[intenergysum_col] != 0].value_counts().sum())
    if structure_df[intenergysum_col][structure_df[intenergysum_col] != 0].value_counts().sum() < 10:
        structure_df.drop([intenergysum_col, inttype1_col, intenergy1_col, inttype2_col, intenergy2_col], axis = 1, inplace = True)
        print('dropped columns for residue: ', resnum, '\n')

#     # for intenergysum columns
#     if 'intenergysum' in col:
#         if df[col][df[col] != 0.0].value_counts().sum() < (0.20 * len(df)):
#             df.drop([col], axis = 1, inplace = True)
#             print('dropped column: ', col)

structures with interactions at position 1.21 : 0
dropped columns for residue:  1.21 

structures with interactions at position 1.22 : 0
dropped columns for residue:  1.22 

structures with interactions at position 1.23 : 0
dropped columns for residue:  1.23 

structures with interactions at position 1.24 : 0
dropped columns for residue:  1.24 

structures with interactions at position 1.25 : 0
dropped columns for residue:  1.25 

structures with interactions at position 1.26 : 0
dropped columns for residue:  1.26 

structures with interactions at position 1.27 : 1
dropped columns for residue:  1.27 

structures with interactions at position 1.28 : 0
dropped columns for residue:  1.28 

structures with interactions at position 1.29 : 0
dropped columns for residue:  1.29 

structures with interactions at position 1.30 : 2
dropped columns for residue:  1.30 

structures with interactions at position 1.31 : 5
dropped columns for residue:  1.31 

structures with interactions at position 1.

dropped columns for residue:  4.42 

structures with interactions at position 4.43 : 0
dropped columns for residue:  4.43 

structures with interactions at position 4.44 : 0
dropped columns for residue:  4.44 

structures with interactions at position 4.45 : 1
dropped columns for residue:  4.45 

structures with interactions at position 4.46 : 8
dropped columns for residue:  4.46 

structures with interactions at position 4.47 : 2
dropped columns for residue:  4.47 

structures with interactions at position 4.48 : 0
dropped columns for residue:  4.48 

structures with interactions at position 4.49 : 0
dropped columns for residue:  4.49 

structures with interactions at position 4.50 : 12
structures with interactions at position 4.51 : 0
dropped columns for residue:  4.51 

structures with interactions at position 4.52 : 1
dropped columns for residue:  4.52 

structures with interactions at position 4.53 : 0
dropped columns for residue:  4.53 

structures with interactions at position 4

structures with interactions at position 7.44 : 0
dropped columns for residue:  7.44 

structures with interactions at position 7.45 : 1
dropped columns for residue:  7.45 

structures with interactions at position 7.46 : 7
dropped columns for residue:  7.46 

structures with interactions at position 7.47 : 1
dropped columns for residue:  7.47 

structures with interactions at position 7.48 : 0
dropped columns for residue:  7.48 

structures with interactions at position 7.49 : 0
dropped columns for residue:  7.49 

structures with interactions at position 7.50 : 0
dropped columns for residue:  7.50 

structures with interactions at position 7.51 : 0
dropped columns for residue:  7.51 

structures with interactions at position 7.52 : 0
dropped columns for residue:  7.52 

structures with interactions at position 7.53 : 0
dropped columns for residue:  7.53 

structures with interactions at position 7.54 : 0
dropped columns for residue:  7.54 

structures with interactions at position 7.

In [5]:
structure_df

Unnamed: 0,Function,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,Agonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,0.0,,0.0,,0.0
1,Agonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,0.0,,0.0,,0.0
2,Agonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,0.0,,0.0,,0.0
3,Agonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,0.0,,0.0,,0.0
4,Agonist,0.0,,0.0,,0.0,0.0,,0.0,,...,-0.1,Hbond,-0.1,,0.0,-0.6,Hbond,-0.5,Hbond,-0.1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
359,Agonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,-0.1,Arene,-0.1,,
360,Antagonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,0.0,,0.0,,0.0
361,Antagonist,0.0,,0.0,,0.0,0.0,,0.0,,...,0.0,,0.0,,0.0,-0.2,Arene,-0.2,,
362,Antagonist,0.0,Distance,0.0,Distance,0.0,-5.2,Hbond,-5.2,,...,0.0,,0.0,,0.0,0.0,,0.0,,0.0


In [6]:
structure_df1 = structure_df.replace('Agonist (partial)', 'Agonist')

In [7]:
structure_df2 = structure_df1.replace('Inverse agonist', 'Antagonist')

In [8]:
actual_fxns = structure_df['Function']
actual_fxns1 = structure_df1['Function']
actual_fxns2 = structure_df2['Function']

In [9]:
# label encoding
def encode_labels(df):
    # create instance of labelencoder
    encoder = OrdinalEncoder()
    le = LabelEncoder()

    cols = [col for col in df.columns if 'type' in col]

    # loop though all columns and convert strings to categorical integer variables
    for col in cols:
        df[col] = encoder.fit_transform(np.array(df[col].tolist()).reshape(-1, 1))


    # encode states as integers
    # get columns with 'type' in their name
    cols = [col for col in df.columns if 'Function' in col]

    # loop though all columns and convert strings to categorical integer variables
    for col in cols:
        df[col] = le.fit_transform(df[col])
        
    return (df, le)

In [10]:
(structure_df, encoder) = encode_labels(structure_df)
(structure_df1, encoder1) = encode_labels(structure_df1)
(structure_df2, encoder2) = encode_labels(structure_df2)

In [11]:
encoder.classes_

array(['Agonist', 'Agonist (partial)', 'Antagonist', 'Inverse agonist'],
      dtype=object)

In [12]:
# assign target classes to y
y = structure_df['Function']
y1 = structure_df1['Function']
y2 = structure_df2['Function']

# assign data to X
X = structure_df.drop(['Function'], axis = 1)
X1 = structure_df1.drop(['Function'], axis = 1)
X2 = structure_df2.drop(['Function'], axis = 1)

# create actual_state column with non-encoded states
X['actual_fxn'] = actual_fxns
X1['actual_fxn'] = actual_fxns1
X2['actual_fxn'] = actual_fxns2

In [13]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
X_train1, X_test1, y_train1, y_test1 = train_test_split(X1, y1, test_size=0.25, random_state=42)
X_train2, X_test2, y_train2, y_test2 = train_test_split(X2, y2, test_size=0.25, random_state=42)

In [14]:
def scale_impute(dataframe):
    # get colnames
    colnames = list(dataframe.drop(['actual_fxn'], axis = 1).columns)
    fxns_df = dataframe['actual_fxn']
    #state_df.reset_index(inplace=True)
    df = dataframe.drop(['actual_fxn'], axis = 1)

    # impute data
    from sklearn.impute import SimpleImputer
    my_imputer = SimpleImputer()
    df_imputed = pd.DataFrame(my_imputer.fit_transform(df))

    # scale data
    scaler = StandardScaler()
    to_scale = [col for col in df_imputed.columns.values]
    scaler.fit(df_imputed[to_scale])

    # predict z-scores on the test set
    df_imputed[to_scale] = scaler.transform(df_imputed[to_scale]) 

    # #rename columns
    df_imputed.columns = colnames

    # display scaled values
    display(df_imputed)
    
    return(df_imputed, fxns_df)

In [15]:
(X_train_imputed, X_train_fxns) = scale_impute(X_train)
(X_test_imputed, X_test_fxns) = scale_impute(X_test)

(X_train_imputed1, X_train_fxns1) = scale_impute(X_train1)
(X_test_imputed1, X_test_fxns1) = scale_impute(X_test1)

(X_train_imputed2, X_train_fxns2) = scale_impute(X_train2)
(X_test_imputed2, X_test_fxns2) = scale_impute(X_test2)

Unnamed: 0,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,1.39_intenergy2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.048200,-2.463451,0.010800,1.175080,0.000000
1,1.550301e-01,-3.479646,0.155158,-5.264982,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,-2.839074,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
2,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
3,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.048200,-0.458914,0.010800,1.175080,0.000000
4,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
268,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
269,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
270,2.185092e-17,1.948602,0.000000,2.504424,0.000000,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
271,1.550301e-01,-3.479646,0.155158,-0.085378,0.065437,0.135785,-5.953211,0.13098,-0.071398,0.081763,...,-0.483218,-1.207353,-0.667050,2.082685,0.000000,0.048200,-2.463451,0.010800,1.175080,0.000000


Unnamed: 0,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,1.39_intenergy2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
1,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,-6.690516,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
2,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,-3.587952,-2.577757,-4.378760,1.268804,0.000000
3,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
4,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,-1.553408,0.276230,0.147822,0.142358
87,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
88,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,-2.494438,-1.708078,-2.608058,2.182353,2.431311e-17,-5.139039,-0.529059,-0.654768,-4.336107,-9.196330
89,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358


Unnamed: 0,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,1.39_intenergy2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.048200,-2.463451,0.010800,1.175080,0.000000
1,1.550301e-01,-3.479646,0.155158,-5.264982,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,-2.839074,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
2,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
3,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.048200,-0.458914,0.010800,1.175080,0.000000
4,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
268,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
269,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
270,2.185092e-17,1.948602,0.000000,2.504424,0.000000,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
271,1.550301e-01,-3.479646,0.155158,-0.085378,0.065437,0.135785,-5.953211,0.13098,-0.071398,0.081763,...,-0.483218,-1.207353,-0.667050,2.082685,0.000000,0.048200,-2.463451,0.010800,1.175080,0.000000


Unnamed: 0,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,1.39_intenergy2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
1,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,-6.690516,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
2,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,-3.587952,-2.577757,-4.378760,1.268804,0.000000
3,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
4,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,-1.553408,0.276230,0.147822,0.142358
87,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
88,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,-2.494438,-1.708078,-2.608058,2.182353,2.431311e-17,-5.139039,-0.529059,-0.654768,-4.336107,-9.196330
89,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358


Unnamed: 0,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,1.39_intenergy2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.048200,-2.463451,0.010800,1.175080,0.000000
1,1.550301e-01,-3.479646,0.155158,-5.264982,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,-2.839074,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
2,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
3,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.048200,-0.458914,0.010800,1.175080,0.000000
4,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
268,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
269,1.550301e-01,0.139186,0.155158,-0.085378,0.065437,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
270,2.185092e-17,1.948602,0.000000,2.504424,0.000000,0.135785,0.133780,0.13098,-0.071398,0.081763,...,0.276524,0.424367,0.244697,0.073593,0.209466,0.151177,0.543354,0.162001,0.049477,0.148476
271,1.550301e-01,-3.479646,0.155158,-0.085378,0.065437,0.135785,-5.953211,0.13098,-0.071398,0.081763,...,-0.483218,-1.207353,-0.667050,2.082685,0.000000,0.048200,-2.463451,0.010800,1.175080,0.000000


Unnamed: 0,1.35_intenergysum,1.35_inttype1,1.35_intenergy1,1.35_inttype2,1.35_intenergy2,1.39_intenergysum,1.39_inttype1,1.39_intenergy1,1.39_inttype2,1.39_intenergy2,...,7.42_intenergysum,7.42_inttype1,7.42_intenergy1,7.42_inttype2,7.42_intenergy2,7.43_intenergysum,7.43_inttype1,7.43_intenergy1,7.43_inttype2,7.43_intenergy2
0,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
1,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,-6.690516,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
2,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,-3.587952,-2.577757,-4.378760,1.268804,0.000000
3,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
4,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,-1.553408,0.276230,0.147822,0.142358
87,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358
88,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,-2.494438,-1.708078,-2.608058,2.182353,2.431311e-17,-5.139039,-0.529059,-0.654768,-4.336107,-9.196330
89,0.15688,0.089757,0.148654,-0.066446,0.111574,0.150272,0.228086,0.150272,-0.214423,0.0,...,0.207870,0.364390,0.190683,0.046932,1.955659e-01,0.289763,0.495290,0.276230,0.147822,0.142358


In [16]:
X_train_fxns = X_train_fxns.reset_index()
X_train_fxns.drop(['index'], axis = 1, inplace = True)

X_test_fxns = X_test_fxns.reset_index()
X_test_fxns.drop(['index'], axis = 1, inplace = True)

X_train_fxns1 = X_train_fxns1.reset_index()
X_train_fxns1.drop(['index'], axis = 1, inplace = True)

X_test_fxns1 = X_test_fxns1.reset_index()
X_test_fxns1.drop(['index'], axis = 1, inplace = True)

X_train_fxns2 = X_train_fxns2.reset_index()
X_train_fxns2.drop(['index'], axis = 1, inplace = True)

X_test_fxns2 = X_test_fxns2.reset_index()
X_test_fxns2.drop(['index'], axis = 1, inplace = True)

In [17]:
def train_test_predict(train_df, train_y, test_df, test_y, encoder):
    #Import Random Forest Model
    from sklearn.ensemble import RandomForestClassifier

    #Create a Gaussian Classifier
    clf=RandomForestClassifier(n_estimators=100, random_state=1)

    #Train the model using the training sets y_pred=clf.predict(X_test)
    clf.fit(train_df, train_y)
    
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.model_selection import cross_val_score, KFold
    import numpy as np

    # cross-validation
    scores = cross_val_score(clf, train_df, train_y, cv=5)
    print("Mean cross-validation score: %.2f" % scores.mean())

    # k-fold CV
    kfold = KFold(n_splits=10, shuffle=True, random_state = 1)
    kf_cv_scores = cross_val_score(clf, train_df, train_y, cv=kfold)
    print("K-fold CV average score: %.2f" % kf_cv_scores.mean())
    
    # test set predictions
    y_pred = clf.predict(test_df)

    #Import scikit-learn metrics module for accuracy calculation
    from sklearn import metrics

    # reverse label encoding
    y_pred_actual = encoder.inverse_transform(y_pred)
    y_test_actual = encoder.inverse_transform(test_y)

    data = {'y_Actual':    y_test_actual,
            'y_Predicted': y_pred_actual
            }

    df = pd.DataFrame(data, columns=['y_Actual','y_Predicted'])

    confusion_matrix = pd.crosstab(df['y_Actual'], df['y_Predicted'], rownames=['Actual'], colnames=['Predicted'])
    print ('\n', confusion_matrix, '\n')
    
     # Model Accuracy, how often is the classifier correct?
    acc = metrics.accuracy_score(test_y, y_pred)
    precision = metrics.precision_score(test_y, y_pred, average = 'weighted', labels=np.unique(y_pred))
    recall = metrics.recall_score(test_y, y_pred, average = 'weighted', labels=np.unique(y_pred))
    print("Accuracy:","{:.2f}".format(acc))
    print("Precision:","{:.2f}".format(precision))
    print("Recall:","{:.2f}".format(recall), '\n')

In [18]:
train_test_predict(X_train_imputed, y_train, X_test_imputed, y_test, encoder)

Mean cross-validation score: 0.77
K-fold CV average score: 0.75

 Predicted          Agonist  Antagonist  Inverse agonist
Actual                                                 
Agonist                 31           3                0
Agonist (partial)        3           1                0
Antagonist              11          34                0
Inverse agonist          0           6                2 

Accuracy: 0.74
Precision: 0.76
Recall: 0.77 



In [19]:
train_test_predict(X_train_imputed1, y_train1, X_test_imputed1, y_test1, encoder1)

Mean cross-validation score: 0.76
K-fold CV average score: 0.76

 Predicted        Agonist  Antagonist  Inverse agonist
Actual                                               
Agonist               34           4                0
Antagonist            11          34                0
Inverse agonist        0           7                1 

Accuracy: 0.76
Precision: 0.78
Recall: 0.76 



In [20]:
train_test_predict(X_train_imputed2, y_train2, X_test_imputed2, y_test2, encoder2)

Mean cross-validation score: 0.78
K-fold CV average score: 0.79

 Predicted   Agonist  Antagonist
Actual                         
Agonist          32           6
Antagonist        6          47 

Accuracy: 0.87
Precision: 0.87
Recall: 0.87 



## XGBoost

In [60]:
import xgboost as xgb

xgbc = xgb.XGBClassifier(use_label_encoder=False,
                         eval_metric='mlogloss',
                         n_estimators=500,
                         random_state = 1,
                         learning_rate = 0.05
                        )

In [61]:
xgbc

XGBClassifier(base_score=None, booster=None, colsample_bylevel=None,
              colsample_bynode=None, colsample_bytree=None,
              enable_categorical=False, eval_metric='mlogloss', gamma=None,
              gpu_id=None, importance_type=None, interaction_constraints=None,
              learning_rate=0.05, max_delta_step=None, max_depth=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              n_estimators=500, n_jobs=None, num_parallel_tree=None,
              predictor=None, random_state=1, reg_alpha=None, reg_lambda=None,
              scale_pos_weight=None, subsample=None, tree_method=None,
              use_label_encoder=False, validate_parameters=None,
              verbosity=None)

In [62]:
xgbc.fit(X_train_imputed, y_train)

from sklearn.model_selection import cross_val_score, KFold

# cross-validation
scores = cross_val_score(xgbc, X_train_imputed, y_train, cv=5)
print("Mean cross-validation score: %.2f" % scores.mean())

# k-fold CV
kfold = KFold(n_splits=10, shuffle=True)
kf_cv_scores = cross_val_score(xgbc, X_train_imputed, y_train, cv=kfold )
print("K-fold CV average score: %.2f" % kf_cv_scores.mean())

from sklearn.metrics import confusion_matrix

y_pred = xgbc.predict(X_test_imputed)

print("Accuracy:",metrics.accuracy_score(y_test, y_pred), '\n')

# reverse label encoding
y_pred_actual = labelencoder.inverse_transform(y_pred)
y_test_actual = labelencoder.inverse_transform(y_test)

data = {'y_Actual':    y_test_actual,
        'y_Predicted': y_pred_actual
        }

df = pd.DataFrame(data, columns=['y_Actual','y_Predicted'])

confusion_matrix = pd.crosstab(df['y_Actual'], df['y_Predicted'], rownames=['Actual'], colnames=['Predicted'])
print (confusion_matrix)

Mean cross-validation score: 0.71
K-fold CV average score: 0.73
Accuracy: 0.6593406593406593 

Predicted          Agonist  Antagonist  Inverse agonist
Actual                                                 
Agonist                 21          13                0
Agonist (partial)        2           2                0
Antagonist               8          37                0
Inverse agonist          0           6                2


In [63]:
len(X_train)

273

In [64]:
len(X_test)

91