In [52]:
import torch 
import json,pickle,math
import pandas as pd
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


In [53]:
full_df = pd.read_csv(open('../kiba_all_pairs.csv','r'))

In [54]:
all_3_folds={}
for i in [0,1,2]:
    file_name = 'fold' +str(i)

    temp = open('../data/kiba/KIBA_3_FOLDS/' + file_name +'.pkl', 'rb')
    new_df = pd.read_pickle(temp)
    all_3_folds.update({file_name:new_df})
    temp.close()
        

In [55]:
# all_3_folds['fold2']

In [56]:
def create_davis_test_train(test_fold_number,all_3_folds):
    
    test_set = pd.DataFrame(columns = full_df.columns)
    train_set = pd.DataFrame(columns= full_df.columns)
    for i in [0,1,2]:
        fold_name = 'fold' + str(i) 
        df = all_3_folds[fold_name]

        if str(i) == test_fold_number:
            test_set = df.copy()

        if str(i) != test_fold_number:
            train_set = pd.concat([train_set, df.copy()], ignore_index=True)

                
    return train_set, test_set


# Create train test split on these 3 folds
## fold_number is the id of fold. For example, test = fold0, train = fold 1,2

In [57]:
fold_number = '0'

In [58]:
train, test = create_davis_test_train(test_fold_number=fold_number, all_3_folds=all_3_folds)

In [59]:
test = test[['SMILES','Target Sequence','Label']]
train = train[['SMILES','Target Sequence','Label']]

# train =train.sample(100)

In [60]:
# test

# Creating similarity matrices for this fold

In [61]:
import rdkit
from rdkit.Chem import AllChem as Chem
from rdkit.Chem import AllChem
from rdkit.DataStructs import FingerprintSimilarity as fs
from rdkit.Chem.Fingerprints import FingerprintMols
from Bio import pairwise2

In [62]:
train_targets = list(set(list(train['Target Sequence'])))
train_smiles = list(set(list(train['SMILES'])))

def computeLigandSimilarity(smiles):
    fingerprints = {}
    for smile in smiles:
        mol = AllChem.MolFromSmiles(smile)
        if mol == None:
            mol = AllChem.MolFromSmiles(smile, sanitize=False)
        fp = FingerprintMols.FingerprintMol(mol)
        fingerprints[smile] = fp
    
    n = len(smiles)
    sims = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1):
            fpi = fingerprints[smiles[i]]
            fpj = fingerprints[smiles[j]]
            sim = fs(fpi, fpj)
            sims[i, j] = sims[j, i] = sim
    return sims

def computeProteinSimilarity(targets):
    n = len(targets)
    mat = np.zeros((n,n))
    mat_i = np.zeros(n)
    for i in range(n):
        seq = targets[i]
        s = pairwise2.align.localxx(seq,seq, score_only=True)
        mat_i[i] = s
        
    for i in range(n):
        print(i)
        for j in range(n):
            if mat[i][j] == 0 :
                s1 = targets[i]
                s2 = targets[j]
                sw_ij = pairwise2.align.localxx(s1,s2,score_only=True)
                normalized_score = sw_ij /math.sqrt(mat_i[i]*mat_i[j])
                mat[i][j] = mat[j][i] = normalized_score
    
    return mat

In [63]:
ligand_similarity_matrix = computeLigandSimilarity(train_smiles)

In [64]:
np.shape(ligand_similarity_matrix)

(2060, 2060)

In [65]:
# ligand_similarity_matrix

In [66]:
print(len(train_targets))
protein_similarity_matrix = computeProteinSimilarity(train_targets)

140
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


In [67]:
np.shape(protein_similarity_matrix)

(140, 140)

In [68]:
LSM = ligand_similarity_matrix
PSM = protein_similarity_matrix

# Creating similarity matrcies for test set

In [69]:
test_targets = list(set(list(test['Target Sequence'])))
test_smiles = list(set(list(test['SMILES'])))

In [70]:
test_PSM = np.zeros((len(test_targets), len(train_targets)))
np.shape(test_PSM)

(89, 140)

In [71]:
len(test_targets)

89

In [72]:
s_train_PSM = np.zeros(len(train_targets))
s_test_PSM = np.zeros(len(test_targets))

