In [1]:
from My_Pka_Model import Pka_basic,Pka_acidic
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
!nvidia-smi

Sat Nov 21 11:16:25 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01    Driver Version: 440.33.01    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:06:00.0 Off |                    0 |
| N/A   37C    P0    34W / 250W |   5037MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE...  Off  | 00000000:2F:00.0 Off |                    0 |
| N/A   36C    P0    39W / 250W |     12MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-PCIE...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   

In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [5]:
batch_size = 1024
epoch_num = 1000
layer_num = 6
learning_rate = 0.0003
weight_decay = 0.0003

In [6]:
model = Pka_acidic(node_feat_size = 74,
                   edge_feat_size = 12,
                   output_size = 1,
                   num_layers= layer_num,
                   graph_feat_size=200,
                   dropout=0.2).to(device)

In [7]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [8]:
train_loader = load_data('./Dataset/acidic_non_B_smiles.txt',batch_size = batch_size)
val_loader = load_data('./Dataset/Phenylboronic_acid_smiles.txt',batch_size = 512)


In [9]:
print(len(train_loader.dataset))
print(len(val_loader.dataset))

9015
24


In [10]:
train_RMSE_lis = []
train_MAE_lis = []
val_RMSE_lis = []
val_MAE_lis = []
test_RMSE_lis = []
test_MAE_lis = []
test_2_RMSE_lis = []
test_2_MAE_lis = []
cur_rmse_lis = []
min_val_rmse = 100

file_name = './Logger/try_B.txt'

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:\tval_RMSD:\tval_MAE:'
print(header)


for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
        #print(N,train_RMSD)

        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N

        
        
        log = '{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4))
        print(log)
        
        train_RMSE_lis.append(train_RMSE)
        train_MAE_lis.append(train_MAE)
        val_RMSE_lis.append(val_RMSE)
        val_MAE_lis.append(val_MAE)
        

            
            
            

#         #save model
#         if val_RMSE <= min_val_rmse:
#             min_val_rmse = val_RMSE
#             if epoch >= 300:
#                 torch.save(model.state_dict(), './Trained_model/ramdom_split_acidic_4.pkl')
#                 print('saved')

#         with open(file_name,'a+') as f:
#             f.write(log)
#             f.write('\n')

epoch:	train_RMSD:	train_MAE:	val_RMSD:	val_MAE:




0	6.9881	6.0815	7.6751	7.5718
1	6.4058	5.4114	7.0194	6.9065
2	4.6211	3.3636	4.75	4.5814
3	4.0864	2.9207	3.8875	3.6795
4	3.873	2.8176	3.4792	3.248
5	3.7219	2.7827	3.1394	2.9284
6	3.6076	2.7797	2.8243	2.6388
7	3.5274	2.7934	2.5329	2.3637
8	3.4821	2.8149	2.2905	2.127
9	3.4634	2.8355	2.122	1.9562
10	3.4579	2.8509	2.0177	1.847
11	3.4569	2.8607	1.9589	1.7838
12	3.4571	2.8666	1.9258	1.7493
13	3.4573	2.8688	1.9139	1.7376
14	3.4572	2.8679	1.9189	1.7425
15	3.4572	2.8678	1.9192	1.7428
16	3.457	2.8652	1.9334	1.7568
17	3.4569	2.8646	1.9371	1.7605
18	3.4569	2.8651	1.9342	1.7576
19	3.4569	2.865	1.9347	1.7582
20	3.4569	2.8645	1.9371	1.7605
21	3.4568	2.8621	1.9511	1.7753
22	3.4566	2.8633	1.9435	1.7671
23	3.4565	2.8644	1.9368	1.7603
24	3.4563	2.8641	1.9378	1.7613
25	3.4559	2.8619	1.9493	1.7734
26	3.4557	2.8625	1.9446	1.7684
27	3.4554	2.8637	1.9368	1.7603
28	3.4549	2.8636	1.9357	1.7592
29	3.4542	2.8617	1.9433	1.7671
30	3.453	2.8594	1.9521	1.7768
31	3.451	2.8571	1.9564	1.7816
32	3.4475	2.8511	1.976	1.803


