In [1]:
import torch 
import json,pickle,math
import pandas as pd
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


In [2]:
full_df = pd.read_csv(open('../davis_all_pairs.csv','r'))

In [3]:
all_9_folds={}
for i in [0,1,2]:
    for j in [0,1,2]:
        file_name = 'fold' +str(i) +str(j) 
        
        temp = open('../data/davis/DAVIS_9_FOLDS/' + file_name +'.pkl', 'rb')
        new_df = pd.read_pickle(temp)
        all_9_folds.update({file_name:new_df})
        temp.close()
        

In [4]:
def create_davis_test_train(test_fold_number,all_9_folds):
    test_protein_fold_id = test_fold_number[0]
    test_ligand_fold_id = test_fold_number[1]
    test_set = pd.DataFrame(columns = full_df.columns)
    train_set = pd.DataFrame(columns= full_df.columns)
    for i in [0,1,2]:
        for j in [0,1,2]:
            fold_name = 'fold' + str(i) + str(j)
            df = all_9_folds[fold_name]
            
            if str(i) == test_protein_fold_id and str(j) == test_ligand_fold_id:
                test_set = df.copy()
                
            if str(i) != test_protein_fold_id and str(j) != test_ligand_fold_id:
                train_set = pd.concat([train_set, df.copy()], ignore_index=True)
                
                
    return train_set, test_set


# Create train test split on these 9 folds
## fold_number is the id of fold. For example, test = fold00, train = fold 11,22,12,21

In [5]:
fold_number = '10'

In [6]:
train, test = create_davis_test_train(test_fold_number=fold_number, all_9_folds=all_9_folds)

In [7]:
train

