In [1]:
from My_Pka_Model import Pka_basic,Pka_acidic
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
!nvidia-smi

Sun Nov 22 04:26:36 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01    Driver Version: 440.33.01    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:06:00.0 Off |                    0 |
| N/A   37C    P0    34W / 250W |   3027MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE...  Off  | 00000000:2F:00.0 Off |                    0 |
| N/A   38C    P0    39W / 250W |   4125MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-PCIE...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   

In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [5]:
batch_size = 1024
epoch_num = 1000
layer_num = 7
learning_rate = 0.0003
weight_decay = 0.0003

In [6]:
model = Pka_acidic(node_feat_size = 74,
                   edge_feat_size = 12,
                   output_size = 1,
                   num_layers= layer_num,
                   graph_feat_size=200,
                   dropout=0.2).to(device)

In [7]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [8]:
train_loader = load_data('./Dataset/acidic_non_B_smiles_3.txt',batch_size = batch_size)
val_loader = load_data('./Dataset/Phenylboronic_acid_smiles(22).txt',batch_size = 512)


In [9]:
print(len(train_loader.dataset))
print(len(val_loader.dataset))

9013
22


In [10]:
train_RMSE_lis = []
train_MAE_lis = []
val_RMSE_lis = []
val_MAE_lis = []
test_RMSE_lis = []
test_MAE_lis = []
test_2_RMSE_lis = []
test_2_MAE_lis = []
cur_rmse_lis = []
min_val_rmse = 100

file_name = './Logger/try_B.txt'

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:\tval_RMSD:\tval_MAE:'
print(header)


for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
        #print(N,train_RMSD)

        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N

        
        
        log = '{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4))
        print(log)
        
        train_RMSE_lis.append(train_RMSE)
        train_MAE_lis.append(train_MAE)
        val_RMSE_lis.append(val_RMSE)
        val_MAE_lis.append(val_MAE)
        

            
            
            

#         #save model
#         if val_RMSE <= min_val_rmse:
#             min_val_rmse = val_RMSE
#             if epoch >= 300:
#                 torch.save(model.state_dict(), './Trained_model/ramdom_split_acidic_4.pkl')
#                 print('saved')

#         with open(file_name,'a+') as f:
#             f.write(log)
#             f.write('\n')

epoch:	train_RMSD:	train_MAE:	val_RMSD:	val_MAE:




0	7.0502	6.1522	7.7404	7.6334
1	6.2966	5.2832	6.8837	6.7632
2	4.2628	3.04	4.1922	3.9913
3	3.8003	2.7962	3.326	3.0875
4	3.621	2.7786	2.8691	2.6766
5	3.5254	2.7939	2.5293	2.3603
6	3.4737	2.8221	2.2342	2.0735
7	3.4588	2.8472	2.0507	1.8859
8	3.4571	2.8637	1.9528	1.7815
9	3.4575	2.8696	1.9207	1.7503
10	3.4579	2.8724	1.9062	1.7361
11	3.4573	2.868	1.9294	1.7587
12	3.4571	2.8654	1.9433	1.7722
13	3.4572	2.867	1.9348	1.764
14	3.4571	2.8642	1.9501	1.7788
15	3.4571	2.8645	1.9482	1.777
16	3.457	2.8636	1.953	1.7817
17	3.457	2.8638	1.9519	1.7805
18	3.4571	2.8641	1.9506	1.7793
19	3.4571	2.8653	1.9437	1.7727
20	3.457	2.8646	1.9476	1.7764
21	3.457	2.8641	1.9499	1.7787
22	3.4567	2.8628	1.9569	1.786
23	3.4563	2.8616	1.9626	1.7922
24	3.4558	2.8616	1.9605	1.7899
25	3.455	2.8602	1.9657	1.7957
26	3.4541	2.8608	1.9585	1.7879
27	3.4526	2.8589	1.9624	1.7922
28	3.4502	2.8572	1.961	1.7908
29	3.4465	2.8551	1.9544	1.7837
30	3.441	2.8448	1.985	1.8169
31	3.4332	2.8394	1.9762	1.8076
32	3.4223	2.8273	1.9892	1.8219
33	3.

264	0.9958	0.6356	0.5648	0.4256
265	0.9922	0.636	0.5944	0.4501
266	0.9982	0.641	0.716	0.5939
267	0.9865	0.6286	0.4998	0.3793
268	0.9897	0.638	0.5201	0.3737
269	0.9927	0.638	0.6893	0.555
270	0.9895	0.6426	0.5351	0.4168
271	1.0466	0.6921	0.7801	0.6517
272	1.0656	0.7075	0.8258	0.7118
273	1.0276	0.6715	0.8419	0.7219
274	0.9999	0.6425	0.5793	0.4458
275	0.9777	0.6281	0.5257	0.4057
276	0.979	0.6282	0.541	0.385
277	0.9843	0.6382	0.5942	0.494
278	0.9901	0.6477	0.5207	0.4037
279	0.9896	0.6527	0.5085	0.3976
280	0.999	0.6608	0.6073	0.5088
281	0.9669	0.6242	0.5052	0.3838
282	0.9623	0.618	0.5231	0.3867
283	0.9724	0.6357	0.5231	0.4147
284	0.9859	0.6567	0.5439	0.4184
285	1.0102	0.668	0.7396	0.6225
286	0.9627	0.6233	0.5315	0.4256
287	0.953	0.6149	0.5115	0.3652
288	0.9637	0.6299	0.5102	0.39
289	0.9661	0.6281	0.5695	0.4792
290	0.9604	0.6198	0.5516	0.3931
291	0.9563	0.6155	0.6396	0.4951
292	0.9579	0.6257	0.494	0.3703
293	0.9424	0.6105	0.5485	0.4074
294	0.9593	0.6274	0.5374	0.4448
295	0.9407	0.6071	0.5232	