263	1.1232	0.7327	1.1621	1.0184
264	1.0724	0.6788	0.7058	0.6005
265	1.0713	0.6698	0.7232	0.5205
266	1.0643	0.6664	0.708	0.4848
267	1.0637	0.6718	0.67	0.5062
268	1.0492	0.6547	0.6886	0.4857
269	1.052	0.6574	0.7367	0.5304
270	1.0468	0.6542	0.7553	0.5408
271	1.049	0.6638	0.679	0.4614
272	1.0529	0.6641	0.6939	0.5881
273	1.0442	0.6564	0.6829	0.5788
274	1.0475	0.6581	0.6565	0.4609
275	1.0624	0.6765	0.8582	0.7716
276	1.0436	0.6587	0.7609	0.6611
277	1.0563	0.6771	0.6842	0.5804
278	1.0403	0.6579	0.6695	0.4703
279	1.0314	0.6499	0.6702	0.4726
280	1.0619	0.6826	0.6686	0.562
281	1.0484	0.6698	0.8493	0.6671
282	1.0906	0.7135	1.0614	0.9029
283	1.0366	0.6588	0.739	0.5295
284	1.0215	0.6422	0.6513	0.4611
285	1.0327	0.6616	0.6442	0.4537
286	1.0214	0.6452	0.6513	0.4497
287	1.0181	0.6413	0.6785	0.4856
288	1.0304	0.653	0.7193	0.5218
289	1.0191	0.6404	0.6928	0.5232
290	1.0128	0.6424	0.6972	0.5083
291	1.0201	0.6526	0.6792	0.4651
292	1.0658	0.7019	0.9729	0.8233
293	1.0073	0.6373	0.624	0.4499
294	1.0114	0.6418	

523	0.7404	0.5065	0.7142	0.4941
524	0.7381	0.5019	0.694	0.5387
525	0.7629	0.5309	0.8172	0.6986
526	0.7657	0.5324	0.8476	0.7075
527	0.7428	0.5083	0.7034	0.5656
528	0.744	0.5093	0.7367	0.5586
529	0.7423	0.5104	0.7098	0.4887
530	0.7369	0.5026	0.7658	0.5835
531	0.7391	0.507	0.6932	0.5382
532	0.7421	0.5145	0.6935	0.5581
533	0.7426	0.5123	0.7945	0.7278
534	0.7399	0.5072	0.7333	0.5653
535	0.7411	0.508	0.7072	0.4876
536	0.7431	0.5121	0.683	0.5037
537	0.7549	0.5282	0.71	0.4952
538	0.7513	0.5224	0.6774	0.4559
539	0.735	0.5027	0.6751	0.5206
540	0.7606	0.5276	0.7949	0.6292
541	0.7531	0.5238	0.7224	0.6003
542	0.7344	0.5065	0.6985	0.4865
543	0.7438	0.5139	0.6661	0.5267
544	0.8347	0.6072	0.7812	0.7048
545	0.7878	0.5547	0.7917	0.7059
546	0.7422	0.5173	0.7618	0.594
547	0.7377	0.5054	0.7892	0.6658
548	0.7306	0.5044	0.7441	0.5759
549	0.7412	0.5139	0.6619	0.4735
550	0.7288	0.5007	0.6588	0.4865
551	0.7457	0.5138	0.6094	0.4689
552	0.7296	0.5001	0.6947	0.486
553	0.7231	0.497	0.6658	0.4953
554	0.7382	0.5081	0

783	0.6802	0.4836	0.8846	0.6798
784	0.6633	0.4697	0.8502	0.6846
785	0.6327	0.4393	0.7721	0.6076
786	0.6297	0.4376	0.6995	0.5574
787	0.6236	0.4338	0.7901	0.5962
788	0.6369	0.4463	0.8615	0.6587
789	0.6233	0.4318	0.7916	0.6321
790	0.6298	0.4366	0.7216	0.5765
791	0.6259	0.4344	0.7198	0.579
792	0.6436	0.4563	0.6803	0.5595
793	0.6372	0.447	0.6615	0.5224
794	0.6313	0.4424	0.6437	0.507
795	0.6202	0.4287	0.724	0.5881
796	0.6319	0.442	0.6505	0.5321
797	0.626	0.435	0.6882	0.5533
798	0.6891	0.4995	0.7544	0.6287
799	0.6642	0.478	0.7445	0.6119
800	0.6204	0.4302	0.6911	0.5791
801	0.6211	0.433	0.6802	0.5529
802	0.6265	0.4374	0.7065	0.5762
803	0.6195	0.433	0.7002	0.5517
804	0.6189	0.4273	0.6062	0.4725
805	0.6191	0.4281	0.7214	0.5853
806	0.6302	0.4422	0.5972	0.4731
807	0.6288	0.4401	0.6199	0.4908
808	0.6164	0.4285	0.6139	0.4722
809	0.6305	0.4435	0.6638	0.5217
810	0.6378	0.4455	0.6181	0.4601
811	0.6207	0.4325	0.7168	0.5483
812	0.6325	0.4407	0.623	0.4733
813	0.624	0.4333	0.6953	0.5729
814	0.6148	0.4251	0.

