In [1]:
from My_Pka_Model import Pka_basic,Pka_acidic
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
epoch_num = 1500
layer_num = 6
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
model = Pka_acidic(node_feat_size = 74,
                   edge_feat_size = 12,
                   output_size = 1,
                   num_layers= layer_num,
                   graph_feat_size=200,
                   dropout=0.2).to(device)

In [6]:
train_loader = load_data('./Dataset/site_acidic_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

8836


In [7]:
task_name = 'site_acidic_new_2'

In [8]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [9]:
train_RMSE_lis = []
train_MAE_lis = []

file_name = './Logger/{}.txt'.format(task_name)

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')
    
for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')
                  

epoch:	train_RMSD:	train_MAE:




0	7.0257	6.1157
1	6.44	5.4419
2	4.6517	3.3953
3	4.1168	2.9531
4	3.9019	2.848
5	3.749	2.8109
6	3.6326	2.8046
7	3.5493	2.815
8	3.5008	2.8338
9	3.4804	2.853
10	3.4741	2.8686
11	3.4732	2.8786
12	3.4733	2.8829
13	3.4734	2.8841
14	3.4734	2.8833
15	3.4734	2.8833
16	3.4733	2.8821
17	3.4731	2.8793
18	3.4731	2.8793
19	3.4732	2.8818
20	3.4732	2.8817
21	3.4731	2.8802
22	3.473	2.8805
23	3.4727	2.8797
24	3.4725	2.8806
25	3.4722	2.8796
26	3.4718	2.8799
27	3.4713	2.8778
28	3.4708	2.8774
29	3.47	2.8779
30	3.4685	2.8747
31	3.4661	2.8724
32	3.4619	2.867
33	3.4558	2.8606
34	3.4468	2.8528
35	3.4349	2.8411
36	3.4167	2.8214
37	3.3896	2.7898
38	3.3385	2.7226
39	3.2562	2.6158
40	3.1002	2.3998
41	2.9086	2.1329
42	2.7346	1.9085
43	2.6133	1.758
44	2.5571	1.7013
45	2.5163	1.6607
46	2.4976	1.641
47	2.4716	1.6219
48	2.4436	1.6034
49	2.3948	1.5675
50	2.3829	1.561
51	2.3684	1.5448
52	2.3619	1.5443
53	2.3463	1.5277
54	2.3245	1.5104
55	2.3073	1.4882
56	2.3048	1.4988
57	2.2867	1.4729
58	2.2713	1.4606
59	2.2577	1.4519
60	

468	0.8106	0.5524
469	0.8161	0.5484
470	0.81	0.5552
471	0.823	0.5697
472	0.8078	0.5518
473	0.8118	0.5546
474	0.8007	0.5431
475	0.797	0.5392
476	0.818	0.5531
477	0.7841	0.5264
478	0.7797	0.5221
479	0.7827	0.5258
480	0.7922	0.5351
481	0.8033	0.5491
482	0.7811	0.5263
483	0.7832	0.5277
484	0.8004	0.5505
485	0.7919	0.543
486	0.7947	0.5449
487	0.7896	0.5334
488	0.7726	0.5222
489	0.7776	0.5285
490	0.7747	0.5247
491	0.7861	0.5375
492	0.7749	0.5266
493	0.7739	0.5253
494	0.7797	0.5319
495	0.7707	0.5255
496	0.7904	0.5403
497	0.7754	0.5318
498	0.7918	0.5439
499	0.7674	0.5221
500	0.7915	0.5492
501	0.7776	0.5327
502	0.8078	0.5648
503	0.7672	0.5204
504	0.7848	0.54
505	0.7644	0.5199
506	0.773	0.527
507	0.8144	0.5763
508	0.7959	0.5591
509	0.8312	0.5909
510	0.7735	0.5325
511	0.7636	0.5182
512	0.7615	0.5194
513	0.7616	0.517
514	0.7611	0.5193
515	0.7774	0.537
516	0.7577	0.5193
517	0.7547	0.5116
518	0.7756	0.5311
519	0.7714	0.5307
520	0.7557	0.5139
521	0.7505	0.5132
522	0.8036	0.5673
523	0.7529	0.5115
524	

929	0.6157	0.434
930	0.605	0.4255
931	0.6135	0.4311
932	0.5893	0.4089
933	0.588	0.4067
934	0.5868	0.4062
935	0.6053	0.4254
936	0.6121	0.431
937	0.5876	0.4081
938	0.587	0.4076
939	0.6036	0.4251
940	0.589	0.4086
941	0.5945	0.4125
942	0.5921	0.4102
943	0.5875	0.4081
944	0.5958	0.4137
945	0.5896	0.4117
946	0.5854	0.4062
947	0.5827	0.404
948	0.5909	0.4113
949	0.5891	0.4107
950	0.5937	0.4167
951	0.5942	0.4149
952	0.6009	0.4202
953	0.599	0.4233
954	0.5909	0.4104
955	0.5965	0.4197
956	0.5941	0.4149
957	0.5839	0.4057
958	0.5898	0.4116
959	0.5771	0.3988
960	0.5802	0.401
961	0.6031	0.4247
962	0.6349	0.4567
963	0.6199	0.4401
964	0.6312	0.4511
965	0.5998	0.4198
966	0.5861	0.4096
967	0.5972	0.4196
968	0.5909	0.4103
969	0.5954	0.4144
970	0.5883	0.4076
971	0.5859	0.4099
972	0.5849	0.4067
973	0.6017	0.422
974	0.6	0.421
975	0.5843	0.4058
976	0.5857	0.4079
977	0.6004	0.4284
978	0.591	0.4128
979	0.5819	0.403
980	0.5819	0.4039
981	0.5768	0.4012
982	0.5765	0.4
983	0.587	0.4079
984	0.5846	0.4096
985	0.6165	0

1370	0.5203	0.3591
1371	0.5092	0.351
1372	0.506	0.3484
1373	0.5221	0.3687
1374	0.5134	0.3533
1375	0.5046	0.3485
1376	0.5275	0.3719
1377	0.5299	0.3695
1378	0.5221	0.3634
1379	0.5164	0.3558
1380	0.5396	0.3771
1381	0.5179	0.3606
1382	0.5412	0.3869
1383	0.5217	0.3621
1384	0.5162	0.3599
1385	0.5032	0.3445
1386	0.5066	0.3477
1387	0.505	0.3487
1388	0.5072	0.3516
1389	0.5059	0.349
1390	0.512	0.3537
1391	0.5586	0.4061
1392	0.5404	0.3855
1393	0.5079	0.3499
1394	0.5102	0.3531
1395	0.5052	0.3465
1396	0.5051	0.349
1397	0.5074	0.348
1398	0.5047	0.3475
1399	0.4999	0.3416
1400	0.5042	0.3469
1401	0.5009	0.3451
1402	0.5104	0.3559
1403	0.5074	0.3515
1404	0.5073	0.3504
1405	0.496	0.3385
1406	0.4977	0.3416
1407	0.4981	0.3399
1408	0.5402	0.3887
1409	0.5356	0.3821
1410	0.498	0.3419
1411	0.5138	0.3552
1412	0.5047	0.3486
1413	0.5253	0.3723
1414	0.5077	0.3524
1415	0.5009	0.3431
1416	0.4994	0.3441
1417	0.5108	0.3545
1418	0.502	0.3457
1419	0.5167	0.3611
1420	0.502	0.3449
1421	0.5002	0.3443
1422	0.4986	0.34
1423	0

In [10]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [11]:
for epoch in range(epoch_num,epoch_num+500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1500	0.4777	0.324
1501	0.4755	0.3208
1502	0.4725	0.319
1503	0.4727	0.3187
1504	0.4747	0.3203
1505	0.4723	0.3183
1506	0.471	0.3175
1507	0.472	0.318
1508	0.4715	0.3174
1509	0.4718	0.3177
1510	0.4737	0.3195
1511	0.4713	0.3172
1512	0.4699	0.316
1513	0.4698	0.3159
1514	0.4693	0.315
1515	0.4701	0.316
1516	0.4706	0.3167
1517	0.4687	0.3149
1518	0.4682	0.3146
1519	0.4681	0.3139
1520	0.4692	0.3158
1521	0.4725	0.3199
1522	0.4706	0.3177
1523	0.4698	0.3167
1524	0.4679	0.3138
1525	0.4687	0.3146
1526	0.4696	0.3153
1527	0.4711	0.3179
1528	0.4683	0.3151
1529	0.4682	0.3139
1530	0.4691	0.3147
1531	0.4702	0.3164
1532	0.4677	0.3137
1533	0.4685	0.3149
1534	0.4672	0.3142
1535	0.4686	0.3153
1536	0.4677	0.3138
1537	0.4674	0.3129
1538	0.4665	0.3123
1539	0.4669	0.3132
1540	0.4685	0.314
1541	0.4663	0.3124
1542	0.4669	0.313
1543	0.4697	0.3176
1544	0.4685	0.3158
1545	0.4668	0.3132
1546	0.4659	0.3126
1547	0.4664	0.3131
1548	0.4663	0.3129
1549	0.4667	0.3126
1550	0.4673	0.3138
1551	0.4678	0.3141
1552	0.4676	0.3143
155

1937	0.4529	0.3026
1938	0.4538	0.3028
1939	0.4526	0.3012
1940	0.4543	0.3037
1941	0.4539	0.3028
1942	0.4532	0.302
1943	0.4523	0.3014
1944	0.4517	0.3019
1945	0.4524	0.3019
1946	0.4522	0.3015
1947	0.4521	0.3017
1948	0.4518	0.3005
1949	0.4512	0.3005
1950	0.4519	0.3017
1951	0.4521	0.3018
1952	0.4526	0.3024
1953	0.4538	0.3032
1954	0.4539	0.3037
1955	0.4521	0.3023
1956	0.4528	0.3024
1957	0.4527	0.3032
1958	0.4524	0.3022
1959	0.4523	0.3004
1960	0.4557	0.3061
1961	0.4573	0.3075
1962	0.4532	0.3024
1963	0.4535	0.3034
1964	0.4534	0.3032
1965	0.4514	0.3013
1966	0.4515	0.3012
1967	0.4521	0.3017
1968	0.4529	0.303
1969	0.4593	0.3104
1970	0.4567	0.3072
1971	0.4519	0.3012
1972	0.4516	0.3013
1973	0.4523	0.3014
1974	0.4516	0.3014
1975	0.4522	0.3016
1976	0.4509	0.3001
1977	0.4562	0.3063
1978	0.4548	0.3044
1979	0.4523	0.3012
1980	0.4521	0.3009
1981	0.4511	0.3002
1982	0.4523	0.3021
1983	0.452	0.3024
1984	0.4527	0.3034
1985	0.4519	0.3022
1986	0.4538	0.3041
1987	0.4519	0.3031
1988	0.4522	0.3021
1989	0.454	0.30

In [12]:
torch.save(model.state_dict(), './Trained_model/{}.pkl'.format(task_name))