524	0.7016	0.4866	0.4618	0.3641
525	0.7027	0.4841	0.4901	0.3918
526	0.7059	0.4929	0.5399	0.4559
527	0.7131	0.4937	0.4696	0.3474
528	0.7014	0.4865	0.4377	0.3394
529	0.7752	0.5472	0.679	0.5661
530	0.6979	0.4837	0.4612	0.3686
531	0.6957	0.4813	0.4482	0.3302
532	0.7117	0.5031	0.4553	0.3464
533	0.7069	0.4935	0.4706	0.346
534	0.6965	0.4854	0.4765	0.3705
535	0.6957	0.4804	0.4072	0.3059
536	0.7198	0.5071	0.6977	0.6102
537	0.7015	0.4891	0.5267	0.4326
538	0.7158	0.5045	0.5117	0.4086
539	0.6925	0.4793	0.4689	0.3503
540	0.6894	0.4783	0.5	0.3719
541	0.6947	0.4798	0.4748	0.3424
542	0.6913	0.482	0.462	0.3397
543	0.6911	0.4801	0.4729	0.3695
544	0.709	0.5002	0.4952	0.3629
545	0.7042	0.4999	0.5177	0.4195
546	0.7221	0.5095	0.6938	0.6094
547	0.6961	0.4835	0.4907	0.369
548	0.6925	0.4824	0.4625	0.3643
549	0.7157	0.5032	0.5423	0.4629
550	0.7386	0.53	0.4608	0.3708
551	0.6821	0.4736	0.4755	0.3775
552	0.7104	0.4945	0.5148	0.3755
553	0.6861	0.4756	0.4586	0.364
554	0.6892	0.4808	0.4379	0.3102
555	0.6849	0.478	0.4

784	0.6013	0.4249	0.5352	0.3862
785	0.5997	0.4207	0.4954	0.3638
786	0.5943	0.415	0.5379	0.4353
787	0.6039	0.4192	0.4954	0.3524
788	0.6096	0.4267	0.4972	0.3302
789	0.6004	0.4174	0.5421	0.445
790	0.5987	0.4192	0.522	0.4139
791	0.614	0.4318	0.7356	0.6207
792	0.5987	0.4163	0.4827	0.3365
793	0.5947	0.4138	0.4941	0.3439
794	0.6019	0.4204	0.5348	0.3905
795	0.5911	0.4106	0.5513	0.4399
796	0.6074	0.4265	0.4958	0.3435
797	0.6132	0.4316	0.4889	0.4033
798	0.6417	0.4603	0.5732	0.4952
799	0.5961	0.4157	0.5067	0.3913
800	0.5957	0.4166	0.6629	0.5714
801	0.5991	0.4187	0.5089	0.3567
802	0.5918	0.4129	0.5037	0.4169
803	0.5977	0.418	0.4405	0.3284
804	0.5936	0.4113	0.4712	0.315
805	0.592	0.4103	0.4692	0.3549
806	0.5978	0.4161	0.4979	0.3943
807	0.5881	0.4098	0.4478	0.29
808	0.597	0.4156	0.4659	0.3373
809	0.6004	0.4205	0.5033	0.3452
810	0.6117	0.4335	0.4904	0.3901
811	0.5893	0.412	0.4305	0.2769
812	0.5984	0.4139	0.6345	0.513
813	0.592	0.4135	0.4894	0.3877
814	0.5879	0.4104	0.4822	0.3505
815	0.5959	0.416	0.49