for i in range(len(train_targets)):
    seq = train_targets[i]
    s_train_PSM[i] = pairwise2.align.localxx(seq,seq, score_only=True)
    
for i in range(len(test_targets)):
    seq = test_targets[i]
    s_test_PSM[i] = pairwise2.align.localxx(seq,seq, score_only=True)
    
for i in range(len(test_targets)):
    print(i)
    for j in range(len(train_targets)):
        seq1 = test_targets[i]
        seq2 = train_targets[j]
        s_ij=pairwise2.align.localxx(seq1, seq2, score_only=True)
        N_S = s_ij / math.sqrt(s_train_PSM[j] * s_test_PSM[i])
        test_PSM[i][j] = N_S

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


In [73]:
test_LSM = np.zeros((len(test_smiles), len(train_smiles)))
np.shape(test_LSM)

(1867, 2060)

In [74]:
test_smi_fp=[]
train_smi_fp = []

for smi in test_smiles:
    mol1 = AllChem.MolFromSmiles(smi)
    if mol1 == None:
        mol1= AllChem.MolFromSmiles(smi, sanitize=False)
    fp1 = FingerprintMols.FingerprintMol(mol1)
    test_smi_fp.append(fp1)

for smi in train_smiles:
    mol1 = AllChem.MolFromSmiles(smi)
    if mol1 == None:
        mol1= AllChem.MolFromSmiles(smi, sanitize=False)
    fp1 = FingerprintMols.FingerprintMol(mol1)
    train_smi_fp.append(fp1)

In [75]:
# len(train_smi_fp)

In [76]:
for i in range(len(test_smiles)):
    print(i)
    for j in range(len(train_smiles)):
        smi1 = test_smiles[i]
        smi2 = train_smiles[j]
        
        test_LSM[i][j] = fs(test_smi_fp[i], train_smi_fp[j])
#         if i==j:
#             print(i)
#             print(test_LSM[i][j])

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [77]:
# 
test_LSM

array([[1.        , 0.2812983 , 0.36468984, ..., 0.34758995, 0.30672926,
        0.32313433],
       [0.2812983 , 1.        , 0.30551627, ..., 0.29166667, 0.26377295,
        0.30825243],
       [0.34575163, 0.29897611, 0.404375  , ..., 0.40605296, 0.30716724,
        0.38030096],
       ...,
       [0.31189948, 0.26484375, 0.36699164, ..., 0.39198856, 0.25558122,
        0.27781872],
       [0.35151934, 0.28826896, 0.35673624, ..., 0.34939759, 0.30337886,
        0.30874317],
       [0.32313433, 0.30825243, 0.34869326, ..., 0.33884298, 0.28368794,
        1.        ]])

In [78]:
# train_smiles[-1]

In [79]:
# test_smiles[-1]

In [80]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
# Hyper parameters
num_epochs = 20
# num_classes = 10
batch_size = 24
learning_rate = 0.001

In [81]:
class custom_dataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, smiles,smi_index ,targets, target_index, LSM,PSM,transform=None):
        self.df = dataframe
#         self.root_dir = root_dir
        self.smiles =smiles
        self.targets = targets
        self.LSM = LSM
        self.PSM = PSM
        self.transform = transform
        self.smi_index=smi_index
        self.target_index=target_index
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        smi = self.df.iloc[idx]['SMILES']
        seq = self.df.iloc[idx]['Target Sequence']
#         s_i = self.smiles.index(smi)
#         t_i = self.targets.index(seq)
        s_i = self.smi_index[smi]
        t_i = self.target_index[seq]
        
        ki=self.LSM[s_i]
        kj=self.PSM[t_i]
        
        ki_x_kj = np.outer(ki,kj)
        ki_x_kj = torch.tensor([ki_x_kj])
        output = {'outer_product': ki_x_kj , 'Label':self.df.iloc[idx]['Label']}
        return output

In [82]:
test_index_smi={}
for i in test_smiles:
    test_index_smi[i]=test_smiles.index(i)
train_index_smi={}
for i in train_smiles:
    train_index_smi[i]=train_smiles.index(i)

In [83]:
test_index_seq={}
for i in test_targets:
    test_index_seq[i]=test_targets.index(i)

