In [1]:
from My_Pka_Model import Pka_basic,Pka_acidic
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
epoch_num = 1500
layer_num = 6
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
model = Pka_acidic(node_feat_size = 74,
                   edge_feat_size = 12,
                   output_size = 1,
                   num_layers= layer_num,
                   graph_feat_size=200,
                   dropout=0.2).to(device)

In [6]:
train_loader = load_data('./Dataset/acidic_train_0.70_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

6337


In [7]:
task_name = 'acidic_ramdom_split_3'

In [8]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [9]:
train_RMSE_lis = []
train_MAE_lis = []

file_name = './Logger/{}.txt'.format(task_name)

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')
    
for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

epoch:	train_RMSD:	train_MAE:




0	7.0241	6.1321
1	6.7996	5.8762
2	5.7916	4.7034
3	4.5039	3.2705
4	4.104	2.9504
5	3.9183	2.8518
6	3.7846	2.8069
7	3.6771	2.7892
8	3.5908	2.7887
9	3.5246	2.7979
10	3.4806	2.813
11	3.4558	2.8303
12	3.4444	2.8455
13	3.4403	2.8574
14	3.4391	2.8656
15	3.4391	2.8714
16	3.4393	2.8752
17	3.4395	2.877
18	3.4396	2.8772
19	3.4395	2.877
20	3.4393	2.8751
21	3.4392	2.874
22	3.4391	2.8721
23	3.439	2.8707
24	3.439	2.8706
25	3.439	2.8716
26	3.4391	2.8724
27	3.4391	2.873
28	3.4389	2.8717
29	3.4387	2.8697
30	3.4386	2.8695
31	3.4384	2.8688
32	3.4382	2.8686
33	3.438	2.8688
34	3.4378	2.8704
35	3.4377	2.8732
36	3.4374	2.8733
37	3.4367	2.8698
38	3.4361	2.8702
39	3.4348	2.8683
40	3.4328	2.8659
41	3.4303	2.8633
42	3.4265	2.8569
43	3.4214	2.8507
44	3.4142	2.8435
45	3.4048	2.8348
46	3.3927	2.8241
47	3.3749	2.8042
48	3.3469	2.768
49	3.3035	2.7156
50	3.2441	2.6463
51	3.1448	2.5057
52	3.0117	2.3182
53	2.8593	2.1028
54	2.717	1.9091
55	2.6144	1.7942
56	2.5516	1.7298
57	2.5036	1.6854
58	2.4756	1.6462
59	2.457	1.627
60	2

468	0.9194	0.6135
469	0.9184	0.6147
470	0.9217	0.6077
471	0.9391	0.6368
472	0.9324	0.6316
473	0.9131	0.6097
474	0.9086	0.6056
475	0.8976	0.5988
476	0.8989	0.5911
477	0.8948	0.5891
478	0.897	0.5896
479	0.8987	0.5944
480	0.909	0.611
481	0.9031	0.6042
482	0.9424	0.6292
483	0.8927	0.5914
484	0.8842	0.5865
485	0.8936	0.5896
486	0.8837	0.5792
487	0.8845	0.5881
488	0.9124	0.6079
489	0.9392	0.6283
490	0.8995	0.6026
491	0.9058	0.6047
492	0.9064	0.6075
493	0.8836	0.5862
494	0.8777	0.576
495	0.8782	0.5799
496	0.8797	0.5831
497	0.8793	0.5826
498	0.8825	0.5867
499	0.8722	0.5752
500	0.8799	0.584
501	0.873	0.5794
502	0.8738	0.5775
503	0.8881	0.5914
504	0.8685	0.5767
505	0.8685	0.5756
506	0.887	0.5932
507	0.8704	0.5764
508	0.8907	0.5929
509	0.8682	0.5713
510	0.871	0.5728
511	0.8647	0.5736
512	0.924	0.6314
513	0.8746	0.578
514	0.8773	0.5892
515	0.8746	0.59
516	0.8627	0.5746
517	0.8565	0.5666
518	0.8737	0.5789
519	0.8665	0.5753
520	0.8882	0.6008
521	0.8645	0.578
522	0.8553	0.5695
523	0.8567	0.5692
524	0

930	0.6277	0.4377
931	0.6256	0.4368
932	0.6431	0.4539
933	0.6438	0.4554
934	0.626	0.4412
935	0.6347	0.4486
936	0.6252	0.437
937	0.6232	0.4376
938	0.6776	0.4881
939	0.6844	0.5042
940	0.6623	0.4769
941	0.7019	0.5178
942	0.6318	0.448
943	0.628	0.4433
944	0.6371	0.4506
945	0.6318	0.4433
946	0.6239	0.4412
947	0.6314	0.4454
948	0.6253	0.4411
949	0.6204	0.438
950	0.625	0.4401
951	0.618	0.4331
952	0.6345	0.4527
953	0.6386	0.4547
954	0.6315	0.4452
955	0.6264	0.4408
956	0.6239	0.44
957	0.6124	0.4291
958	0.6176	0.4336
959	0.6103	0.4285
960	0.6177	0.4338
961	0.6126	0.4295
962	0.662	0.4681
963	0.647	0.4598
964	0.671	0.4921
965	0.6534	0.4686
966	0.6334	0.4523
967	0.643	0.459
968	0.6394	0.4589
969	0.6395	0.4606
970	0.6238	0.4403
971	0.6398	0.452
972	0.6803	0.503
973	0.6247	0.4477
974	0.6237	0.4409
975	0.6117	0.4298
976	0.6104	0.4286
977	0.6154	0.4328
978	0.6043	0.4228
979	0.6376	0.4551
980	0.6893	0.5045
981	0.6564	0.4728
982	0.6247	0.4413
983	0.6292	0.4485
984	0.6066	0.4263
985	0.6067	0.4273
986	0.61

1371	0.5257	0.3744
1372	0.5411	0.3849
1373	0.5152	0.3602
1374	0.5194	0.3666
1375	0.52	0.3643
1376	0.5441	0.3916
1377	0.5232	0.3701
1378	0.5189	0.3666
1379	0.5154	0.3604
1380	0.5407	0.3849
1381	0.5219	0.367
1382	0.5188	0.3691
1383	0.5266	0.3764
1384	0.5124	0.3571
1385	0.5222	0.3726
1386	0.5126	0.3593
1387	0.518	0.3638
1388	0.5157	0.3614
1389	0.5128	0.3599
1390	0.5175	0.3611
1391	0.513	0.3576
1392	0.5128	0.359
1393	0.5221	0.365
1394	0.534	0.3761
1395	0.5192	0.3643
1396	0.5183	0.3644
1397	0.5256	0.3728
1398	0.5083	0.3536
1399	0.5346	0.379
1400	0.5734	0.4173
1401	0.5155	0.3668
1402	0.5162	0.3629
1403	0.5166	0.3647
1404	0.5226	0.3666
1405	0.5146	0.3638
1406	0.5101	0.3551
1407	0.5164	0.3655
1408	0.5099	0.358
1409	0.5235	0.3727
1410	0.5359	0.3785
1411	0.5317	0.3784
1412	0.5506	0.3967
1413	0.5378	0.3836
1414	0.5404	0.3877
1415	0.5134	0.3616
1416	0.5344	0.3802
1417	0.5347	0.381
1418	0.5415	0.3865
1419	0.5469	0.4006
1420	0.5758	0.4295
1421	0.5367	0.3824
1422	0.5564	0.4001
1423	0.5563	0.3981
1424

In [10]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [11]:
for epoch in range(epoch_num,epoch_num+500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1500	0.4902	0.3395
1501	0.4873	0.3365
1502	0.4833	0.3338
1503	0.4808	0.331
1504	0.4796	0.3308
1505	0.4788	0.329
1506	0.4769	0.3263
1507	0.4764	0.3262
1508	0.4778	0.3278
1509	0.4757	0.3258
1510	0.4764	0.3265
1511	0.4775	0.3278
1512	0.4786	0.3292
1513	0.4767	0.327
1514	0.478	0.3281
1515	0.4774	0.3287
1516	0.4762	0.3257
1517	0.4768	0.3268
1518	0.4754	0.3247
1519	0.4742	0.3236
1520	0.4739	0.3239
1521	0.4734	0.3244
1522	0.4735	0.3244
1523	0.4758	0.3261
1524	0.4734	0.3237
1525	0.4728	0.3235
1526	0.4763	0.3279
1527	0.4731	0.3238
1528	0.473	0.3234
1529	0.475	0.3254
1530	0.4728	0.3224
1531	0.4719	0.3223
1532	0.4723	0.3224
1533	0.4731	0.3237
1534	0.473	0.3237
1535	0.4743	0.3252
1536	0.4713	0.3218
1537	0.4711	0.3216
1538	0.4715	0.3214
1539	0.4725	0.3225
1540	0.4779	0.3306
1541	0.4743	0.3252
1542	0.4717	0.3229
1543	0.4718	0.3224
1544	0.4706	0.3206
1545	0.4711	0.3205
1546	0.4722	0.323
1547	0.4729	0.3231
1548	0.472	0.3232
1549	0.4703	0.3206
1550	0.4733	0.3233
1551	0.4728	0.3241
1552	0.4714	0.3216
15

1936	0.4571	0.312
1937	0.4577	0.3124
1938	0.4538	0.309
1939	0.4526	0.3074
1940	0.4528	0.3074
1941	0.4538	0.3085
1942	0.4544	0.3081
1943	0.4543	0.3088
1944	0.454	0.3086
1945	0.4535	0.3077
1946	0.4569	0.3123
1947	0.4537	0.3088
1948	0.4528	0.3075
1949	0.4531	0.3087
1950	0.4541	0.3092
1951	0.4545	0.309
1952	0.4529	0.3075
1953	0.4574	0.3132
1954	0.4547	0.3092
1955	0.4525	0.3072
1956	0.4525	0.307
1957	0.4528	0.3076
1958	0.4531	0.3086
1959	0.4531	0.3079
1960	0.4532	0.308
1961	0.456	0.3118
1962	0.4547	0.3107
1963	0.4533	0.3085
1964	0.4527	0.3077
1965	0.4533	0.3088
1966	0.4537	0.3086
1967	0.4527	0.3072
1968	0.4535	0.3085
1969	0.4531	0.308
1970	0.4536	0.3094
1971	0.453	0.3082
1972	0.4516	0.3061
1973	0.4534	0.3078
1974	0.4533	0.3076
1975	0.4527	0.3069
1976	0.4519	0.306
1977	0.4575	0.3137
1978	0.4546	0.3089
1979	0.4519	0.3069
1980	0.4525	0.3077
1981	0.4548	0.3106
1982	0.4552	0.3109
1983	0.4579	0.3128
1984	0.4531	0.3081
1985	0.4535	0.3087
1986	0.4545	0.3094
1987	0.4515	0.306
1988	0.4536	0.308
1989	

In [12]:
torch.save(model.state_dict(), './Trained_model/{}.pkl'.format(task_name))