In [11]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [12]:
for epoch in range(1000,1500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
        #print(N,train_RMSD)

        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N

        
        
        log = '{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4))
        print(log)
        
        train_RMSE_lis.append(train_RMSE)
        train_MAE_lis.append(train_MAE)
        val_RMSE_lis.append(val_RMSE)
        val_MAE_lis.append(val_MAE)

  # Remove the CWD from sys.path while we load stuff.


1000	0.5383	0.372	0.4417	0.3052
1001	0.536	0.3687	0.4359	0.3038
1002	0.5325	0.3664	0.4286	0.2955
1003	0.5319	0.3664	0.4454	0.3044
1004	0.5338	0.37	0.4585	0.3185
1005	0.5299	0.3636	0.4485	0.3063
1006	0.5314	0.3648	0.4397	0.3073
1007	0.5289	0.363	0.4355	0.295
1008	0.5282	0.3621	0.4346	0.3082
1009	0.5282	0.3615	0.4381	0.3015
1010	0.5296	0.3645	0.4536	0.3195
1011	0.529	0.3629	0.4474	0.3171
1012	0.5269	0.3609	0.4448	0.3113
1013	0.5265	0.3617	0.448	0.3211
1014	0.5274	0.362	0.4468	0.3072
1015	0.5293	0.3632	0.4535	0.3087
1016	0.5272	0.3621	0.4542	0.3106
1017	0.5266	0.3618	0.4467	0.3039
1018	0.5274	0.3624	0.4443	0.3048
1019	0.5279	0.363	0.4539	0.3112
1020	0.526	0.3596	0.4474	0.3128
1021	0.5259	0.3606	0.4451	0.3204
1022	0.5269	0.3622	0.4607	0.319
1023	0.5244	0.3589	0.4535	0.3169
1024	0.5243	0.3594	0.4493	0.3198
1025	0.5252	0.3608	0.4565	0.3275
1026	0.5258	0.3613	0.4602	0.3231
1027	0.5267	0.3614	0.4568	0.3289
1028	0.5262	0.3609	0.4613	0.3223
1029	0.5248	0.3597	0.4507	0.3113
1030	0.5233	0.3578	0.4

1252	0.5112	0.3507	0.4668	0.3199
1253	0.5139	0.3517	0.493	0.3592
1254	0.5125	0.3505	0.4549	0.3198
1255	0.5124	0.3506	0.4631	0.3204
1256	0.5141	0.3519	0.4495	0.3182
1257	0.5152	0.3539	0.437	0.3075
1258	0.5114	0.3494	0.4651	0.3213
1259	0.511	0.3482	0.4586	0.3092
1260	0.5116	0.3494	0.4665	0.3184
1261	0.5119	0.3502	0.468	0.3182
1262	0.5105	0.3487	0.4599	0.316
1263	0.5122	0.3506	0.4502	0.3186
1264	0.5102	0.3483	0.4571	0.3109
1265	0.5103	0.348	0.463	0.3147
1266	0.51	0.3483	0.4685	0.3204
1267	0.5098	0.3476	0.4714	0.3221
1268	0.5091	0.3479	0.4702	0.3295
1269	0.5118	0.3495	0.5062	0.3876
1270	0.5115	0.3503	0.4648	0.3174
1271	0.5092	0.3482	0.4613	0.3143
1272	0.51	0.3477	0.4519	0.3089
1273	0.5104	0.3482	0.4612	0.3142
1274	0.5112	0.3493	0.4576	0.3219
1275	0.5089	0.348	0.4698	0.3234
1276	0.5092	0.3484	0.4723	0.3269
1277	0.5078	0.3464	0.4663	0.3268
1278	0.5114	0.3501	0.4808	0.3362
1279	0.5102	0.3492	0.4765	0.3311
1280	0.5142	0.3521	0.4684	0.3457
1281	0.5102	0.3476	0.4646	0.3241
1282	0.5089	0.3481	0.4

In [13]:
torch.save(model.state_dict(), './Trained_model/non_B3_try_layer_7_1.pkl')

In [15]:
from My_Pka_Model import Pka_basic_view,Pka_acidic_view
from sklearn.metrics import r2_score,mean_squared_error,mean_absolute_error

In [16]:
def predict(smiles,model_view):

    node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
    edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')
    bg = smiles_to_bigraph(smiles= smiles, 
                  node_featurizer=node_featurizer,
                  edge_featurizer=edge_featurizer,canonical_atom_order= False)

    with torch.no_grad():
        model_view.eval()
        molecule_pka,atom_pka = model_view(bg,bg.ndata['h'], bg.edata['h'])
        
    return molecule_pka,atom_pka

In [17]:
acid_pred = Pka_acidic_view(node_feat_size = 74,
                            edge_feat_size = 12,
                            output_size = 1,
                            num_layers= 6,
                            graph_feat_size=200,
                            dropout=0.2)

In [19]:
acid_pred.load_state_dict(torch.load('./Trained_model/non_B3_try_2.pkl',map_location='cuda:0'))

with open('./Dataset/Phenylboronic_acid_smiles(22).txt') as f: #acidic_test_0.15_smiles.txt,SAMPL6_test_acidic.txt
    pred = []
    label = []
    for line in f.readlines():
        line = line.replace('\n','').split('\t')
        molecule_pka,atom_pka = predict(line[0],acid_pred)
        #print(molecule_pka,float(line[1]))
        pred.append(molecule_pka)
        label.append(float(line[1]))

print(mean_absolute_error(pred,label))
print(mean_squared_error(pred,label)**0.5)
print(r2_score(pred,label))

0.3665171753276478
0.5355764761738165
0.8487965419847198


In [None]:
如果RMSE能够稳在0.72一下，将很好

In [None]:
# 一般到后面才开始学smalp6?

这份，非常稳定，但感觉值偏大？
不能降到0.5？


这个感觉还算稳定了


先看看模型的泛化能力再决定要不要主动做聚类

In [None]:
感觉不稳定的测试还是尽量避免