train_index_seq={}
for i in train_targets:
    train_index_seq[i]=train_targets.index(i)

In [84]:
train_dataset = custom_dataset(dataframe=train, smiles=train_smiles, smi_index=train_index_smi, targets = train_targets, target_index=train_index_seq, LSM=LSM,PSM=PSM)


In [85]:
test_dataset = custom_dataset(dataframe=test, smiles=test_smiles,smi_index=test_index_smi, targets = test_targets, target_index=test_index_seq, LSM=test_LSM,PSM=test_PSM)


In [86]:
# 68117/32

In [87]:
train_loader= torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader= torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [88]:
print(len(train_loader)*batch_size +  len(test_loader)*batch_size)

118296


In [89]:
# for i in test_loader:
#     a = i['outer_product']
#     b= i['Label']
#     break
# # print(a)
# conv1 = nn.Conv2d(1,32,5).double()
# pool = nn.MaxPool2d(2,2).double()
# conv2 = nn.Conv2d(32,18,3).double()
# fc1 = nn.Linear(18*513*33, 128).double()
# fc2 = nn.Linear(128,1).double()
# dropout = nn.Dropout(0.1).double()
# x= conv1(a)
# print(x.shape)
# x = pool(x)
# print(x.shape)
# x= conv2(x)
# print(x.shape)
# x = pool(x)
# print(x.shape)
# x = x.view(-1,18*513*33)
# print(x.shape)
# x = dropout(x)
# print(x.shape)
# x = fc1(x)
# print(x.shape)
# x = fc2(x)
# print(x.shape)

In [90]:
import torch.nn.functional as F

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1,32, 5).double()
        self.pool1 = nn.MaxPool2d(2,2).double()
        self.conv2 = nn.Conv2d(32,18,3).double()
        self.pool2 = nn.MaxPool2d(2,2).double()
        self.fc1 = nn.Linear(18*513*33, 128).double()
        self.fc2 = nn.Linear(128,1).double()
        self.dropout = nn.Dropout(0.1).double()
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1,18*513*33)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        
        return x
    

In [91]:
model = ConvNet().to(device)

In [92]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [93]:
def rmse(y,f):
    rmse = math.sqrt(((y - f)**2).mean(axis=0))
    return rmse
def mse(y,f):
    mse = ((y - f)**2).mean(axis=0)
    return mse
def pearson(y,f):
    rp = np.corrcoef(y, f)[0,1]
    return rp
from lifelines.utils import concordance_index
def ci(y,f):
    return concordance_index(y,f)

In [94]:
def predicting(model, device, test_loader):
    model.eval()
    total_preds = np.array([])
    total_labels = np.array([])
    with torch.no_grad():
        correct = 0
        total = 0
        c=0
        for i in test_loader:
            print(c)
            c=c+1
            if(c>100):
                break
            images = i['outer_product']
            labels = i['Label']
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images) 
            outputs = outputs.cpu().detach().numpy().flatten()
            labels =labels.cpu().detach().numpy().flatten()
            total_preds = np.concatenate([total_preds, outputs])
            total_labels = np.concatenate([total_labels, labels])
    model.train()   
    return total_labels, total_preds

In [95]:
model_file_name = 'best_sim-CNN-DTA_kiba_fold' + fold_number + "NEW"+ '.model'
result_file_name = 'best_result_sim-CNNDTA_kiba_fold'+fold_number + "NEW"+ '.csv'

In [96]:
# torch.cuda.empty_cache()

In [97]:
# for i in train_loader:
#     a = i['outer_product']
#     b= i['Label']
#     o = model(a.to(device))
    
#     break

In [98]:
# G,P = predicting(model, device, train_loader)

In [99]:
# test_loader

In [100]:
# Train the model
best_mse = 1000
best_ci = 0
# model_file_name = 'best_sim-CNN-DTA_kiba.model'
# result_file_name = 'best_result_sim-CNNDTA_kiba.csv'
total_step = len(train_loader)
for epoch in range(num_epochs):
    c=0
    for i in train_loader:
        c=c+1
        if (c > 100):
            break
        images = i['outer_product']
        labels = i['Label']
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs.flatten(), labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
           
        print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
               .format(epoch+1, num_epochs, c, total_step, loss.item()))
    