Unnamed: 0,SMILES,Target Sequence,Label,drug_encoding,target_encoding
0,CN1CCC(C(C1)O)C2=C(C=C(C3=C2OC(=CC3=O)C4=CC=CC...,MSDVAIVKEGWLHKRGEYIKTWRPRYFLLKNDGTFIGYKERPQDVD...,5.000000,"[C, N, 1, C, C, C, (, C, (, C, 1, ), O, ), C, ...","[M, S, D, V, A, I, V, K, E, G, W, L, H, K, R, ..."
1,CCN(CCCOC1=CC2=C(C=C1)C(=NC=N2)NC3=NNC(=C3)CC(...,MSDVAIVKEGWLHKRGEYIKTWRPRYFLLKNDGTFIGYKERPQDVD...,5.000000,"[C, C, N, (, C, C, C, O, C, 1, =, C, C, 2, =, ...","[M, S, D, V, A, I, V, K, E, G, W, L, H, K, R, ..."
2,CN1CCN(CC1)CCCOC2=C(C=C3C(=C2)N=CC(=C3NC4=CC(=...,MSDVAIVKEGWLHKRGEYIKTWRPRYFLLKNDGTFIGYKERPQDVD...,5.000000,"[C, N, 1, C, C, N, (, C, C, 1, ), C, C, C, O, ...","[M, S, D, V, A, I, V, K, E, G, W, L, H, K, R, ..."
3,CC1=CC2=C(N1)C=CC(=C2F)OC3=NC=NN4C3=C(C(=C4)OC...,MSDVAIVKEGWLHKRGEYIKTWRPRYFLLKNDGTFIGYKERPQDVD...,5.000000,"[C, C, 1, =, C, C, 2, =, C, (, N, 1, ), C, =, ...","[M, S, D, V, A, I, V, K, E, G, W, L, H, K, R, ..."
4,C=CC(=O)NC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC(=C(C=C...,MSDVAIVKEGWLHKRGEYIKTWRPRYFLLKNDGTFIGYKERPQDVD...,5.000000,"[C, =, C, C, (, =, O, ), N, C, 1, =, C, (, C, ...","[M, S, D, V, A, I, V, K, E, G, W, L, H, K, R, ..."
...,...,...,...,...,...
13270,CN1C=NC2=C1C=C(C(=C2F)NC3=C(C=C(C=C3)Br)Cl)C(=...,MDDKDIDKELRQKLNFSYCEETEIEGQKKVEESREASSQTPEKGEV...,5.000000,"[C, N, 1, C, =, N, C, 2, =, C, 1, C, =, C, (, ...","[M, D, D, K, D, I, D, K, E, L, R, Q, K, L, N, ..."
13271,CC(C)S(=O)(=O)C1=CC=CC=C1NC2=NC(=NC=C2Cl)NC3=C...,MDDKDIDKELRQKLNFSYCEETEIEGQKKVEESREASSQTPEKGEV...,5.000000,"[C, C, (, C, ), S, (, =, O, ), (, =, O, ), C, ...","[M, D, D, K, D, I, D, K, E, L, R, Q, K, L, N, ..."
13272,C1=CC(=CC(=C1)O)C2=NC3=C(N=C2C4=CC(=CC=C4)O)N=...,MDDKDIDKELRQKLNFSYCEETEIEGQKKVEESREASSQTPEKGEV...,5.000000,"[C, 1, =, C, C, (, =, C, C, (, =, C, 1, ), O, ...","[M, D, D, K, D, I, D, K, E, L, R, Q, K, L, N, ..."
13273,CC1=CN=C(N=C1NC2=CC(=CC=C2)S(=O)(=O)NC(C)(C)C)...,MDDKDIDKELRQKLNFSYCEETEIEGQKKVEESREASSQTPEKGEV...,5.221849,"[C, C, 1, =, C, N, =, C, (, N, =, C, 1, N, C, ...","[M, D, D, K, D, I, D, K, E, L, R, Q, K, L, N, ..."


In [8]:
test

Unnamed: 0,SMILES,Target Sequence,Label,drug_encoding,target_encoding
0,CCC1C(=O)N(C2=CN=C(N=C2N1C3CCCC3)NC4=C(C=C(C=C...,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,5.000000,"[C, C, C, 1, C, (, =, O, ), N, (, C, 2, =, C, ...","[M, V, D, G, V, M, I, L, P, V, L, I, M, I, A, ..."
1,CN1C2=C(C=C(C=C2)OC3=CC(=NC=C3)C4=NC=C(N4)C(F)...,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,5.000000,"[C, N, 1, C, 2, =, C, (, C, =, C, (, C, =, C, ...","[M, V, D, G, V, M, I, L, P, V, L, I, M, I, A, ..."
2,C1CC1CONC(=O)C2=C(C(=C(C=C2)F)F)NC3=C(C=C(C=C3...,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,5.000000,"[C, 1, C, C, 1, C, O, N, C, (, =, O, ), C, 2, ...","[M, V, D, G, V, M, I, L, P, V, L, I, M, I, A, ..."
3,CC1=CC2=C(N1)C=CC(=C2F)OC3=NC=NC4=CC(=C(C=C43)...,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,5.468521,"[C, C, 1, =, C, C, 2, =, C, (, N, 1, ), C, =, ...","[M, V, D, G, V, M, I, L, P, V, L, I, M, I, A, ..."
4,CC1=C(C(=CC=C1)Cl)NC(=O)C2=CN=C(S2)NC3=NC(=NC(...,MVDGVMILPVLIMIALPSPSMEDEKPKVNPKLYMCVCEGLSCGNED...,6.207608,"[C, C, 1, =, C, (, C, (, =, C, C, =, C, 1, ), ...","[M, V, D, G, V, M, I, L, P, V, L, I, M, I, A, ..."
...,...,...,...,...,...
3376,CNC(=O)C1=NC=CC(=C1)OC2=CC=C(C=C2)NC(=O)NC3=CC...,MPDPAAHLPFFYGSISRAEAEEHLKLAGMADGLFLLRQCLRSLGGY...,5.000000,"[C, N, C, (, =, O, ), C, 1, =, N, C, =, C, C, ...","[M, P, D, P, A, A, H, L, P, F, F, Y, G, S, I, ..."
3377,CCN(CC)CCNC(=O)C1=C(NC(=C1C)C=C2C3=C(C=CC(=C3)...,MPDPAAHLPFFYGSISRAEAEEHLKLAGMADGLFLLRQCLRSLGGY...,5.000000,"[C, C, N, (, C, C, ), C, C, N, C, (, =, O, ), ...","[M, P, D, P, A, A, H, L, P, F, F, Y, G, S, I, ..."
3378,CC(C)OC1=CC=C(C=C1)NC(=O)N2CCN(CC2)C3=NC=NC4=C...,MPDPAAHLPFFYGSISRAEAEEHLKLAGMADGLFLLRQCLRSLGGY...,5.000000,"[C, C, (, C, ), O, C, 1, =, C, C, =, C, (, C, ...","[M, P, D, P, A, A, H, L, P, F, F, Y, G, S, I, ..."
3379,CC1=CC(=NN1)NC2=NC(=NC(=C2)N3CCN(CC3)C)SC4=CC=...,MPDPAAHLPFFYGSISRAEAEEHLKLAGMADGLFLLRQCLRSLGGY...,5.000000,"[C, C, 1, =, C, C, (, =, N, N, 1, ), N, C, 2, ...","[M, P, D, P, A, A, H, L, P, F, F, Y, G, S, I, ..."


# To ensure that there are no common targets or drugs in train and test


In [9]:
test_smiles = list(test['SMILES'])
test_targets = list(test['Target Sequence'])
train_smiles = list(train['SMILES'])
train_targets = list(train['Target Sequence'])

for i in test_smiles:
    if i in train_smiles:
        print("common entity present")
for i in test_targets:
    if i in train_targets:
        print("common entity present")


# Creating similarity matrices for this fold

In [10]:
import rdkit
from rdkit.Chem import AllChem as Chem
from rdkit.Chem import AllChem
from rdkit.DataStructs import FingerprintSimilarity as fs
from rdkit.Chem.Fingerprints import FingerprintMols
from Bio import pairwise2

In [11]:
train_targets = list(set(list(train['Target Sequence'])))
train_smiles = list(set(list(train['SMILES'])))

def computeLigandSimilarity(smiles):
    fingerprints = {}
    for smile in smiles:
        mol = AllChem.MolFromSmiles(smile)
        if mol == None:
            mol = AllChem.MolFromSmiles(smile, sanitize=False)
        fp = FingerprintMols.FingerprintMol(mol)
        fingerprints[smile] = fp
    
    n = len(smiles)
    sims = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1):
            fpi = fingerprints[smiles[i]]
            fpj = fingerprints[smiles[j]]
            sim = fs(fpi, fpj)
            sims[i, j] = sims[j, i] = sim
    return sims

def computeProteinSimilarity(targets):
    n = len(targets)
    mat = np.zeros((n,n))
    mat_i = np.zeros(n)
    for i in range(n):
        seq = targets[i]
        s = pairwise2.align.localxx(seq,seq, score_only=True)
        mat_i[i] = s
        
    for i in range(n):
        print(i)
        for j in range(n):
            if mat[i][j] == 0 :
                s1 = targets[i]
                s2 = targets[j]
                sw_ij = pairwise2.align.localxx(s1,s2,score_only=True)
                normalized_score = sw_ij /math.sqrt(mat_i[i]*mat_i[j])
                mat[i][j] = mat[j][i] = normalized_score
    
    return mat

In [12]:
ligand_similarity_matrix = computeLigandSimilarity(train_smiles)

In [13]:
np.shape(ligand_similarity_matrix)

(45, 45)

In [14]:
print(len(train_targets))
protein_similarity_matrix = computeProteinSimilarity(train_targets)

248
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247


In [15]:
np.shape(protein_similarity_matrix)

(248, 248)

In [16]:
LSM = ligand_similarity_matrix
PSM = protein_similarity_matrix

# creating outer products for train set

In [17]:
outer_train_prods = []
for i,row in train.iterrows():
#     print(i)
    smi = row['SMILES']
    seq = row['Target Sequence']
    target_id = train_targets.index(seq)
    smi_id = train_smiles.index(smi)
    ki=LSM[smi_id]
    kj=PSM[target_id]
    ki_x_kj = np.outer(ki,kj)
    outer_train_prods.append([ki_x_kj])
outer_train_prods = np.array(outer_train_prods)
print(np.shape(outer_train_prods))

(13275, 1, 45, 248)


# Creating similarity matrcies for test set

In [18]:
test_targets = list(set(list(test['Target Sequence'])))
test_smiles = list(set(list(test['SMILES'])))

In [19]:
test_PSM = np.zeros((len(test_targets), len(train_targets)))
np.shape(test_PSM)

(131, 248)

In [20]:
s_train_PSM = np.zeros(len(train_targets))
s_test_PSM = np.zeros(len(test_targets))

for i in range(len(train_targets)):
    seq = train_targets[i]
    s_train_PSM[i] = pairwise2.align.localxx(seq,seq, score_only=True)
    
for i in range(len(test_targets)):
    seq = test_targets[i]
    s_test_PSM[i] = pairwise2.align.localxx(seq,seq, score_only=True)
    
for i in range(len(test_targets)):
    print(i)
    for j in range(len(train_targets)):
        seq1 = test_targets[i]
        seq2 = train_targets[j]
        s_ij=pairwise2.align.localxx(seq1, seq2, score_only=True)
        N_S = s_ij / math.sqrt(s_train_PSM[j] * s_test_PSM[i])
        test_PSM[i][j] = N_S

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130


In [21]:
test_LSM = np.zeros((len(test_smiles), len(train_smiles)))
np.shape(test_LSM)

(23, 45)

In [22]:
for i in range(len(test_smiles)):
    print(i)
    for j in range(len(train_smiles)):
        smi1 = test_smiles[i]
        smi2 = train_smiles[j]
        
        mol1 = AllChem.MolFromSmiles(smi1)
        if mol1 == None:
            mol1= AllChem.MolFromSmiles(smi1, sanitize=False)
        fp1 = FingerprintMols.FingerprintMol(mol1)
        
        mol2 = AllChem.MolFromSmiles(smi2)
        if mol2 == None:
            mol2= AllChem.MolFromSmiles(smi2, sanitize=False)
        fp2 = FingerprintMols.FingerprintMol(mol2)
        
        test_LSM[i][j] = fs(fp1,fp2)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


# Creating outer products for test set

In [23]:
outer_test_prods = []
for i,row in test.iterrows():
#     print(i)
    smi = row['SMILES']
    seq = row['Target Sequence']
    target_id = test_targets.index(seq)
    smi_id = test_smiles.index(smi)
    ki=test_LSM[smi_id]
    kj=test_PSM[target_id]
    ki_x_kj = np.outer(ki,kj)
    outer_test_prods.append([ki_x_kj])
outer_test_prods = np.array(outer_test_prods)
print(np.shape(outer_test_prods))

(3381, 1, 45, 248)


In [24]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
num_epochs = 20
# num_classes = 10
batch_size = 32
learning_rate = 0.001

In [25]:
class custom_dataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, outer_prods, transform=None):
#         self.df = pd.read_csv(open(csv_file))
        self.df = dataframe
#         self.root_dir = root_dir
        self.transform = transform
        self.outer_prods = outer_prods
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        output = {'outer_product': self.outer_prods[idx] , 'Label':self.df.iloc[idx]['Label']}
        return output

In [26]:
train_dataset = custom_dataset(dataframe = train, outer_prods = outer_train_prods)
test_dataset = custom_dataset(dataframe = test, outer_prods = outer_test_prods)


In [27]:
train_loader= torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader= torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [28]:
print(len(train_loader)*32, len(test_loader)*32)

13280 3392


In [30]:
for i in test_loader:
    a = i['outer_product']
    b= i['Label']
    break
conv1 = nn.Conv2d(1,32,5).double()
pool = nn.MaxPool2d(2,2).double()
conv2 = nn.Conv2d(32,18,3).double()
fc1 = nn.Linear(18*9*60, 128).double()
fc2 = nn.Linear(128,1).double()
dropout = nn.Dropout(0.1).double()
x= conv1(a)
print(x.shape)
x = pool(x)
print(x.shape)
x= conv2(x)
print(x.shape)
x = pool(x)
print(x.shape)
x = x.view(-1,18*9*60)
print(x.shape)
x = dropout(x)
print(x.shape)
x = fc1(x)
print(x.shape)
x = fc2(x)
print(x.shape)

torch.Size([32, 32, 41, 244])
torch.Size([32, 32, 20, 122])
torch.Size([32, 18, 18, 120])
torch.Size([32, 18, 9, 60])
torch.Size([32, 9720])
torch.Size([32, 9720])
torch.Size([32, 128])
torch.Size([32, 1])


In [31]:
import torch.nn.functional as F

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1,32, 5).double()
        self.pool1 = nn.MaxPool2d(2,2).double()
        self.conv2 = nn.Conv2d(32,18,3).double()
        self.pool2 = nn.MaxPool2d(2,2).double()
        self.fc1 = nn.Linear(18*9*60, 128).double()
        self.fc2 = nn.Linear(128,1).double()
        self.dropout = nn.Dropout(0.1).double()
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1,18*9*60)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        
        return x
    

In [32]:
model = ConvNet().to(device)

In [33]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# Evaluation metrics

In [34]:
def rmse(y,f):
    rmse = math.sqrt(((y - f)**2).mean(axis=0))
    return rmse
def mse(y,f):
    mse = ((y - f)**2).mean(axis=0)
    return mse
def pearson(y,f):
    rp = np.corrcoef(y, f)[0,1]
    return rp
from lifelines.utils import concordance_index
def ci(y,f):
    return concordance_index(y,f)

In [35]:
def predicting(model, device, test_loader):
    model.eval()
    total_preds = np.array([])
    total_labels = np.array([])
    with torch.no_grad():
        for i in test_loader:
            images = i['outer_product']
            labels = i['Label']
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images) 
            outputs = outputs.cpu().detach().numpy().flatten()
            labels =labels.cpu().detach().numpy().flatten()
            total_preds = np.concatenate([total_preds, outputs])
            total_labels = np.concatenate([total_labels, labels])
    
    model.train()
    return total_labels, total_preds

# Train the model


In [36]:
model_file_name = 'best_sim-CNN-DTA_davis_fold' + fold_number +  '.model'
result_file_name = 'best_result_sim-CNNDTA_davis_fold'+fold_number + '.csv'

In [37]:
# Train the model
best_mse = 1000
best_ci = 0

total_step = len(train_loader)
for epoch in range(num_epochs):
    c=0
    for i in train_loader:
        c=c+1
        images = i['outer_product']
        labels = i['Label']
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs.flatten(), labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
           
        print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
               .format(epoch+1, num_epochs, c, total_step, loss.item()))
    
    # taking best model so far
    G,P = predicting(model, device, test_loader)
    ret = [rmse(G, P), mse(G, P), pearson(G, P), ci(G, P)]
    if ret[1] < best_mse:
        torch.save(model.state_dict(), model_file_name)
        with open(result_file_name, 'w') as f:
            f.write(','.join(map(str, ret)))
        best_epoch = epoch+1
        best_mse = ret[1]
        best_ci = ret[-1]
        best_r = ret[2]
        
        print('rmse improved at epoch ', best_epoch,
                      '; best_mse,best_ci,best_r:', best_mse, best_ci,best_r)
        
        

Epoch [1/20], Step [1/415], Loss: 26.8676
Epoch [1/20], Step [2/415], Loss: 21.9561
Epoch [1/20], Step [3/415], Loss: 13.6543
Epoch [1/20], Step [4/415], Loss: 3.5953
Epoch [1/20], Step [5/415], Loss: 2.3186
Epoch [1/20], Step [6/415], Loss: 10.0193
Epoch [1/20], Step [7/415], Loss: 9.1870
Epoch [1/20], Step [8/415], Loss: 3.7300
Epoch [1/20], Step [9/415], Loss: 1.2445
Epoch [1/20], Step [10/415], Loss: 1.1412
Epoch [1/20], Step [11/415], Loss: 2.2488
Epoch [1/20], Step [12/415], Loss: 4.1949
Epoch [1/20], Step [13/415], Loss: 4.1535
Epoch [1/20], Step [14/415], Loss: 2.8398
Epoch [1/20], Step [15/415], Loss: 3.7854
Epoch [1/20], Step [16/415], Loss: 1.5485
Epoch [1/20], Step [17/415], Loss: 0.5138
Epoch [1/20], Step [18/415], Loss: 1.1486
Epoch [1/20], Step [19/415], Loss: 1.1497
Epoch [1/20], Step [20/415], Loss: 1.3090
Epoch [1/20], Step [21/415], Loss: 1.8630
Epoch [1/20], Step [22/415], Loss: 2.1153
Epoch [1/20], Step [23/415], Loss: 1.9055
Epoch [1/20], Step [24/415], Loss: 1.05

Epoch [1/20], Step [199/415], Loss: 1.6196
Epoch [1/20], Step [200/415], Loss: 0.7477
Epoch [1/20], Step [201/415], Loss: 0.7237
Epoch [1/20], Step [202/415], Loss: 0.1814
Epoch [1/20], Step [203/415], Loss: 0.1369
Epoch [1/20], Step [204/415], Loss: 1.6977
Epoch [1/20], Step [205/415], Loss: 0.5049
Epoch [1/20], Step [206/415], Loss: 0.3155
Epoch [1/20], Step [207/415], Loss: 1.1025
Epoch [1/20], Step [208/415], Loss: 0.4155
Epoch [1/20], Step [209/415], Loss: 0.3220
Epoch [1/20], Step [210/415], Loss: 0.6876
Epoch [1/20], Step [211/415], Loss: 0.6404
Epoch [1/20], Step [212/415], Loss: 0.3236
Epoch [1/20], Step [213/415], Loss: 0.5974
Epoch [1/20], Step [214/415], Loss: 0.4427
Epoch [1/20], Step [215/415], Loss: 0.6800
Epoch [1/20], Step [216/415], Loss: 0.2626
Epoch [1/20], Step [217/415], Loss: 1.0517
Epoch [1/20], Step [218/415], Loss: 0.5121
Epoch [1/20], Step [219/415], Loss: 0.7245
Epoch [1/20], Step [220/415], Loss: 0.8961
Epoch [1/20], Step [221/415], Loss: 0.7982
Epoch [1/20

Epoch [1/20], Step [391/415], Loss: 0.1861
Epoch [1/20], Step [392/415], Loss: 0.5624
Epoch [1/20], Step [393/415], Loss: 1.4930
Epoch [1/20], Step [394/415], Loss: 0.4413
Epoch [1/20], Step [395/415], Loss: 1.0929
Epoch [1/20], Step [396/415], Loss: 0.5909
Epoch [1/20], Step [397/415], Loss: 0.3318
Epoch [1/20], Step [398/415], Loss: 0.3468
Epoch [1/20], Step [399/415], Loss: 1.4450
Epoch [1/20], Step [400/415], Loss: 0.8067
Epoch [1/20], Step [401/415], Loss: 0.8375
Epoch [1/20], Step [402/415], Loss: 0.8601
Epoch [1/20], Step [403/415], Loss: 0.3207
Epoch [1/20], Step [404/415], Loss: 1.1123
Epoch [1/20], Step [405/415], Loss: 0.3893
Epoch [1/20], Step [406/415], Loss: 0.4442
Epoch [1/20], Step [407/415], Loss: 0.6874
Epoch [1/20], Step [408/415], Loss: 1.0672
Epoch [1/20], Step [409/415], Loss: 0.3708
Epoch [1/20], Step [410/415], Loss: 0.8118
Epoch [1/20], Step [411/415], Loss: 0.7143
Epoch [1/20], Step [412/415], Loss: 0.6608
Epoch [1/20], Step [413/415], Loss: 0.5672
Epoch [1/20

Epoch [2/20], Step [168/415], Loss: 0.3168
Epoch [2/20], Step [169/415], Loss: 0.6374
Epoch [2/20], Step [170/415], Loss: 0.1454
Epoch [2/20], Step [171/415], Loss: 0.6274
Epoch [2/20], Step [172/415], Loss: 0.5168
Epoch [2/20], Step [173/415], Loss: 0.7241
Epoch [2/20], Step [174/415], Loss: 0.4581
Epoch [2/20], Step [175/415], Loss: 1.1424
Epoch [2/20], Step [176/415], Loss: 0.3792
Epoch [2/20], Step [177/415], Loss: 0.8678
Epoch [2/20], Step [178/415], Loss: 0.7846
Epoch [2/20], Step [179/415], Loss: 1.0302
Epoch [2/20], Step [180/415], Loss: 0.6704
Epoch [2/20], Step [181/415], Loss: 0.5977
Epoch [2/20], Step [182/415], Loss: 0.2538
Epoch [2/20], Step [183/415], Loss: 0.2462
Epoch [2/20], Step [184/415], Loss: 0.2235
Epoch [2/20], Step [185/415], Loss: 0.7138
Epoch [2/20], Step [186/415], Loss: 0.7859
Epoch [2/20], Step [187/415], Loss: 0.3881
Epoch [2/20], Step [188/415], Loss: 0.4186
Epoch [2/20], Step [189/415], Loss: 0.5852
Epoch [2/20], Step [190/415], Loss: 0.7786
Epoch [2/20

Epoch [2/20], Step [359/415], Loss: 0.3089
Epoch [2/20], Step [360/415], Loss: 0.2768
Epoch [2/20], Step [361/415], Loss: 0.7622
Epoch [2/20], Step [362/415], Loss: 0.8841
Epoch [2/20], Step [363/415], Loss: 0.5174
Epoch [2/20], Step [364/415], Loss: 1.0347
Epoch [2/20], Step [365/415], Loss: 0.8505
Epoch [2/20], Step [366/415], Loss: 0.4912
Epoch [2/20], Step [367/415], Loss: 0.3436
Epoch [2/20], Step [368/415], Loss: 0.5696
Epoch [2/20], Step [369/415], Loss: 0.7141
Epoch [2/20], Step [370/415], Loss: 0.2336
Epoch [2/20], Step [371/415], Loss: 0.6811
Epoch [2/20], Step [372/415], Loss: 0.1933
Epoch [2/20], Step [373/415], Loss: 1.2546
Epoch [2/20], Step [374/415], Loss: 0.4547
Epoch [2/20], Step [375/415], Loss: 0.7674
Epoch [2/20], Step [376/415], Loss: 0.8633
Epoch [2/20], Step [377/415], Loss: 0.4078
Epoch [2/20], Step [378/415], Loss: 0.5601
Epoch [2/20], Step [379/415], Loss: 0.4999
Epoch [2/20], Step [380/415], Loss: 0.7065
Epoch [2/20], Step [381/415], Loss: 0.4022
Epoch [2/20

Epoch [3/20], Step [140/415], Loss: 0.1871
Epoch [3/20], Step [141/415], Loss: 0.7696
Epoch [3/20], Step [142/415], Loss: 0.7871
Epoch [3/20], Step [143/415], Loss: 0.5667
Epoch [3/20], Step [144/415], Loss: 0.8490
Epoch [3/20], Step [145/415], Loss: 0.7821
Epoch [3/20], Step [146/415], Loss: 0.8663
Epoch [3/20], Step [147/415], Loss: 0.5092
Epoch [3/20], Step [148/415], Loss: 0.4566
Epoch [3/20], Step [149/415], Loss: 0.4833
Epoch [3/20], Step [150/415], Loss: 1.8511
Epoch [3/20], Step [151/415], Loss: 1.2709
Epoch [3/20], Step [152/415], Loss: 1.2205
Epoch [3/20], Step [153/415], Loss: 2.5248
Epoch [3/20], Step [154/415], Loss: 0.7263
Epoch [3/20], Step [155/415], Loss: 0.3622
Epoch [3/20], Step [156/415], Loss: 0.4525
Epoch [3/20], Step [157/415], Loss: 0.8993
Epoch [3/20], Step [158/415], Loss: 1.8713
Epoch [3/20], Step [159/415], Loss: 0.3208
Epoch [3/20], Step [160/415], Loss: 0.2885
Epoch [3/20], Step [161/415], Loss: 0.7541
Epoch [3/20], Step [162/415], Loss: 0.6870
Epoch [3/20

Epoch [3/20], Step [333/415], Loss: 0.3612
Epoch [3/20], Step [334/415], Loss: 0.6699
Epoch [3/20], Step [335/415], Loss: 0.8102
Epoch [3/20], Step [336/415], Loss: 0.6229
Epoch [3/20], Step [337/415], Loss: 0.3851
Epoch [3/20], Step [338/415], Loss: 0.7757
Epoch [3/20], Step [339/415], Loss: 0.3890
Epoch [3/20], Step [340/415], Loss: 0.3922
Epoch [3/20], Step [341/415], Loss: 0.1837
Epoch [3/20], Step [342/415], Loss: 0.2061
Epoch [3/20], Step [343/415], Loss: 0.4986
Epoch [3/20], Step [344/415], Loss: 0.2899
Epoch [3/20], Step [345/415], Loss: 0.2545
Epoch [3/20], Step [346/415], Loss: 0.5994
Epoch [3/20], Step [347/415], Loss: 0.3530
Epoch [3/20], Step [348/415], Loss: 1.0991
Epoch [3/20], Step [349/415], Loss: 0.4548
Epoch [3/20], Step [350/415], Loss: 0.2186
Epoch [3/20], Step [351/415], Loss: 0.3799
Epoch [3/20], Step [352/415], Loss: 0.8516
Epoch [3/20], Step [353/415], Loss: 0.9667
Epoch [3/20], Step [354/415], Loss: 0.5308
Epoch [3/20], Step [355/415], Loss: 0.2627
Epoch [3/20

Epoch [4/20], Step [112/415], Loss: 0.2634
Epoch [4/20], Step [113/415], Loss: 0.4339
Epoch [4/20], Step [114/415], Loss: 0.5223
Epoch [4/20], Step [115/415], Loss: 0.6668
Epoch [4/20], Step [116/415], Loss: 0.4549
Epoch [4/20], Step [117/415], Loss: 0.8023
Epoch [4/20], Step [118/415], Loss: 0.3336
Epoch [4/20], Step [119/415], Loss: 0.8051
Epoch [4/20], Step [120/415], Loss: 0.3277
Epoch [4/20], Step [121/415], Loss: 0.5439
Epoch [4/20], Step [122/415], Loss: 0.9904
Epoch [4/20], Step [123/415], Loss: 0.4147
Epoch [4/20], Step [124/415], Loss: 0.5279
Epoch [4/20], Step [125/415], Loss: 0.3351
Epoch [4/20], Step [126/415], Loss: 0.3581
Epoch [4/20], Step [127/415], Loss: 0.3337
Epoch [4/20], Step [128/415], Loss: 0.8526
Epoch [4/20], Step [129/415], Loss: 0.4901
Epoch [4/20], Step [130/415], Loss: 0.3675
Epoch [4/20], Step [131/415], Loss: 0.1705
Epoch [4/20], Step [132/415], Loss: 0.5568
Epoch [4/20], Step [133/415], Loss: 0.8387
Epoch [4/20], Step [134/415], Loss: 0.2800
Epoch [4/20

Epoch [4/20], Step [308/415], Loss: 0.9366
Epoch [4/20], Step [309/415], Loss: 0.7290
Epoch [4/20], Step [310/415], Loss: 0.8443
Epoch [4/20], Step [311/415], Loss: 0.4709
Epoch [4/20], Step [312/415], Loss: 0.3037
Epoch [4/20], Step [313/415], Loss: 0.7309
Epoch [4/20], Step [314/415], Loss: 1.0779
Epoch [4/20], Step [315/415], Loss: 0.6723
Epoch [4/20], Step [316/415], Loss: 0.7596
Epoch [4/20], Step [317/415], Loss: 0.7774
Epoch [4/20], Step [318/415], Loss: 0.3966
Epoch [4/20], Step [319/415], Loss: 1.1197
Epoch [4/20], Step [320/415], Loss: 0.5862
Epoch [4/20], Step [321/415], Loss: 0.2879
Epoch [4/20], Step [322/415], Loss: 0.2313
Epoch [4/20], Step [323/415], Loss: 0.4193
Epoch [4/20], Step [324/415], Loss: 0.2534
Epoch [4/20], Step [325/415], Loss: 0.7983
Epoch [4/20], Step [326/415], Loss: 0.3979
Epoch [4/20], Step [327/415], Loss: 0.9619
Epoch [4/20], Step [328/415], Loss: 0.6342
Epoch [4/20], Step [329/415], Loss: 0.7544
Epoch [4/20], Step [330/415], Loss: 0.6236
Epoch [4/20

Epoch [5/20], Step [92/415], Loss: 0.3850
Epoch [5/20], Step [93/415], Loss: 0.3283
Epoch [5/20], Step [94/415], Loss: 0.8380
Epoch [5/20], Step [95/415], Loss: 0.3208
Epoch [5/20], Step [96/415], Loss: 0.2910
Epoch [5/20], Step [97/415], Loss: 0.2790
Epoch [5/20], Step [98/415], Loss: 0.2296
Epoch [5/20], Step [99/415], Loss: 0.6054
Epoch [5/20], Step [100/415], Loss: 0.8274
Epoch [5/20], Step [101/415], Loss: 0.4459
Epoch [5/20], Step [102/415], Loss: 0.3631
Epoch [5/20], Step [103/415], Loss: 0.4913
Epoch [5/20], Step [104/415], Loss: 0.3034
Epoch [5/20], Step [105/415], Loss: 0.3356
Epoch [5/20], Step [106/415], Loss: 0.4155
Epoch [5/20], Step [107/415], Loss: 0.1034
Epoch [5/20], Step [108/415], Loss: 0.4380
Epoch [5/20], Step [109/415], Loss: 0.0843
Epoch [5/20], Step [110/415], Loss: 0.3093
Epoch [5/20], Step [111/415], Loss: 0.2890
Epoch [5/20], Step [112/415], Loss: 0.6621
Epoch [5/20], Step [113/415], Loss: 0.6306
Epoch [5/20], Step [114/415], Loss: 0.1611
Epoch [5/20], Step 

Epoch [5/20], Step [288/415], Loss: 0.5703
Epoch [5/20], Step [289/415], Loss: 0.3196
Epoch [5/20], Step [290/415], Loss: 0.4255
Epoch [5/20], Step [291/415], Loss: 0.5489
Epoch [5/20], Step [292/415], Loss: 0.7409
Epoch [5/20], Step [293/415], Loss: 0.4645
Epoch [5/20], Step [294/415], Loss: 0.4397
Epoch [5/20], Step [295/415], Loss: 0.6883
Epoch [5/20], Step [296/415], Loss: 0.4478
Epoch [5/20], Step [297/415], Loss: 0.4815
Epoch [5/20], Step [298/415], Loss: 0.5945
Epoch [5/20], Step [299/415], Loss: 0.7925
Epoch [5/20], Step [300/415], Loss: 0.3722
Epoch [5/20], Step [301/415], Loss: 0.7854
Epoch [5/20], Step [302/415], Loss: 0.2466
Epoch [5/20], Step [303/415], Loss: 0.6322
Epoch [5/20], Step [304/415], Loss: 0.6632
Epoch [5/20], Step [305/415], Loss: 0.5075
Epoch [5/20], Step [306/415], Loss: 0.2959
Epoch [5/20], Step [307/415], Loss: 0.2182
Epoch [5/20], Step [308/415], Loss: 0.7916
Epoch [5/20], Step [309/415], Loss: 0.5785
Epoch [5/20], Step [310/415], Loss: 0.2626
Epoch [5/20

Epoch [6/20], Step [72/415], Loss: 0.5850
Epoch [6/20], Step [73/415], Loss: 0.3066
Epoch [6/20], Step [74/415], Loss: 0.4576
Epoch [6/20], Step [75/415], Loss: 0.4863
Epoch [6/20], Step [76/415], Loss: 0.2681
Epoch [6/20], Step [77/415], Loss: 0.3472
Epoch [6/20], Step [78/415], Loss: 0.5710
Epoch [6/20], Step [79/415], Loss: 0.3058
Epoch [6/20], Step [80/415], Loss: 0.2898
Epoch [6/20], Step [81/415], Loss: 0.2103
Epoch [6/20], Step [82/415], Loss: 0.6769
Epoch [6/20], Step [83/415], Loss: 0.4403
Epoch [6/20], Step [84/415], Loss: 0.2661
Epoch [6/20], Step [85/415], Loss: 0.4611
Epoch [6/20], Step [86/415], Loss: 0.5348
Epoch [6/20], Step [87/415], Loss: 0.3880
Epoch [6/20], Step [88/415], Loss: 0.2174
Epoch [6/20], Step [89/415], Loss: 0.3323
Epoch [6/20], Step [90/415], Loss: 0.4405
Epoch [6/20], Step [91/415], Loss: 0.4336
Epoch [6/20], Step [92/415], Loss: 0.7795
Epoch [6/20], Step [93/415], Loss: 0.3785
Epoch [6/20], Step [94/415], Loss: 0.4709
Epoch [6/20], Step [95/415], Loss:

Epoch [6/20], Step [264/415], Loss: 0.7435
Epoch [6/20], Step [265/415], Loss: 0.5165
Epoch [6/20], Step [266/415], Loss: 0.3598
Epoch [6/20], Step [267/415], Loss: 0.2326
Epoch [6/20], Step [268/415], Loss: 0.7052
Epoch [6/20], Step [269/415], Loss: 0.4138
Epoch [6/20], Step [270/415], Loss: 0.3694
Epoch [6/20], Step [271/415], Loss: 0.3499
Epoch [6/20], Step [272/415], Loss: 0.1152
Epoch [6/20], Step [273/415], Loss: 0.2226
Epoch [6/20], Step [274/415], Loss: 0.2393
Epoch [6/20], Step [275/415], Loss: 0.4547
Epoch [6/20], Step [276/415], Loss: 0.3263
Epoch [6/20], Step [277/415], Loss: 0.3595
Epoch [6/20], Step [278/415], Loss: 0.2965
Epoch [6/20], Step [279/415], Loss: 0.3378
Epoch [6/20], Step [280/415], Loss: 0.5256
Epoch [6/20], Step [281/415], Loss: 0.5293
Epoch [6/20], Step [282/415], Loss: 0.2188
Epoch [6/20], Step [283/415], Loss: 0.7945
Epoch [6/20], Step [284/415], Loss: 0.1512
Epoch [6/20], Step [285/415], Loss: 0.5410
Epoch [6/20], Step [286/415], Loss: 0.7505
Epoch [6/20

Epoch [7/20], Step [44/415], Loss: 0.8008
Epoch [7/20], Step [45/415], Loss: 0.3319
Epoch [7/20], Step [46/415], Loss: 0.4742
Epoch [7/20], Step [47/415], Loss: 0.5538
Epoch [7/20], Step [48/415], Loss: 0.5521
Epoch [7/20], Step [49/415], Loss: 0.5990
Epoch [7/20], Step [50/415], Loss: 0.5709
Epoch [7/20], Step [51/415], Loss: 0.8045
Epoch [7/20], Step [52/415], Loss: 0.2286
Epoch [7/20], Step [53/415], Loss: 0.4940
Epoch [7/20], Step [54/415], Loss: 0.3888
Epoch [7/20], Step [55/415], Loss: 0.6974
Epoch [7/20], Step [56/415], Loss: 0.3748
Epoch [7/20], Step [57/415], Loss: 0.5037
Epoch [7/20], Step [58/415], Loss: 0.5230
Epoch [7/20], Step [59/415], Loss: 0.6690
Epoch [7/20], Step [60/415], Loss: 0.4543
Epoch [7/20], Step [61/415], Loss: 0.3265
Epoch [7/20], Step [62/415], Loss: 0.5184
Epoch [7/20], Step [63/415], Loss: 0.6911
Epoch [7/20], Step [64/415], Loss: 0.5563
Epoch [7/20], Step [65/415], Loss: 0.4556
Epoch [7/20], Step [66/415], Loss: 0.3340
Epoch [7/20], Step [67/415], Loss:

Epoch [7/20], Step [238/415], Loss: 0.6643
Epoch [7/20], Step [239/415], Loss: 0.3148
Epoch [7/20], Step [240/415], Loss: 0.2470
Epoch [7/20], Step [241/415], Loss: 0.8497
Epoch [7/20], Step [242/415], Loss: 0.3532
Epoch [7/20], Step [243/415], Loss: 0.3654
Epoch [7/20], Step [244/415], Loss: 0.1464
Epoch [7/20], Step [245/415], Loss: 0.4853
Epoch [7/20], Step [246/415], Loss: 0.0876
Epoch [7/20], Step [247/415], Loss: 0.3654
Epoch [7/20], Step [248/415], Loss: 0.4415
Epoch [7/20], Step [249/415], Loss: 0.4530
Epoch [7/20], Step [250/415], Loss: 0.4298
Epoch [7/20], Step [251/415], Loss: 0.1914
Epoch [7/20], Step [252/415], Loss: 0.4304
Epoch [7/20], Step [253/415], Loss: 0.2739
Epoch [7/20], Step [254/415], Loss: 0.2309
Epoch [7/20], Step [255/415], Loss: 0.1503
Epoch [7/20], Step [256/415], Loss: 0.5571
Epoch [7/20], Step [257/415], Loss: 0.6813
Epoch [7/20], Step [258/415], Loss: 0.3379
Epoch [7/20], Step [259/415], Loss: 0.8090
Epoch [7/20], Step [260/415], Loss: 0.2680
Epoch [7/20

Epoch [8/20], Step [14/415], Loss: 0.1672
Epoch [8/20], Step [15/415], Loss: 0.4958
Epoch [8/20], Step [16/415], Loss: 0.5792
Epoch [8/20], Step [17/415], Loss: 0.1608
Epoch [8/20], Step [18/415], Loss: 0.3420
Epoch [8/20], Step [19/415], Loss: 0.2350
Epoch [8/20], Step [20/415], Loss: 0.1614
Epoch [8/20], Step [21/415], Loss: 0.6379
Epoch [8/20], Step [22/415], Loss: 0.2253
Epoch [8/20], Step [23/415], Loss: 0.6473
Epoch [8/20], Step [24/415], Loss: 0.2982
Epoch [8/20], Step [25/415], Loss: 0.2190
Epoch [8/20], Step [26/415], Loss: 0.1537
Epoch [8/20], Step [27/415], Loss: 0.2927
Epoch [8/20], Step [28/415], Loss: 0.4485
Epoch [8/20], Step [29/415], Loss: 0.1557
Epoch [8/20], Step [30/415], Loss: 0.3420
Epoch [8/20], Step [31/415], Loss: 0.6553
Epoch [8/20], Step [32/415], Loss: 0.1594
Epoch [8/20], Step [33/415], Loss: 1.0332
Epoch [8/20], Step [34/415], Loss: 0.3139
Epoch [8/20], Step [35/415], Loss: 0.4941
Epoch [8/20], Step [36/415], Loss: 0.6962
Epoch [8/20], Step [37/415], Loss:

Epoch [8/20], Step [209/415], Loss: 0.4433
Epoch [8/20], Step [210/415], Loss: 0.7911
Epoch [8/20], Step [211/415], Loss: 0.2931
Epoch [8/20], Step [212/415], Loss: 0.4731
Epoch [8/20], Step [213/415], Loss: 0.2590
Epoch [8/20], Step [214/415], Loss: 0.1748
Epoch [8/20], Step [215/415], Loss: 0.5811
Epoch [8/20], Step [216/415], Loss: 0.4119
Epoch [8/20], Step [217/415], Loss: 0.3988
Epoch [8/20], Step [218/415], Loss: 0.4224
Epoch [8/20], Step [219/415], Loss: 0.5642
Epoch [8/20], Step [220/415], Loss: 0.1865
Epoch [8/20], Step [221/415], Loss: 0.3534
Epoch [8/20], Step [222/415], Loss: 0.6807
Epoch [8/20], Step [223/415], Loss: 0.1020
Epoch [8/20], Step [224/415], Loss: 0.1780
Epoch [8/20], Step [225/415], Loss: 0.2575
Epoch [8/20], Step [226/415], Loss: 0.1790
Epoch [8/20], Step [227/415], Loss: 0.4212
Epoch [8/20], Step [228/415], Loss: 0.5633
Epoch [8/20], Step [229/415], Loss: 1.0080
Epoch [8/20], Step [230/415], Loss: 0.2406
Epoch [8/20], Step [231/415], Loss: 0.3775
Epoch [8/20

Epoch [8/20], Step [406/415], Loss: 0.3283
Epoch [8/20], Step [407/415], Loss: 0.1163
Epoch [8/20], Step [408/415], Loss: 0.3850
Epoch [8/20], Step [409/415], Loss: 0.8078
Epoch [8/20], Step [410/415], Loss: 0.3018
Epoch [8/20], Step [411/415], Loss: 0.2617
Epoch [8/20], Step [412/415], Loss: 0.4703
Epoch [8/20], Step [413/415], Loss: 0.4362
Epoch [8/20], Step [414/415], Loss: 0.5389
Epoch [8/20], Step [415/415], Loss: 0.4196
Epoch [9/20], Step [1/415], Loss: 0.5036
Epoch [9/20], Step [2/415], Loss: 0.4763
Epoch [9/20], Step [3/415], Loss: 0.1877
Epoch [9/20], Step [4/415], Loss: 0.2965
Epoch [9/20], Step [5/415], Loss: 0.1690
Epoch [9/20], Step [6/415], Loss: 0.5558
Epoch [9/20], Step [7/415], Loss: 0.2624
Epoch [9/20], Step [8/415], Loss: 0.6757
Epoch [9/20], Step [9/415], Loss: 0.4828
Epoch [9/20], Step [10/415], Loss: 0.3114
Epoch [9/20], Step [11/415], Loss: 0.1346
Epoch [9/20], Step [12/415], Loss: 0.3749
Epoch [9/20], Step [13/415], Loss: 0.2521
Epoch [9/20], Step [14/415], Loss

Epoch [9/20], Step [189/415], Loss: 0.1586
Epoch [9/20], Step [190/415], Loss: 0.3406
Epoch [9/20], Step [191/415], Loss: 0.2435
Epoch [9/20], Step [192/415], Loss: 0.2577
Epoch [9/20], Step [193/415], Loss: 0.1683
Epoch [9/20], Step [194/415], Loss: 0.3086
Epoch [9/20], Step [195/415], Loss: 0.1190
Epoch [9/20], Step [196/415], Loss: 0.5363
Epoch [9/20], Step [197/415], Loss: 0.3121
Epoch [9/20], Step [198/415], Loss: 0.6373
Epoch [9/20], Step [199/415], Loss: 0.3764
Epoch [9/20], Step [200/415], Loss: 0.5844
Epoch [9/20], Step [201/415], Loss: 0.2273
Epoch [9/20], Step [202/415], Loss: 0.4166
Epoch [9/20], Step [203/415], Loss: 0.3385
Epoch [9/20], Step [204/415], Loss: 0.2067
Epoch [9/20], Step [205/415], Loss: 0.3360
Epoch [9/20], Step [206/415], Loss: 0.4330
Epoch [9/20], Step [207/415], Loss: 0.6095
Epoch [9/20], Step [208/415], Loss: 0.1537
Epoch [9/20], Step [209/415], Loss: 0.4321
Epoch [9/20], Step [210/415], Loss: 0.3985
Epoch [9/20], Step [211/415], Loss: 0.3891
Epoch [9/20

Epoch [9/20], Step [382/415], Loss: 0.3097
Epoch [9/20], Step [383/415], Loss: 0.2356
Epoch [9/20], Step [384/415], Loss: 0.1849
Epoch [9/20], Step [385/415], Loss: 0.6476
Epoch [9/20], Step [386/415], Loss: 0.5098
Epoch [9/20], Step [387/415], Loss: 0.2728
Epoch [9/20], Step [388/415], Loss: 0.4456
Epoch [9/20], Step [389/415], Loss: 0.3190
Epoch [9/20], Step [390/415], Loss: 0.6109
Epoch [9/20], Step [391/415], Loss: 0.4002
Epoch [9/20], Step [392/415], Loss: 0.2070
Epoch [9/20], Step [393/415], Loss: 0.0885
Epoch [9/20], Step [394/415], Loss: 0.5188
Epoch [9/20], Step [395/415], Loss: 0.7525
Epoch [9/20], Step [396/415], Loss: 0.5470
Epoch [9/20], Step [397/415], Loss: 0.2434
Epoch [9/20], Step [398/415], Loss: 0.3748
Epoch [9/20], Step [399/415], Loss: 0.3361
Epoch [9/20], Step [400/415], Loss: 0.3881
Epoch [9/20], Step [401/415], Loss: 0.2961
Epoch [9/20], Step [402/415], Loss: 0.2440
Epoch [9/20], Step [403/415], Loss: 0.6650
Epoch [9/20], Step [404/415], Loss: 0.3499
Epoch [9/20

Epoch [10/20], Step [161/415], Loss: 0.8138
Epoch [10/20], Step [162/415], Loss: 0.6068
Epoch [10/20], Step [163/415], Loss: 0.2354
Epoch [10/20], Step [164/415], Loss: 0.4100
Epoch [10/20], Step [165/415], Loss: 0.3800
Epoch [10/20], Step [166/415], Loss: 0.2712
Epoch [10/20], Step [167/415], Loss: 0.2562
Epoch [10/20], Step [168/415], Loss: 0.3607
Epoch [10/20], Step [169/415], Loss: 0.2800
Epoch [10/20], Step [170/415], Loss: 0.5236
Epoch [10/20], Step [171/415], Loss: 0.2911
Epoch [10/20], Step [172/415], Loss: 0.2265
Epoch [10/20], Step [173/415], Loss: 0.3011
Epoch [10/20], Step [174/415], Loss: 0.4675
Epoch [10/20], Step [175/415], Loss: 0.3015
Epoch [10/20], Step [176/415], Loss: 0.4285
Epoch [10/20], Step [177/415], Loss: 0.6721
Epoch [10/20], Step [178/415], Loss: 0.2827
Epoch [10/20], Step [179/415], Loss: 0.3327
Epoch [10/20], Step [180/415], Loss: 0.1981
Epoch [10/20], Step [181/415], Loss: 0.2382
Epoch [10/20], Step [182/415], Loss: 0.1730
Epoch [10/20], Step [183/415], L

Epoch [10/20], Step [348/415], Loss: 0.6641
Epoch [10/20], Step [349/415], Loss: 0.5384
Epoch [10/20], Step [350/415], Loss: 0.1982
Epoch [10/20], Step [351/415], Loss: 0.8546
Epoch [10/20], Step [352/415], Loss: 0.2239
Epoch [10/20], Step [353/415], Loss: 1.2023
Epoch [10/20], Step [354/415], Loss: 0.4465
Epoch [10/20], Step [355/415], Loss: 0.3857
Epoch [10/20], Step [356/415], Loss: 0.1695
Epoch [10/20], Step [357/415], Loss: 0.7067
Epoch [10/20], Step [358/415], Loss: 0.4682
Epoch [10/20], Step [359/415], Loss: 0.2169
Epoch [10/20], Step [360/415], Loss: 0.8069
Epoch [10/20], Step [361/415], Loss: 0.3074
Epoch [10/20], Step [362/415], Loss: 0.1895
Epoch [10/20], Step [363/415], Loss: 0.7592
Epoch [10/20], Step [364/415], Loss: 0.1821
Epoch [10/20], Step [365/415], Loss: 0.4083
Epoch [10/20], Step [366/415], Loss: 0.4001
Epoch [10/20], Step [367/415], Loss: 0.4006
Epoch [10/20], Step [368/415], Loss: 0.6316
Epoch [10/20], Step [369/415], Loss: 0.4746
Epoch [10/20], Step [370/415], L

Epoch [11/20], Step [125/415], Loss: 0.2842
Epoch [11/20], Step [126/415], Loss: 0.2100
Epoch [11/20], Step [127/415], Loss: 0.2750
Epoch [11/20], Step [128/415], Loss: 0.2609
Epoch [11/20], Step [129/415], Loss: 0.6702
Epoch [11/20], Step [130/415], Loss: 0.1672
Epoch [11/20], Step [131/415], Loss: 0.4813
Epoch [11/20], Step [132/415], Loss: 0.2858
Epoch [11/20], Step [133/415], Loss: 0.2892
Epoch [11/20], Step [134/415], Loss: 0.1255
Epoch [11/20], Step [135/415], Loss: 0.2033
Epoch [11/20], Step [136/415], Loss: 0.3128
Epoch [11/20], Step [137/415], Loss: 0.2397
Epoch [11/20], Step [138/415], Loss: 0.2227
Epoch [11/20], Step [139/415], Loss: 0.2137
Epoch [11/20], Step [140/415], Loss: 0.4299
Epoch [11/20], Step [141/415], Loss: 0.4008
Epoch [11/20], Step [142/415], Loss: 0.1308
Epoch [11/20], Step [143/415], Loss: 0.4534
Epoch [11/20], Step [144/415], Loss: 0.1535
Epoch [11/20], Step [145/415], Loss: 0.3916
Epoch [11/20], Step [146/415], Loss: 0.5421
Epoch [11/20], Step [147/415], L

Epoch [11/20], Step [314/415], Loss: 0.4536
Epoch [11/20], Step [315/415], Loss: 0.4308
Epoch [11/20], Step [316/415], Loss: 0.3011
Epoch [11/20], Step [317/415], Loss: 0.3569
Epoch [11/20], Step [318/415], Loss: 0.2795
Epoch [11/20], Step [319/415], Loss: 0.3969
Epoch [11/20], Step [320/415], Loss: 0.7330
Epoch [11/20], Step [321/415], Loss: 0.1157
Epoch [11/20], Step [322/415], Loss: 0.4995
Epoch [11/20], Step [323/415], Loss: 0.3467
Epoch [11/20], Step [324/415], Loss: 0.4949
Epoch [11/20], Step [325/415], Loss: 0.1610
Epoch [11/20], Step [326/415], Loss: 0.9397
Epoch [11/20], Step [327/415], Loss: 0.2913
Epoch [11/20], Step [328/415], Loss: 0.4511
Epoch [11/20], Step [329/415], Loss: 0.2807
Epoch [11/20], Step [330/415], Loss: 0.1777
Epoch [11/20], Step [331/415], Loss: 1.1284
Epoch [11/20], Step [332/415], Loss: 0.1536
Epoch [11/20], Step [333/415], Loss: 0.1766
Epoch [11/20], Step [334/415], Loss: 0.1270
Epoch [11/20], Step [335/415], Loss: 0.1919
Epoch [11/20], Step [336/415], L

Epoch [12/20], Step [92/415], Loss: 0.2622
Epoch [12/20], Step [93/415], Loss: 0.1475
Epoch [12/20], Step [94/415], Loss: 0.2827
Epoch [12/20], Step [95/415], Loss: 0.0713
Epoch [12/20], Step [96/415], Loss: 0.1056
Epoch [12/20], Step [97/415], Loss: 0.4239
Epoch [12/20], Step [98/415], Loss: 0.3265
Epoch [12/20], Step [99/415], Loss: 0.3992
Epoch [12/20], Step [100/415], Loss: 0.1795
Epoch [12/20], Step [101/415], Loss: 0.2687
Epoch [12/20], Step [102/415], Loss: 0.1600
Epoch [12/20], Step [103/415], Loss: 0.4302
Epoch [12/20], Step [104/415], Loss: 0.3228
Epoch [12/20], Step [105/415], Loss: 0.1528
Epoch [12/20], Step [106/415], Loss: 0.2305
Epoch [12/20], Step [107/415], Loss: 0.1114
Epoch [12/20], Step [108/415], Loss: 0.1594
Epoch [12/20], Step [109/415], Loss: 0.1007
Epoch [12/20], Step [110/415], Loss: 0.2371
Epoch [12/20], Step [111/415], Loss: 0.1519
Epoch [12/20], Step [112/415], Loss: 0.5386
Epoch [12/20], Step [113/415], Loss: 0.1750
Epoch [12/20], Step [114/415], Loss: 0.2

Epoch [12/20], Step [281/415], Loss: 0.4253
Epoch [12/20], Step [282/415], Loss: 0.2632
Epoch [12/20], Step [283/415], Loss: 0.1036
Epoch [12/20], Step [284/415], Loss: 0.1720
Epoch [12/20], Step [285/415], Loss: 0.2324
Epoch [12/20], Step [286/415], Loss: 0.4509
Epoch [12/20], Step [287/415], Loss: 0.1311
Epoch [12/20], Step [288/415], Loss: 0.3883
Epoch [12/20], Step [289/415], Loss: 0.2564
Epoch [12/20], Step [290/415], Loss: 0.5604
Epoch [12/20], Step [291/415], Loss: 0.7540
Epoch [12/20], Step [292/415], Loss: 0.3339
Epoch [12/20], Step [293/415], Loss: 0.1164
Epoch [12/20], Step [294/415], Loss: 0.4893
Epoch [12/20], Step [295/415], Loss: 0.1197
Epoch [12/20], Step [296/415], Loss: 0.2655
Epoch [12/20], Step [297/415], Loss: 0.2159
Epoch [12/20], Step [298/415], Loss: 0.2698
Epoch [12/20], Step [299/415], Loss: 0.2835
Epoch [12/20], Step [300/415], Loss: 0.1184
Epoch [12/20], Step [301/415], Loss: 0.3788
Epoch [12/20], Step [302/415], Loss: 0.4316
Epoch [12/20], Step [303/415], L

Epoch [13/20], Step [57/415], Loss: 0.2691
Epoch [13/20], Step [58/415], Loss: 0.1755
Epoch [13/20], Step [59/415], Loss: 0.1853
Epoch [13/20], Step [60/415], Loss: 0.4314
Epoch [13/20], Step [61/415], Loss: 0.5178
Epoch [13/20], Step [62/415], Loss: 0.6422
Epoch [13/20], Step [63/415], Loss: 0.5564
Epoch [13/20], Step [64/415], Loss: 0.2985
Epoch [13/20], Step [65/415], Loss: 0.1539
Epoch [13/20], Step [66/415], Loss: 0.3284
Epoch [13/20], Step [67/415], Loss: 0.2154
Epoch [13/20], Step [68/415], Loss: 0.4564
Epoch [13/20], Step [69/415], Loss: 0.2218
Epoch [13/20], Step [70/415], Loss: 0.1775
Epoch [13/20], Step [71/415], Loss: 0.5000
Epoch [13/20], Step [72/415], Loss: 0.1224
Epoch [13/20], Step [73/415], Loss: 0.2932
Epoch [13/20], Step [74/415], Loss: 0.1913
Epoch [13/20], Step [75/415], Loss: 0.2555
Epoch [13/20], Step [76/415], Loss: 0.2547
Epoch [13/20], Step [77/415], Loss: 0.1669
Epoch [13/20], Step [78/415], Loss: 0.2639
Epoch [13/20], Step [79/415], Loss: 0.1965
Epoch [13/2

Epoch [13/20], Step [246/415], Loss: 0.2716
Epoch [13/20], Step [247/415], Loss: 0.2091
Epoch [13/20], Step [248/415], Loss: 0.3304
Epoch [13/20], Step [249/415], Loss: 0.1909
Epoch [13/20], Step [250/415], Loss: 0.3799
Epoch [13/20], Step [251/415], Loss: 0.2837
Epoch [13/20], Step [252/415], Loss: 0.2844
Epoch [13/20], Step [253/415], Loss: 0.1693
Epoch [13/20], Step [254/415], Loss: 0.3066
Epoch [13/20], Step [255/415], Loss: 0.4132
Epoch [13/20], Step [256/415], Loss: 0.1614
Epoch [13/20], Step [257/415], Loss: 0.3812
Epoch [13/20], Step [258/415], Loss: 0.2321
Epoch [13/20], Step [259/415], Loss: 0.2114
Epoch [13/20], Step [260/415], Loss: 0.3090
Epoch [13/20], Step [261/415], Loss: 0.3328
Epoch [13/20], Step [262/415], Loss: 0.1917
Epoch [13/20], Step [263/415], Loss: 0.2479
Epoch [13/20], Step [264/415], Loss: 0.1051
Epoch [13/20], Step [265/415], Loss: 0.5515
Epoch [13/20], Step [266/415], Loss: 0.1486
Epoch [13/20], Step [267/415], Loss: 0.3921
Epoch [13/20], Step [268/415], L

Epoch [14/20], Step [22/415], Loss: 0.3575
Epoch [14/20], Step [23/415], Loss: 0.1622
Epoch [14/20], Step [24/415], Loss: 0.1701
Epoch [14/20], Step [25/415], Loss: 0.1253
Epoch [14/20], Step [26/415], Loss: 0.3745
Epoch [14/20], Step [27/415], Loss: 0.2980
Epoch [14/20], Step [28/415], Loss: 0.2022
Epoch [14/20], Step [29/415], Loss: 0.3270
Epoch [14/20], Step [30/415], Loss: 0.1710
Epoch [14/20], Step [31/415], Loss: 0.2336
Epoch [14/20], Step [32/415], Loss: 0.5168
Epoch [14/20], Step [33/415], Loss: 0.3695
Epoch [14/20], Step [34/415], Loss: 0.1901
Epoch [14/20], Step [35/415], Loss: 0.1075
Epoch [14/20], Step [36/415], Loss: 0.5151
Epoch [14/20], Step [37/415], Loss: 0.1891
Epoch [14/20], Step [38/415], Loss: 0.4287
Epoch [14/20], Step [39/415], Loss: 0.1032
Epoch [14/20], Step [40/415], Loss: 0.1175
Epoch [14/20], Step [41/415], Loss: 0.1435
Epoch [14/20], Step [42/415], Loss: 0.4176
Epoch [14/20], Step [43/415], Loss: 0.1838
Epoch [14/20], Step [44/415], Loss: 0.3903
Epoch [14/2

Epoch [14/20], Step [212/415], Loss: 0.2574
Epoch [14/20], Step [213/415], Loss: 0.1350
Epoch [14/20], Step [214/415], Loss: 0.2260
Epoch [14/20], Step [215/415], Loss: 0.5436
Epoch [14/20], Step [216/415], Loss: 0.3684
Epoch [14/20], Step [217/415], Loss: 0.2183
Epoch [14/20], Step [218/415], Loss: 0.2032
Epoch [14/20], Step [219/415], Loss: 0.2780
Epoch [14/20], Step [220/415], Loss: 0.2893
Epoch [14/20], Step [221/415], Loss: 0.1742
Epoch [14/20], Step [222/415], Loss: 0.1265
Epoch [14/20], Step [223/415], Loss: 0.2643
Epoch [14/20], Step [224/415], Loss: 0.1291
Epoch [14/20], Step [225/415], Loss: 0.5370
Epoch [14/20], Step [226/415], Loss: 0.3852
Epoch [14/20], Step [227/415], Loss: 0.3355
Epoch [14/20], Step [228/415], Loss: 0.4969
Epoch [14/20], Step [229/415], Loss: 0.4837
Epoch [14/20], Step [230/415], Loss: 0.5378
Epoch [14/20], Step [231/415], Loss: 0.4386
Epoch [14/20], Step [232/415], Loss: 0.2660
Epoch [14/20], Step [233/415], Loss: 0.2566
Epoch [14/20], Step [234/415], L

Epoch [14/20], Step [401/415], Loss: 0.3229
Epoch [14/20], Step [402/415], Loss: 0.3142
Epoch [14/20], Step [403/415], Loss: 0.1010
Epoch [14/20], Step [404/415], Loss: 0.1477
Epoch [14/20], Step [405/415], Loss: 0.2606
Epoch [14/20], Step [406/415], Loss: 0.2032
Epoch [14/20], Step [407/415], Loss: 0.2409
Epoch [14/20], Step [408/415], Loss: 0.1665
Epoch [14/20], Step [409/415], Loss: 0.2472
Epoch [14/20], Step [410/415], Loss: 0.2788
Epoch [14/20], Step [411/415], Loss: 0.3830
Epoch [14/20], Step [412/415], Loss: 0.1663
Epoch [14/20], Step [413/415], Loss: 0.5075
Epoch [14/20], Step [414/415], Loss: 0.3448
Epoch [14/20], Step [415/415], Loss: 0.1959
Epoch [15/20], Step [1/415], Loss: 0.2136
Epoch [15/20], Step [2/415], Loss: 0.3605
Epoch [15/20], Step [3/415], Loss: 0.4642
Epoch [15/20], Step [4/415], Loss: 0.3202
Epoch [15/20], Step [5/415], Loss: 0.1579
Epoch [15/20], Step [6/415], Loss: 0.2308
Epoch [15/20], Step [7/415], Loss: 0.2441
Epoch [15/20], Step [8/415], Loss: 0.3072
Epoc

Epoch [15/20], Step [176/415], Loss: 0.1613
Epoch [15/20], Step [177/415], Loss: 0.1413
Epoch [15/20], Step [178/415], Loss: 0.1444
Epoch [15/20], Step [179/415], Loss: 0.1568
Epoch [15/20], Step [180/415], Loss: 0.2609
Epoch [15/20], Step [181/415], Loss: 0.0889
Epoch [15/20], Step [182/415], Loss: 0.3671
Epoch [15/20], Step [183/415], Loss: 0.4362
Epoch [15/20], Step [184/415], Loss: 0.5002
Epoch [15/20], Step [185/415], Loss: 0.2462
Epoch [15/20], Step [186/415], Loss: 0.3129
Epoch [15/20], Step [187/415], Loss: 0.2582
Epoch [15/20], Step [188/415], Loss: 0.3148
Epoch [15/20], Step [189/415], Loss: 0.6479
Epoch [15/20], Step [190/415], Loss: 0.1352
Epoch [15/20], Step [191/415], Loss: 0.0778
Epoch [15/20], Step [192/415], Loss: 0.1711
Epoch [15/20], Step [193/415], Loss: 0.2486
Epoch [15/20], Step [194/415], Loss: 0.1934
Epoch [15/20], Step [195/415], Loss: 0.0798
Epoch [15/20], Step [196/415], Loss: 0.0712
Epoch [15/20], Step [197/415], Loss: 0.4730
Epoch [15/20], Step [198/415], L

Epoch [15/20], Step [365/415], Loss: 0.1729
Epoch [15/20], Step [366/415], Loss: 0.0895
Epoch [15/20], Step [367/415], Loss: 0.2175
Epoch [15/20], Step [368/415], Loss: 0.2508
Epoch [15/20], Step [369/415], Loss: 0.2860
Epoch [15/20], Step [370/415], Loss: 0.2723
Epoch [15/20], Step [371/415], Loss: 0.3608
Epoch [15/20], Step [372/415], Loss: 0.2299
Epoch [15/20], Step [373/415], Loss: 0.2669
Epoch [15/20], Step [374/415], Loss: 0.2279
Epoch [15/20], Step [375/415], Loss: 0.0775
Epoch [15/20], Step [376/415], Loss: 0.3739
Epoch [15/20], Step [377/415], Loss: 0.3615
Epoch [15/20], Step [378/415], Loss: 0.1873
Epoch [15/20], Step [379/415], Loss: 0.0672
Epoch [15/20], Step [380/415], Loss: 0.2386
Epoch [15/20], Step [381/415], Loss: 0.2663
Epoch [15/20], Step [382/415], Loss: 0.1679
Epoch [15/20], Step [383/415], Loss: 0.2550
Epoch [15/20], Step [384/415], Loss: 0.4148
Epoch [15/20], Step [385/415], Loss: 0.1470
Epoch [15/20], Step [386/415], Loss: 0.4950
Epoch [15/20], Step [387/415], L

Epoch [16/20], Step [141/415], Loss: 0.0889
Epoch [16/20], Step [142/415], Loss: 0.3009
Epoch [16/20], Step [143/415], Loss: 0.2088
Epoch [16/20], Step [144/415], Loss: 0.1961
Epoch [16/20], Step [145/415], Loss: 0.1823
Epoch [16/20], Step [146/415], Loss: 0.2133
Epoch [16/20], Step [147/415], Loss: 0.1332
Epoch [16/20], Step [148/415], Loss: 0.1738
Epoch [16/20], Step [149/415], Loss: 0.2919
Epoch [16/20], Step [150/415], Loss: 0.4602
Epoch [16/20], Step [151/415], Loss: 0.2569
Epoch [16/20], Step [152/415], Loss: 0.1350
Epoch [16/20], Step [153/415], Loss: 0.2008
Epoch [16/20], Step [154/415], Loss: 0.1759
Epoch [16/20], Step [155/415], Loss: 0.1328
Epoch [16/20], Step [156/415], Loss: 0.0812
Epoch [16/20], Step [157/415], Loss: 0.2816
Epoch [16/20], Step [158/415], Loss: 0.2355
Epoch [16/20], Step [159/415], Loss: 0.1320
Epoch [16/20], Step [160/415], Loss: 0.2860
Epoch [16/20], Step [161/415], Loss: 0.2157
Epoch [16/20], Step [162/415], Loss: 0.4524
Epoch [16/20], Step [163/415], L

Epoch [16/20], Step [330/415], Loss: 0.2155
Epoch [16/20], Step [331/415], Loss: 0.1353
Epoch [16/20], Step [332/415], Loss: 0.2110
Epoch [16/20], Step [333/415], Loss: 0.1883
Epoch [16/20], Step [334/415], Loss: 0.1290
Epoch [16/20], Step [335/415], Loss: 0.2688
Epoch [16/20], Step [336/415], Loss: 0.3267
Epoch [16/20], Step [337/415], Loss: 0.4453
Epoch [16/20], Step [338/415], Loss: 0.1468
Epoch [16/20], Step [339/415], Loss: 0.1890
Epoch [16/20], Step [340/415], Loss: 0.4680
Epoch [16/20], Step [341/415], Loss: 0.1584
Epoch [16/20], Step [342/415], Loss: 0.3042
Epoch [16/20], Step [343/415], Loss: 0.3611
Epoch [16/20], Step [344/415], Loss: 0.0952
Epoch [16/20], Step [345/415], Loss: 0.4610
Epoch [16/20], Step [346/415], Loss: 0.4550
Epoch [16/20], Step [347/415], Loss: 0.4953
Epoch [16/20], Step [348/415], Loss: 0.1230
Epoch [16/20], Step [349/415], Loss: 0.3641
Epoch [16/20], Step [350/415], Loss: 0.2150
Epoch [16/20], Step [351/415], Loss: 0.1774
Epoch [16/20], Step [352/415], L

Epoch [17/20], Step [105/415], Loss: 0.1759
Epoch [17/20], Step [106/415], Loss: 0.1980
Epoch [17/20], Step [107/415], Loss: 0.1464
Epoch [17/20], Step [108/415], Loss: 0.1961
Epoch [17/20], Step [109/415], Loss: 0.0666
Epoch [17/20], Step [110/415], Loss: 0.2858
Epoch [17/20], Step [111/415], Loss: 0.2650
Epoch [17/20], Step [112/415], Loss: 0.3164
Epoch [17/20], Step [113/415], Loss: 0.1694
Epoch [17/20], Step [114/415], Loss: 0.2228
Epoch [17/20], Step [115/415], Loss: 0.3381
Epoch [17/20], Step [116/415], Loss: 0.1247
Epoch [17/20], Step [117/415], Loss: 0.2676
Epoch [17/20], Step [118/415], Loss: 0.2948
Epoch [17/20], Step [119/415], Loss: 0.3618
Epoch [17/20], Step [120/415], Loss: 0.0917
Epoch [17/20], Step [121/415], Loss: 0.1379
Epoch [17/20], Step [122/415], Loss: 0.0815
Epoch [17/20], Step [123/415], Loss: 0.2437
Epoch [17/20], Step [124/415], Loss: 0.2926
Epoch [17/20], Step [125/415], Loss: 0.2418
Epoch [17/20], Step [126/415], Loss: 0.2696
Epoch [17/20], Step [127/415], L

Epoch [17/20], Step [294/415], Loss: 0.2292
Epoch [17/20], Step [295/415], Loss: 0.4376
Epoch [17/20], Step [296/415], Loss: 0.1910
Epoch [17/20], Step [297/415], Loss: 0.4733
Epoch [17/20], Step [298/415], Loss: 0.3012
Epoch [17/20], Step [299/415], Loss: 0.1964
Epoch [17/20], Step [300/415], Loss: 0.0719
Epoch [17/20], Step [301/415], Loss: 0.3345
Epoch [17/20], Step [302/415], Loss: 0.2172
Epoch [17/20], Step [303/415], Loss: 0.2211
Epoch [17/20], Step [304/415], Loss: 0.1420
Epoch [17/20], Step [305/415], Loss: 0.0762
Epoch [17/20], Step [306/415], Loss: 0.1675
Epoch [17/20], Step [307/415], Loss: 0.3395
Epoch [17/20], Step [308/415], Loss: 0.1520
Epoch [17/20], Step [309/415], Loss: 0.1720
Epoch [17/20], Step [310/415], Loss: 0.2680
Epoch [17/20], Step [311/415], Loss: 0.2397
Epoch [17/20], Step [312/415], Loss: 0.2057
Epoch [17/20], Step [313/415], Loss: 0.2877
Epoch [17/20], Step [314/415], Loss: 0.3214
Epoch [17/20], Step [315/415], Loss: 0.1068
Epoch [17/20], Step [316/415], L

Epoch [18/20], Step [71/415], Loss: 0.1592
Epoch [18/20], Step [72/415], Loss: 0.2411
Epoch [18/20], Step [73/415], Loss: 0.2487
Epoch [18/20], Step [74/415], Loss: 0.2476
Epoch [18/20], Step [75/415], Loss: 0.1374
Epoch [18/20], Step [76/415], Loss: 0.3109
Epoch [18/20], Step [77/415], Loss: 0.1241
Epoch [18/20], Step [78/415], Loss: 0.2281
Epoch [18/20], Step [79/415], Loss: 0.0771
Epoch [18/20], Step [80/415], Loss: 0.1751
Epoch [18/20], Step [81/415], Loss: 0.1605
Epoch [18/20], Step [82/415], Loss: 0.1419
Epoch [18/20], Step [83/415], Loss: 0.0858
Epoch [18/20], Step [84/415], Loss: 0.1310
Epoch [18/20], Step [85/415], Loss: 0.1448
Epoch [18/20], Step [86/415], Loss: 0.1463
Epoch [18/20], Step [87/415], Loss: 0.1809
Epoch [18/20], Step [88/415], Loss: 0.2489
Epoch [18/20], Step [89/415], Loss: 0.6542
Epoch [18/20], Step [90/415], Loss: 0.1655
Epoch [18/20], Step [91/415], Loss: 0.2656
Epoch [18/20], Step [92/415], Loss: 0.1454
Epoch [18/20], Step [93/415], Loss: 0.1953
Epoch [18/2

Epoch [18/20], Step [260/415], Loss: 0.1222
Epoch [18/20], Step [261/415], Loss: 0.1835
Epoch [18/20], Step [262/415], Loss: 0.1957
Epoch [18/20], Step [263/415], Loss: 0.1789
Epoch [18/20], Step [264/415], Loss: 0.1003
Epoch [18/20], Step [265/415], Loss: 0.3154
Epoch [18/20], Step [266/415], Loss: 0.0917
Epoch [18/20], Step [267/415], Loss: 0.3150
Epoch [18/20], Step [268/415], Loss: 0.1968
Epoch [18/20], Step [269/415], Loss: 0.1846
Epoch [18/20], Step [270/415], Loss: 0.1903
Epoch [18/20], Step [271/415], Loss: 0.4785
Epoch [18/20], Step [272/415], Loss: 0.1613
Epoch [18/20], Step [273/415], Loss: 0.5052
Epoch [18/20], Step [274/415], Loss: 0.1102
Epoch [18/20], Step [275/415], Loss: 0.1077
Epoch [18/20], Step [276/415], Loss: 0.1243
Epoch [18/20], Step [277/415], Loss: 0.2157
Epoch [18/20], Step [278/415], Loss: 0.1735
Epoch [18/20], Step [279/415], Loss: 0.3933
Epoch [18/20], Step [280/415], Loss: 0.6820
Epoch [18/20], Step [281/415], Loss: 0.1682
Epoch [18/20], Step [282/415], L

Epoch [19/20], Step [36/415], Loss: 0.1900
Epoch [19/20], Step [37/415], Loss: 0.2881
Epoch [19/20], Step [38/415], Loss: 0.1500
Epoch [19/20], Step [39/415], Loss: 0.1652
Epoch [19/20], Step [40/415], Loss: 0.1611
Epoch [19/20], Step [41/415], Loss: 0.2302
Epoch [19/20], Step [42/415], Loss: 0.1162
Epoch [19/20], Step [43/415], Loss: 0.0661
Epoch [19/20], Step [44/415], Loss: 0.2703
Epoch [19/20], Step [45/415], Loss: 0.2622
Epoch [19/20], Step [46/415], Loss: 0.3110
Epoch [19/20], Step [47/415], Loss: 0.1409
Epoch [19/20], Step [48/415], Loss: 0.0776
Epoch [19/20], Step [49/415], Loss: 0.3988
Epoch [19/20], Step [50/415], Loss: 0.2578
Epoch [19/20], Step [51/415], Loss: 0.1811
Epoch [19/20], Step [52/415], Loss: 0.2428
Epoch [19/20], Step [53/415], Loss: 0.0656
Epoch [19/20], Step [54/415], Loss: 0.0886
Epoch [19/20], Step [55/415], Loss: 0.1196
Epoch [19/20], Step [56/415], Loss: 0.2878
Epoch [19/20], Step [57/415], Loss: 0.1440
Epoch [19/20], Step [58/415], Loss: 0.1552
Epoch [19/2

Epoch [19/20], Step [226/415], Loss: 0.2571
Epoch [19/20], Step [227/415], Loss: 0.4549
Epoch [19/20], Step [228/415], Loss: 0.2315
Epoch [19/20], Step [229/415], Loss: 0.1833
Epoch [19/20], Step [230/415], Loss: 0.2294
Epoch [19/20], Step [231/415], Loss: 0.2448
Epoch [19/20], Step [232/415], Loss: 0.2237
Epoch [19/20], Step [233/415], Loss: 0.4418
Epoch [19/20], Step [234/415], Loss: 0.4199
Epoch [19/20], Step [235/415], Loss: 0.1527
Epoch [19/20], Step [236/415], Loss: 0.1139
Epoch [19/20], Step [237/415], Loss: 0.1565
Epoch [19/20], Step [238/415], Loss: 0.0593
Epoch [19/20], Step [239/415], Loss: 0.1305
Epoch [19/20], Step [240/415], Loss: 0.0932
Epoch [19/20], Step [241/415], Loss: 0.2873
Epoch [19/20], Step [242/415], Loss: 0.2463
Epoch [19/20], Step [243/415], Loss: 0.1232
Epoch [19/20], Step [244/415], Loss: 0.2741
Epoch [19/20], Step [245/415], Loss: 0.2464
Epoch [19/20], Step [246/415], Loss: 0.0887
Epoch [19/20], Step [247/415], Loss: 0.1566
Epoch [19/20], Step [248/415], L

Epoch [20/20], Step [1/415], Loss: 0.1084
Epoch [20/20], Step [2/415], Loss: 0.2188
Epoch [20/20], Step [3/415], Loss: 0.1888
Epoch [20/20], Step [4/415], Loss: 0.1304
Epoch [20/20], Step [5/415], Loss: 0.3698
Epoch [20/20], Step [6/415], Loss: 0.3004
Epoch [20/20], Step [7/415], Loss: 0.2356
Epoch [20/20], Step [8/415], Loss: 0.4009
Epoch [20/20], Step [9/415], Loss: 0.1831
Epoch [20/20], Step [10/415], Loss: 0.3372
Epoch [20/20], Step [11/415], Loss: 0.2297
Epoch [20/20], Step [12/415], Loss: 0.3112
Epoch [20/20], Step [13/415], Loss: 0.1104
Epoch [20/20], Step [14/415], Loss: 0.1550
Epoch [20/20], Step [15/415], Loss: 0.1588
Epoch [20/20], Step [16/415], Loss: 0.2677
Epoch [20/20], Step [17/415], Loss: 0.0776
Epoch [20/20], Step [18/415], Loss: 0.1527
Epoch [20/20], Step [19/415], Loss: 0.1093
Epoch [20/20], Step [20/415], Loss: 0.2686
Epoch [20/20], Step [21/415], Loss: 0.2470
Epoch [20/20], Step [22/415], Loss: 0.2116
Epoch [20/20], Step [23/415], Loss: 0.0849
Epoch [20/20], Step 

Epoch [20/20], Step [190/415], Loss: 0.1533
Epoch [20/20], Step [191/415], Loss: 0.1015
Epoch [20/20], Step [192/415], Loss: 0.1961
Epoch [20/20], Step [193/415], Loss: 0.4451
Epoch [20/20], Step [194/415], Loss: 0.1312
Epoch [20/20], Step [195/415], Loss: 0.2337
Epoch [20/20], Step [196/415], Loss: 0.1848
Epoch [20/20], Step [197/415], Loss: 0.3569
Epoch [20/20], Step [198/415], Loss: 0.1052
Epoch [20/20], Step [199/415], Loss: 0.0946
Epoch [20/20], Step [200/415], Loss: 0.2940
Epoch [20/20], Step [201/415], Loss: 0.1912
Epoch [20/20], Step [202/415], Loss: 0.2064
Epoch [20/20], Step [203/415], Loss: 0.1959
Epoch [20/20], Step [204/415], Loss: 0.1970
Epoch [20/20], Step [205/415], Loss: 0.2012
Epoch [20/20], Step [206/415], Loss: 0.1611
Epoch [20/20], Step [207/415], Loss: 0.1682
Epoch [20/20], Step [208/415], Loss: 0.1521
Epoch [20/20], Step [209/415], Loss: 0.1948
Epoch [20/20], Step [210/415], Loss: 0.1507
Epoch [20/20], Step [211/415], Loss: 0.2137
Epoch [20/20], Step [212/415], L

Epoch [20/20], Step [379/415], Loss: 0.1057
Epoch [20/20], Step [380/415], Loss: 0.1363
Epoch [20/20], Step [381/415], Loss: 0.1025
Epoch [20/20], Step [382/415], Loss: 0.0792
Epoch [20/20], Step [383/415], Loss: 0.2120
Epoch [20/20], Step [384/415], Loss: 0.1594
Epoch [20/20], Step [385/415], Loss: 0.1930
Epoch [20/20], Step [386/415], Loss: 0.2036
Epoch [20/20], Step [387/415], Loss: 0.1401
Epoch [20/20], Step [388/415], Loss: 0.1283
Epoch [20/20], Step [389/415], Loss: 0.1646
Epoch [20/20], Step [390/415], Loss: 0.1530
Epoch [20/20], Step [391/415], Loss: 0.2948
Epoch [20/20], Step [392/415], Loss: 0.3176
Epoch [20/20], Step [393/415], Loss: 0.2342
Epoch [20/20], Step [394/415], Loss: 0.1028
Epoch [20/20], Step [395/415], Loss: 0.1768
Epoch [20/20], Step [396/415], Loss: 0.1721
Epoch [20/20], Step [397/415], Loss: 0.1305
Epoch [20/20], Step [398/415], Loss: 0.1859
Epoch [20/20], Step [399/415], Loss: 0.0896
Epoch [20/20], Step [400/415], Loss: 0.1726
Epoch [20/20], Step [401/415], L

In [38]:
best_epoch

16

In [39]:
model.eval()
# eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
total_preds = np.array([])
total_labels = np.array([])
with torch.no_grad():
    correct = 0
    total = 0
    for i in test_loader:
        images = i['outer_product']
        labels = i['Label']
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images) 
        outputs = outputs.cpu().detach().numpy().flatten()
        labels =labels.cpu().detach().numpy().flatten()
        total_preds = np.concatenate([total_preds, outputs])
        total_labels = np.concatenate([total_labels, labels])
#         total_preds = torch.cat(total_preds, outputs.cpu(), 0 )
#         total_labels = torch.cat(total_labels, labels.cpu(), 0)
#         break

In [40]:
G,P = total_labels, total_preds

In [41]:
rmse(G,P)

1.1205611430600004

In [42]:
mse(G,P)

1.2556572753359345

In [43]:
pearson(G,P)

-0.1970601399702593

In [44]:
ci(G,P)

0.38590745774380225

In [45]:
model = ConvNet().to(device)

In [47]:
model.load_state_dict(torch.load('./best_sim-CNN-DTA_davis_fold10.model'))

<All keys matched successfully>

In [48]:
model.eval()


ConvNet(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 18, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=9720, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=1, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [49]:
total_preds = np.array([])
total_labels = np.array([])
with torch.no_grad():
    correct = 0
    total = 0
    for i in test_loader:
        images = i['outer_product']
        labels = i['Label']
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images) 
        outputs = outputs.cpu().detach().numpy().flatten()
        labels =labels.cpu().detach().numpy().flatten()
        total_preds = np.concatenate([total_preds, outputs])
        total_labels = np.concatenate([total_labels, labels])
#         total_preds = torch.cat(total_preds, outputs.cpu(), 0 )
#         total_labels = torch.cat(total_labels, labels.cpu(), 0)
#         break

In [50]:
G,P = total_labels, total_preds

In [51]:
rmse(G,P)

1.1001544055398707

In [52]:
print(pearson(G,P),ci(G,P),rmse(G,P),mse(G,P))

0.12577732589737653 0.5580237318437462 1.1001544055398707 1.210339716028786
