In [1]:
from My_Pka_Model import Pka_basic,Pka_acidic
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
epoch_num = 1500
layer_num = 6
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
model = Pka_basic(node_feat_size = 74,
                   edge_feat_size = 12,
                   output_size = 1,
                   num_layers= layer_num,
                   graph_feat_size=200,
                   dropout=0.2).to(device)

In [6]:
train_loader = load_data('./Dataset/basic_train_0.70_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

5905


In [7]:
task_name = 'basic_ramdom_split_3'

In [8]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [9]:
train_RMSE_lis = []
train_MAE_lis = []

file_name = './Logger/{}.txt'.format(task_name)

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')
    
for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

epoch:	train_RMSD:	train_MAE:




0	6.919	6.303
1	6.7705	6.1485
2	6.2629	5.6204
3	4.9035	4.2453
4	4.1254	3.5188
5	3.8371	3.2618
6	3.6699	3.1172
7	3.5408	3.0058
8	3.4332	2.9112
9	3.3427	2.8303
10	3.2681	2.7614
11	3.2102	2.7044
12	3.1699	2.661
13	3.1453	2.6302
14	3.1315	2.6085
15	3.1249	2.5947
16	3.122	2.5861
17	3.1209	2.581
18	3.1206	2.5783
19	3.1206	2.578
20	3.1207	2.5794
21	3.1207	2.5793
22	3.1208	2.5799
23	3.1208	2.5801
24	3.1208	2.5803
25	3.1209	2.5808
26	3.1208	2.5803
27	3.1209	2.5806
28	3.1208	2.5804
29	3.1208	2.5804
30	3.1208	2.5805
31	3.1208	2.5805
32	3.1207	2.5797
33	3.1208	2.5802
34	3.1207	2.5798
35	3.1207	2.5792
36	3.1207	2.579
37	3.1208	2.5801
38	3.1209	2.5805
39	3.1208	2.5805
40	3.1208	2.5803
41	3.1208	2.5803
42	3.1209	2.5805
43	3.121	2.5812
44	3.1209	2.5809
45	3.1208	2.5804
46	3.1209	2.5806
47	3.1209	2.5809
48	3.1209	2.5806
49	3.1209	2.5809
50	3.121	2.5813
51	3.1209	2.5807
52	3.1209	2.5807
53	3.1209	2.5806
54	3.1209	2.5806
55	3.1208	2.5804
56	3.1207	2.5795
57	3.1208	2.5798
58	3.1208	2.5798
59	3.1208	2.58
6

466	0.7958	0.5673
467	0.7914	0.5684
468	0.773	0.5442
469	0.795	0.5693
470	0.7713	0.544
471	0.7727	0.5444
472	0.7727	0.5451
473	0.765	0.5389
474	0.7697	0.5411
475	0.7655	0.5404
476	0.7655	0.5394
477	0.7791	0.5526
478	0.7645	0.5355
479	0.7634	0.5391
480	0.7595	0.5348
481	0.7649	0.5412
482	0.7635	0.5402
483	0.7738	0.5497
484	0.7621	0.5364
485	0.7589	0.5358
486	0.7576	0.5332
487	0.7584	0.5353
488	0.757	0.5324
489	0.7607	0.5372
490	0.7621	0.5383
491	0.7553	0.531
492	0.7656	0.5443
493	0.7563	0.5353
494	0.7561	0.5332
495	0.7525	0.529
496	0.7567	0.5354
497	0.7538	0.5311
498	0.7542	0.5346
499	0.7519	0.5274
500	0.7507	0.5324
501	0.7489	0.526
502	0.7526	0.5326
503	0.7618	0.5412
504	0.7657	0.5449
505	0.8151	0.5929
506	0.7835	0.5689
507	0.8071	0.5904
508	0.7517	0.5329
509	0.748	0.5253
510	0.7462	0.5271
511	0.7567	0.5313
512	0.7524	0.5332
513	0.7449	0.5222
514	0.7436	0.5233
515	0.7474	0.5275
516	0.7429	0.5238
517	0.7484	0.5283
518	0.7396	0.5206
519	0.7389	0.5197
520	0.7392	0.5205
521	0.7414	0.5241
5

927	0.6069	0.4282
928	0.5891	0.4079
929	0.591	0.4103
930	0.5862	0.4054
931	0.588	0.4071
932	0.5952	0.4165
933	0.6609	0.4919
934	0.6039	0.4273
935	0.5901	0.4094
936	0.5869	0.4076
937	0.5995	0.4179
938	0.5945	0.4145
939	0.5901	0.4108
940	0.5878	0.4098
941	0.5883	0.4071
942	0.5981	0.417
943	0.5923	0.4168
944	0.6097	0.4305
945	0.6028	0.4244
946	0.6285	0.4555
947	0.6092	0.4316
948	0.592	0.4152
949	0.5884	0.4094
950	0.5996	0.4236
951	0.5913	0.4141
952	0.6146	0.4395
953	0.5904	0.4142
954	0.6012	0.4174
955	0.5978	0.4228
956	0.5925	0.4152
957	0.5863	0.4095
958	0.5867	0.408
959	0.612	0.4358
960	0.5885	0.4118
961	0.5886	0.4083
962	0.5798	0.4035
963	0.5859	0.4096
964	0.5831	0.4083
965	0.5955	0.4229
966	0.5805	0.4044
967	0.5842	0.4095
968	0.5849	0.4061
969	0.5833	0.4051
970	0.5806	0.4037
971	0.5839	0.4055
972	0.5908	0.4146
973	0.5838	0.4066
974	0.6151	0.4415
975	0.5957	0.4197
976	0.5889	0.4137
977	0.5788	0.4018
978	0.5795	0.4007
979	0.5975	0.4272
980	0.5753	0.3969
981	0.5937	0.4198
982	0.5886	0.415

1367	0.4962	0.344
1368	0.5011	0.3528
1369	0.4943	0.3426
1370	0.4894	0.3372
1371	0.4981	0.3475
1372	0.4877	0.3359
1373	0.486	0.334
1374	0.4884	0.337
1375	0.4925	0.3402
1376	0.5196	0.376
1377	0.5009	0.3517
1378	0.5092	0.361
1379	0.4903	0.3405
1380	0.4898	0.3387
1381	0.4958	0.3485
1382	0.4942	0.3437
1383	0.4879	0.3364
1384	0.5141	0.3708
1385	0.4912	0.3403
1386	0.4869	0.3395
1387	0.4865	0.3345
1388	0.5027	0.3552
1389	0.4943	0.3397
1390	0.4913	0.3425
1391	0.4865	0.3361
1392	0.497	0.3467
1393	0.5019	0.355
1394	0.494	0.3397
1395	0.5013	0.3515
1396	0.4952	0.3443
1397	0.5186	0.37
1398	0.4947	0.3413
1399	0.4938	0.3424
1400	0.4913	0.3436
1401	0.4994	0.3513
1402	0.4866	0.3365
1403	0.4832	0.3328
1404	0.4974	0.3466
1405	0.4858	0.3365
1406	0.4914	0.3451
1407	0.5007	0.3496
1408	0.5028	0.359
1409	0.4842	0.3343
1410	0.4827	0.3305
1411	0.4952	0.3491
1412	0.4796	0.3299
1413	0.4867	0.3372
1414	0.4872	0.3412
1415	0.5399	0.4006
1416	0.4899	0.3402
1417	0.4955	0.3468
1418	0.4848	0.338
1419	0.4855	0.3366
1420	0

In [10]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [11]:
for epoch in range(epoch_num,epoch_num+500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1500	0.4652	0.3182
1501	0.4691	0.3229
1502	0.4605	0.3124
1503	0.4623	0.3165
1504	0.4624	0.3159
1505	0.4607	0.3141
1506	0.4598	0.3131
1507	0.4591	0.3123
1508	0.4616	0.3163
1509	0.4574	0.3106
1510	0.4592	0.3112
1511	0.4595	0.3121
1512	0.4599	0.3129
1513	0.4587	0.3117
1514	0.4575	0.3111
1515	0.4629	0.3184
1516	0.4578	0.3117
1517	0.4584	0.3126
1518	0.4599	0.3138
1519	0.4574	0.3096
1520	0.4582	0.3125
1521	0.4641	0.32
1522	0.4566	0.3097
1523	0.4633	0.32
1524	0.4594	0.3152
1525	0.4566	0.3102
1526	0.4593	0.3134
1527	0.457	0.3097
1528	0.4597	0.3113
1529	0.4645	0.3215
1530	0.4558	0.309
1531	0.4557	0.31
1532	0.4584	0.3137
1533	0.455	0.3088
1534	0.4569	0.3114
1535	0.4553	0.3089
1536	0.4555	0.3094
1537	0.4581	0.3132
1538	0.4552	0.3089
1539	0.4629	0.3187
1540	0.4563	0.3099
1541	0.4544	0.3082
1542	0.4577	0.313
1543	0.4559	0.31
1544	0.4545	0.3087
1545	0.4571	0.3129
1546	0.4546	0.3091
1547	0.4543	0.3074
1548	0.4584	0.3132
1549	0.4539	0.3079
1550	0.4556	0.31
1551	0.4537	0.3076
1552	0.4602	0.3156
1553	0.

1936	0.4395	0.2962
1937	0.4445	0.3038
1938	0.4405	0.2968
1939	0.4407	0.2983
1940	0.4495	0.3106
1941	0.4431	0.3007
1942	0.4414	0.2976
1943	0.4407	0.2981
1944	0.4471	0.3068
1945	0.4404	0.2986
1946	0.4405	0.2971
1947	0.4425	0.301
1948	0.444	0.3018
1949	0.4438	0.3007
1950	0.4527	0.313
1951	0.4422	0.3005
1952	0.4436	0.3005
1953	0.4473	0.3061
1954	0.4408	0.2984
1955	0.4406	0.2978
1956	0.4451	0.3039
1957	0.4401	0.2967
1958	0.4407	0.2974
1959	0.4431	0.3017
1960	0.4416	0.2995
1961	0.4409	0.2992
1962	0.4401	0.2966
1963	0.4404	0.2972
1964	0.4385	0.2961
1965	0.4438	0.3037
1966	0.4429	0.3026
1967	0.4409	0.2985
1968	0.441	0.2986
1969	0.4399	0.2978
1970	0.4394	0.2966
1971	0.4533	0.3151
1972	0.4431	0.2995
1973	0.4398	0.2972
1974	0.444	0.303
1975	0.4404	0.2988
1976	0.4388	0.2969
1977	0.4409	0.2999
1978	0.4438	0.3042
1979	0.4399	0.2969
1980	0.4423	0.3014
1981	0.4482	0.3092
1982	0.4405	0.2985
1983	0.4418	0.3013
1984	0.4438	0.3014
1985	0.4392	0.2964
1986	0.4443	0.3031
1987	0.4402	0.298
1988	0.4416	0.3002


In [12]:
torch.save(model.state_dict(), './Trained_model/{}.pkl'.format(task_name))