#     taking best model so far
    G,P = predicting(model, device, train_loader)
    ret = [rmse(G, P), mse(G, P), pearson(G, P), ci(G, P)]
    print(ret)
    if ret[1] < best_mse:
#         torch.save(model.state_dict(), model_file_name)
#         with open(result_file_name, 'w') as f:
#             f.write(','.join(map(str, ret)))
        best_epoch = epoch+1
        best_mse = ret[1]
        best_ci = ret[-1]
        best_r = ret[2]
        
        print('rmse improved at epoch ', best_epoch,
                      '; best_mse,best_ci,best_r:', best_mse, best_ci,best_r)
        
        

Epoch [1/20], Step [1/2839], Loss: 148.7818
Epoch [1/20], Step [2/2839], Loss: 50.5890
Epoch [1/20], Step [3/2839], Loss: 1.2489
Epoch [1/20], Step [4/2839], Loss: 10.4442
Epoch [1/20], Step [5/2839], Loss: 9.2366
Epoch [1/20], Step [6/2839], Loss: 1.8275
Epoch [1/20], Step [7/2839], Loss: 1.3997
Epoch [1/20], Step [8/2839], Loss: 4.0028
Epoch [1/20], Step [9/2839], Loss: 3.8050
Epoch [1/20], Step [10/2839], Loss: 1.4119
Epoch [1/20], Step [11/2839], Loss: 1.4588
Epoch [1/20], Step [12/2839], Loss: 3.2799
Epoch [1/20], Step [13/2839], Loss: 2.2418
Epoch [1/20], Step [14/2839], Loss: 1.3534
Epoch [1/20], Step [15/2839], Loss: 0.6803
Epoch [1/20], Step [16/2839], Loss: 1.9371
Epoch [1/20], Step [17/2839], Loss: 1.1957
Epoch [1/20], Step [18/2839], Loss: 1.3460
Epoch [1/20], Step [19/2839], Loss: 0.8396
Epoch [1/20], Step [20/2839], Loss: 1.5649
Epoch [1/20], Step [21/2839], Loss: 1.2979
Epoch [1/20], Step [22/2839], Loss: 1.6117
Epoch [1/20], Step [23/2839], Loss: 0.6143
Epoch [1/20], St

Epoch [2/20], Step [81/2839], Loss: 0.6739
Epoch [2/20], Step [82/2839], Loss: 0.7514
Epoch [2/20], Step [83/2839], Loss: 1.5576
Epoch [2/20], Step [84/2839], Loss: 0.8383
Epoch [2/20], Step [85/2839], Loss: 0.5686
Epoch [2/20], Step [86/2839], Loss: 0.4469
Epoch [2/20], Step [87/2839], Loss: 0.4008
Epoch [2/20], Step [88/2839], Loss: 0.4348
Epoch [2/20], Step [89/2839], Loss: 0.4503
Epoch [2/20], Step [90/2839], Loss: 0.7225
Epoch [2/20], Step [91/2839], Loss: 0.9589
Epoch [2/20], Step [92/2839], Loss: 1.0722
Epoch [2/20], Step [93/2839], Loss: 0.7872
Epoch [2/20], Step [94/2839], Loss: 1.0511
Epoch [2/20], Step [95/2839], Loss: 0.5604
Epoch [2/20], Step [96/2839], Loss: 0.7770
Epoch [2/20], Step [97/2839], Loss: 0.6975
Epoch [2/20], Step [98/2839], Loss: 0.4644
Epoch [2/20], Step [99/2839], Loss: 1.2653
Epoch [2/20], Step [100/2839], Loss: 1.1162
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