In [11]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [12]:
for epoch in range(1000,1500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
        #print(N,train_RMSD)

        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N

        
        
        log = '{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4))
        print(log)
        
        train_RMSE_lis.append(train_RMSE)
        train_MAE_lis.append(train_MAE)
        val_RMSE_lis.append(val_RMSE)
        val_MAE_lis.append(val_MAE)

  # Remove the CWD from sys.path while we load stuff.


1000	0.5558	0.3827	0.7329	0.5799
1001	0.5519	0.379	0.671	0.5207
1002	0.5522	0.3779	0.7043	0.5524
1003	0.5508	0.3767	0.6715	0.5178
1004	0.55	0.3762	0.7127	0.5416
1005	0.5501	0.3754	0.687	0.5311
1006	0.5498	0.3755	0.6788	0.5192
1007	0.5492	0.3751	0.6425	0.5065
1008	0.5494	0.3761	0.7122	0.562
1009	0.5478	0.3738	0.6673	0.517
1010	0.5478	0.374	0.702	0.5458
1011	0.5466	0.3728	0.6654	0.5249
1012	0.5473	0.3744	0.6932	0.539
1013	0.5475	0.3743	0.687	0.5332
1014	0.5548	0.3807	0.6493	0.5086
1015	0.5475	0.3731	0.6657	0.5231
1016	0.5468	0.3738	0.6945	0.5454
1017	0.5469	0.3737	0.7252	0.5794
1018	0.5453	0.3721	0.7421	0.5839
1019	0.5464	0.373	0.7004	0.5481
1020	0.548	0.3743	0.6695	0.5195
1021	0.5449	0.3717	0.711	0.555
1022	0.5477	0.3746	0.7084	0.5644
1023	0.5463	0.3735	0.7069	0.5631
1024	0.547	0.3753	0.7082	0.5613
1025	0.5445	0.3711	0.6716	0.5188
1026	0.5453	0.372	0.6923	0.5439
1027	0.5475	0.3737	0.6453	0.5081
1028	0.5436	0.3705	0.7098	0.5539
1029	0.5443	0.3714	0.6843	0.5344
1030	0.5441	0.3712	0.6975	0

1252	0.5306	0.3622	0.6831	0.5418
1253	0.5301	0.3616	0.6865	0.548
1254	0.5319	0.3642	0.7205	0.5832
1255	0.5305	0.362	0.6447	0.5229
1256	0.5319	0.3628	0.6566	0.5316
1257	0.5328	0.3638	0.6536	0.5295
1258	0.5337	0.3645	0.6361	0.5171
1259	0.5314	0.3623	0.6534	0.5293
1260	0.5321	0.3628	0.6308	0.5171
1261	0.5302	0.3613	0.658	0.528
1262	0.532	0.3647	0.6963	0.5606
1263	0.5312	0.362	0.6571	0.5305
1264	0.5321	0.3631	0.6428	0.5245
1265	0.5304	0.3614	0.6834	0.5495
1266	0.5308	0.3617	0.6638	0.5376
1267	0.5297	0.3607	0.6614	0.5348
1268	0.5305	0.3617	0.6987	0.5604
1269	0.5299	0.3617	0.6976	0.563
1270	0.5302	0.3613	0.6535	0.5318
1271	0.5296	0.3608	0.6441	0.5269
1272	0.53	0.3619	0.6733	0.5418
1273	0.5301	0.3618	0.661	0.5348
1274	0.5303	0.3624	0.7109	0.5768
1275	0.5346	0.3679	0.724	0.5926
1276	0.533	0.3669	0.7311	0.6007
1277	0.5301	0.3616	0.7089	0.576
1278	0.5299	0.3617	0.6987	0.5694
1279	0.5289	0.3609	0.7022	0.571
1280	0.5299	0.3615	0.6901	0.5629
1281	0.5301	0.3618	0.7493	0.6163
1282	0.5286	0.3606	0.705

In [13]:
torch.save(model.state_dict(), './Trained_model/non_B_try_5.pkl')

In [None]:
# 一般到后面才开始学smalp6?

这份，非常稳定，但感觉值偏大？
不能降到0.5？


这个感觉还算稳定了


先看看模型的泛化能力再决定要不要主动做聚类

In [None]:
感觉不稳定的测试还是尽量避免