Epoch [4/20], Step [55/2839], Loss: 0.6671
Epoch [4/20], Step [56/2839], Loss: 0.4938
Epoch [4/20], Step [57/2839], Loss: 0.3611
Epoch [4/20], Step [58/2839], Loss: 0.7844
Epoch [4/20], Step [59/2839], Loss: 0.3818
Epoch [4/20], Step [60/2839], Loss: 0.7372
Epoch [4/20], Step [61/2839], Loss: 0.6673
Epoch [4/20], Step [62/2839], Loss: 0.7119
Epoch [4/20], Step [63/2839], Loss: 1.0946
Epoch [4/20], Step [64/2839], Loss: 1.0447
Epoch [4/20], Step [65/2839], Loss: 1.1792
Epoch [4/20], Step [66/2839], Loss: 1.1789
Epoch [4/20], Step [67/2839], Loss: 0.7353
Epoch [4/20], Step [68/2839], Loss: 0.3729
Epoch [4/20], Step [69/2839], Loss: 0.7063
Epoch [4/20], Step [70/2839], Loss: 0.4294
Epoch [4/20], Step [71/2839], Loss: 0.3868
Epoch [4/20], Step [72/2839], Loss: 0.8579
Epoch [4/20], Step [73/2839], Loss: 0.4596
Epoch [4/20], Step [74/2839], Loss: 0.5121
Epoch [4/20], Step [75/2839], Loss: 0.8457
Epoch [4/20], Step [76/2839], Loss: 1.1241
Epoch [4/20], Step [77/2839], Loss: 0.4664
Epoch [4/20

Epoch [6/20], Step [26/2839], Loss: 1.4679
Epoch [6/20], Step [27/2839], Loss: 1.0944
Epoch [6/20], Step [28/2839], Loss: 0.4898
Epoch [6/20], Step [29/2839], Loss: 1.4588
Epoch [6/20], Step [30/2839], Loss: 2.5016
Epoch [6/20], Step [31/2839], Loss: 0.9748
Epoch [6/20], Step [32/2839], Loss: 2.1162
Epoch [6/20], Step [33/2839], Loss: 0.9404
Epoch [6/20], Step [34/2839], Loss: 1.1732
Epoch [6/20], Step [35/2839], Loss: 1.6730
Epoch [6/20], Step [36/2839], Loss: 0.5384
Epoch [6/20], Step [37/2839], Loss: 1.1092
Epoch [6/20], Step [38/2839], Loss: 1.1617
Epoch [6/20], Step [39/2839], Loss: 0.6792
Epoch [6/20], Step [40/2839], Loss: 0.7689
Epoch [6/20], Step [41/2839], Loss: 0.7664
Epoch [6/20], Step [42/2839], Loss: 1.5068
Epoch [6/20], Step [43/2839], Loss: 0.9723
Epoch [6/20], Step [44/2839], Loss: 1.5603
Epoch [6/20], Step [45/2839], Loss: 0.4523
Epoch [6/20], Step [46/2839], Loss: 0.2850
Epoch [6/20], Step [47/2839], Loss: 2.0044
Epoch [6/20], Step [48/2839], Loss: 0.3103
Epoch [6/20

Epoch [8/20], Step [1/2839], Loss: 0.1582
Epoch [8/20], Step [2/2839], Loss: 0.4783
Epoch [8/20], Step [3/2839], Loss: 0.5714
Epoch [8/20], Step [4/2839], Loss: 0.8525
Epoch [8/20], Step [5/2839], Loss: 0.2202
Epoch [8/20], Step [6/2839], Loss: 0.9310
Epoch [8/20], Step [7/2839], Loss: 0.3587
Epoch [8/20], Step [8/2839], Loss: 0.5696
Epoch [8/20], Step [9/2839], Loss: 0.3710
Epoch [8/20], Step [10/2839], Loss: 0.2645
Epoch [8/20], Step [11/2839], Loss: 0.8095
Epoch [8/20], Step [12/2839], Loss: 0.5291
Epoch [8/20], Step [13/2839], Loss: 0.7059
Epoch [8/20], Step [14/2839], Loss: 0.9651
Epoch [8/20], Step [15/2839], Loss: 0.8489
Epoch [8/20], Step [16/2839], Loss: 0.5561
Epoch [8/20], Step [17/2839], Loss: 0.5845
Epoch [8/20], Step [18/2839], Loss: 1.5284
Epoch [8/20], Step [19/2839], Loss: 0.3151
Epoch [8/20], Step [20/2839], Loss: 1.0640
Epoch [8/20], Step [21/2839], Loss: 0.4668
Epoch [8/20], Step [22/2839], Loss: 1.1044
Epoch [8/20], Step [23/2839], Loss: 0.3912
Epoch [8/20], Step [

Epoch [9/20], Step [84/2839], Loss: 0.8092
Epoch [9/20], Step [85/2839], Loss: 0.5403
Epoch [9/20], Step [86/2839], Loss: 1.5851
Epoch [9/20], Step [87/2839], Loss: 0.7139
Epoch [9/20], Step [88/2839], Loss: 0.7878
Epoch [9/20], Step [89/2839], Loss: 1.2440
Epoch [9/20], Step [90/2839], Loss: 1.9621
Epoch [9/20], Step [91/2839], Loss: 0.5196
Epoch [9/20], Step [92/2839], Loss: 0.5737
Epoch [9/20], Step [93/2839], Loss: 1.0787
Epoch [9/20], Step [94/2839], Loss: 0.4319
Epoch [9/20], Step [95/2839], Loss: 1.2514
Epoch [9/20], Step [96/2839], Loss: 1.2167
Epoch [9/20], Step [97/2839], Loss: 0.7055
Epoch [9/20], Step [98/2839], Loss: 0.6870
Epoch [9/20], Step [99/2839], Loss: 0.4781
Epoch [9/20], Step [100/2839], Loss: 0.6431
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

Epoch [11/20], Step [54/2839], Loss: 0.4593
Epoch [11/20], Step [55/2839], Loss: 0.9285
Epoch [11/20], Step [56/2839], Loss: 0.6228
Epoch [11/20], Step [57/2839], Loss: 0.4587
Epoch [11/20], Step [58/2839], Loss: 1.2724
Epoch [11/20], Step [59/2839], Loss: 0.7749
Epoch [11/20], Step [60/2839], Loss: 0.8056
Epoch [11/20], Step [61/2839], Loss: 1.1432
Epoch [11/20], Step [62/2839], Loss: 1.3491
Epoch [11/20], Step [63/2839], Loss: 1.4813
Epoch [11/20], Step [64/2839], Loss: 0.9358
Epoch [11/20], Step [65/2839], Loss: 1.5011
Epoch [11/20], Step [66/2839], Loss: 0.7419
Epoch [11/20], Step [67/2839], Loss: 1.4602
Epoch [11/20], Step [68/2839], Loss: 0.8346
Epoch [11/20], Step [69/2839], Loss: 0.4111
Epoch [11/20], Step [70/2839], Loss: 0.6335
Epoch [11/20], Step [71/2839], Loss: 1.5743
Epoch [11/20], Step [72/2839], Loss: 0.5524
Epoch [11/20], Step [73/2839], Loss: 0.9104
Epoch [11/20], Step [74/2839], Loss: 0.8636
Epoch [11/20], Step [75/2839], Loss: 0.5918
Epoch [11/20], Step [76/2839], L

Epoch [13/20], Step [24/2839], Loss: 0.3890
Epoch [13/20], Step [25/2839], Loss: 0.7594
Epoch [13/20], Step [26/2839], Loss: 0.4418
Epoch [13/20], Step [27/2839], Loss: 0.7807
Epoch [13/20], Step [28/2839], Loss: 0.8639
Epoch [13/20], Step [29/2839], Loss: 0.9366
Epoch [13/20], Step [30/2839], Loss: 0.6926
Epoch [13/20], Step [31/2839], Loss: 0.4339
Epoch [13/20], Step [32/2839], Loss: 0.4861
Epoch [13/20], Step [33/2839], Loss: 0.7695
Epoch [13/20], Step [34/2839], Loss: 1.0630
Epoch [13/20], Step [35/2839], Loss: 0.4249
Epoch [13/20], Step [36/2839], Loss: 0.5573
Epoch [13/20], Step [37/2839], Loss: 2.3080
Epoch [13/20], Step [38/2839], Loss: 0.4349
Epoch [13/20], Step [39/2839], Loss: 1.0362
Epoch [13/20], Step [40/2839], Loss: 0.9050
Epoch [13/20], Step [41/2839], Loss: 0.3776
Epoch [13/20], Step [42/2839], Loss: 0.5371
Epoch [13/20], Step [43/2839], Loss: 0.3982
Epoch [13/20], Step [44/2839], Loss: 0.8552
Epoch [13/20], Step [45/2839], Loss: 0.2841
Epoch [13/20], Step [46/2839], L

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
[0.7391941334353079, 0.5464079669051758, 0.48633565472202694, 0.6832203484092103]
rmse improved at epoch  14 ; best_mse,best_ci,best_r: 0.5464079669051758 0.6832203484092103 0.48633565472202694
Epoch [15/20], Step [1/2839], Loss: 0.8161
Epoch [15/20], Step [2/2839], Loss: 0.5105
Epoch [15/20], Step [3/2839], Loss: 0.9766
Epoch [15/20], Step [4/2839], Loss: 0.6003
Epoch [15/20], Step [5/2839], Loss: 0.5610
Epoch [15/20], Step [6/2839], Loss: 0.4525
Epoch [15/20], Step [7/2839], Loss: 0.5167
Epoch [15/20], Step [8/2839], Loss: 0.7288
Epoch [15/20], Step [9/2839], Loss: 0.5648
Epoch [15/20], Step [10/2839], Loss: 0.6747
Epoch [15/20], Step [11/2839], Loss: 0.4072
Epoch [15/20], Step [12/2839], Loss: 0.3258
Epoch [15/20], Step [13/2839], 

Epoch [16/20], Step [69/2839], Loss: 0.4742
Epoch [16/20], Step [70/2839], Loss: 0.5058
Epoch [16/20], Step [71/2839], Loss: 0.9202
Epoch [16/20], Step [72/2839], Loss: 0.4232
Epoch [16/20], Step [73/2839], Loss: 0.4330
Epoch [16/20], Step [74/2839], Loss: 1.0039
Epoch [16/20], Step [75/2839], Loss: 0.7809
Epoch [16/20], Step [76/2839], Loss: 0.5247
Epoch [16/20], Step [77/2839], Loss: 1.1064
Epoch [16/20], Step [78/2839], Loss: 2.6633
Epoch [16/20], Step [79/2839], Loss: 0.5456
Epoch [16/20], Step [80/2839], Loss: 1.0323
Epoch [16/20], Step [81/2839], Loss: 1.2368
Epoch [16/20], Step [82/2839], Loss: 0.9595
Epoch [16/20], Step [83/2839], Loss: 0.3411
Epoch [16/20], Step [84/2839], Loss: 1.8437
Epoch [16/20], Step [85/2839], Loss: 0.7214
Epoch [16/20], Step [86/2839], Loss: 0.7851
Epoch [16/20], Step [87/2839], Loss: 0.9952
Epoch [16/20], Step [88/2839], Loss: 1.2354
Epoch [16/20], Step [89/2839], Loss: 0.7573
Epoch [16/20], Step [90/2839], Loss: 0.1678
Epoch [16/20], Step [91/2839], L

Epoch [18/20], Step [36/2839], Loss: 0.5620
Epoch [18/20], Step [37/2839], Loss: 0.4737
Epoch [18/20], Step [38/2839], Loss: 0.4954
Epoch [18/20], Step [39/2839], Loss: 0.7039
Epoch [18/20], Step [40/2839], Loss: 0.7808
Epoch [18/20], Step [41/2839], Loss: 1.0578
Epoch [18/20], Step [42/2839], Loss: 0.7608
Epoch [18/20], Step [43/2839], Loss: 0.6503
Epoch [18/20], Step [44/2839], Loss: 0.8963
Epoch [18/20], Step [45/2839], Loss: 0.4349
Epoch [18/20], Step [46/2839], Loss: 0.9476
Epoch [18/20], Step [47/2839], Loss: 0.2222
Epoch [18/20], Step [48/2839], Loss: 1.4604
Epoch [18/20], Step [49/2839], Loss: 0.4576
Epoch [18/20], Step [50/2839], Loss: 0.4455
Epoch [18/20], Step [51/2839], Loss: 0.4531
Epoch [18/20], Step [52/2839], Loss: 1.0076
Epoch [18/20], Step [53/2839], Loss: 1.1397
Epoch [18/20], Step [54/2839], Loss: 0.4643
Epoch [18/20], Step [55/2839], Loss: 0.4215
Epoch [18/20], Step [56/2839], Loss: 0.7645
Epoch [18/20], Step [57/2839], Loss: 0.3848
Epoch [18/20], Step [58/2839], L

Epoch [20/20], Step [3/2839], Loss: 0.4747
Epoch [20/20], Step [4/2839], Loss: 0.7698
Epoch [20/20], Step [5/2839], Loss: 0.3076
Epoch [20/20], Step [6/2839], Loss: 0.6290
Epoch [20/20], Step [7/2839], Loss: 0.6189
Epoch [20/20], Step [8/2839], Loss: 0.3992
Epoch [20/20], Step [9/2839], Loss: 0.2148
Epoch [20/20], Step [10/2839], Loss: 0.7674
Epoch [20/20], Step [11/2839], Loss: 0.8013
Epoch [20/20], Step [12/2839], Loss: 0.2933
Epoch [20/20], Step [13/2839], Loss: 0.2562
Epoch [20/20], Step [14/2839], Loss: 0.3423
Epoch [20/20], Step [15/2839], Loss: 0.4720
Epoch [20/20], Step [16/2839], Loss: 0.6024
Epoch [20/20], Step [17/2839], Loss: 0.3210
Epoch [20/20], Step [18/2839], Loss: 0.4159
Epoch [20/20], Step [19/2839], Loss: 0.1841
Epoch [20/20], Step [20/2839], Loss: 1.7063
Epoch [20/20], Step [21/2839], Loss: 0.5621
Epoch [20/20], Step [22/2839], Loss: 0.8905
Epoch [20/20], Step [23/2839], Loss: 0.6391
Epoch [20/20], Step [24/2839], Loss: 0.3622
Epoch [20/20], Step [25/2839], Loss: 0.

In [101]:
model = ConvNet().to(device)
model.load_state_dict(torch.load('./best_sim-CNN-DTA_kiba_fold1NEW.model'))
# model.eval()
total_preds = np.array([])
total_labels = np.array([])
with torch.no_grad():
    correct = 0
    total = 0
    c=0
    for i in train_loader:
        print(c)
        c=c+1
        if(c==100):
            break
        images = i['outer_product']
        labels = i['Label']
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images) 
        outputs = outputs.cpu().detach().numpy().flatten()
        labels =labels.cpu().detach().numpy().flatten()
        total_preds = np.concatenate([total_preds, outputs])
        total_labels = np.concatenate([total_labels, labels])
#         total_preds = torch.cat(total_preds, outputs.cpu(), 0 )
#         total_labels = torch.cat(total_labels, labels.cpu(), 0)
#         break

G,P = total_labels, total_preds

RuntimeError: Error(s) in loading state_dict for ConvNet:
	size mismatch for fc1.weight: copying a param with shape torch.Size([128, 342324]) from checkpoint, the shape in current model is torch.Size([128, 304722]).

In [None]:
print("MSE = ",mse(G,P),flush=True)
print("R = ",pearson(G,P),flush=True)
print("CI = ",ci(G,P),flush=True)
print("RMSE = ",rmse(G,P),flush=True)

In [None]:
# device

In [None]:
model = ConvNet().to(device)

In [None]:
model.load_state_dict(torch.load('./best_sim-CNN-DTA_kiba_fold0NEW.model'))

In [None]:
model.eval()


In [None]:
# G,P = predicting(model, device, test_loader)
# it changes to trainmode again()

In [None]:
total_preds = np.array([])
total_labels = np.array([])
with torch.no_grad():
    correct = 0
    total = 0
    c=0
    for i in test_loader:
        print(c)
        c=c+1
        images = i['outer_product']
        labels = i['Label']
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images) 
        outputs = outputs.cpu().detach().numpy().flatten()
        labels =labels.cpu().detach().numpy().flatten()
        total_preds = np.concatenate([total_preds, outputs])
        total_labels = np.concatenate([total_labels, labels])
#         total_preds = torch.cat(total_preds, outputs.cpu(), 0 )
#         total_labels = torch.cat(total_labels, labels.cpu(), 0)
#         break

In [None]:
G,P = total_labels, total_preds

In [None]:
print("MSE = ",mse(G,P))
print("R = ",pearson(G,P))
print("CI = ",ci(G,P))
print("RMSE = ",rmse(G,P))
