In [1]:
from My_Pka_Model import Pka_basic,Pka_acidic
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
epoch_num = 1500
layer_num = 6
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
model = Pka_basic(node_feat_size = 74,
                   edge_feat_size = 12,
                   output_size = 1,
                   num_layers= layer_num,
                   graph_feat_size=200,
                   dropout=0.2).to(device)

In [6]:
train_loader = load_data('./Dataset/site_basic_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

8170


In [7]:
task_name = 'site_basic_1'

In [8]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [9]:
train_RMSE_lis = []
train_MAE_lis = []

file_name = './Logger/{}.txt'.format(task_name)

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')
    
for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

epoch:	train_RMSD:	train_MAE:




0	6.9227	6.3043
1	6.5571	5.925
2	4.9385	4.2855
3	4.0379	3.4454
4	3.758	3.1989
5	3.5754	3.0409
6	3.4318	2.9148
7	3.3184	2.8122
8	3.2332	2.7292
9	3.1799	2.6704
10	3.1521	2.6322
11	3.1402	2.6096
12	3.1361	2.5973
13	3.135	2.5914
14	3.1348	2.5895
15	3.1348	2.5892
16	3.1348	2.5896
17	3.1349	2.5908
18	3.1349	2.5914
19	3.1351	2.5925
20	3.135	2.5925
21	3.1349	2.5927
22	3.1344	2.591
23	3.1339	2.5907
24	3.1331	2.591
25	3.1313	2.5888
26	3.1282	2.587
27	3.1223	2.5823
28	3.1124	2.5736
29	3.0975	2.5595
30	3.0789	2.54
31	3.0565	2.5209
32	3.0222	2.489
33	2.9805	2.4503
34	2.9342	2.4045
35	2.8826	2.3524
36	2.8097	2.2815
37	2.7308	2.2066
38	2.6091	2.0973
39	2.4928	1.9862
40	2.3796	1.8768
41	2.3033	1.7864
42	2.2464	1.7105
43	2.2054	1.6805
44	2.1778	1.6303
45	2.1571	1.6334
46	2.1329	1.6052
47	2.118	1.5905
48	2.1189	1.6013
49	2.1001	1.5745
50	2.092	1.5672
51	2.0849	1.5633
52	2.0774	1.5493
53	2.0849	1.5678
54	2.0689	1.5458
55	2.069	1.5515
56	2.062	1.5426
57	2.0624	1.5489
58	2.0488	1.5273
59	2.0425	1.5121
60	2

468	0.8135	0.5664
469	0.7943	0.553
470	0.8056	0.5624
471	0.7993	0.5591
472	0.819	0.5754
473	0.7994	0.5579
474	0.7857	0.5376
475	0.796	0.5549
476	0.7848	0.5388
477	0.7961	0.5497
478	0.8014	0.5565
479	0.7815	0.5343
480	0.7752	0.5294
481	0.7911	0.5468
482	0.7692	0.5244
483	0.7927	0.553
484	0.776	0.5316
485	0.7733	0.5284
486	0.7932	0.5528
487	0.8179	0.5782
488	0.7835	0.5417
489	0.7636	0.5187
490	0.7696	0.5227
491	0.7647	0.5227
492	0.7699	0.5256
493	0.7732	0.5315
494	0.779	0.528
495	0.7674	0.5283
496	0.7692	0.5291
497	0.7693	0.5226
498	0.761	0.517
499	0.7693	0.5236
500	0.758	0.5178
501	0.7641	0.5206
502	0.7638	0.5205
503	0.7677	0.5258
504	0.8338	0.5978
505	0.7684	0.5314
506	0.7694	0.5345
507	0.7654	0.5274
508	0.7891	0.5535
509	0.7698	0.5334
510	0.7605	0.525
511	0.7675	0.5307
512	0.7575	0.5212
513	0.7575	0.5158
514	0.7556	0.5137
515	0.7534	0.5123
516	0.7523	0.5141
517	0.751	0.5132
518	0.7541	0.5172
519	0.7603	0.5267
520	0.7784	0.5433
521	0.7667	0.5324
522	0.7792	0.5436
523	0.7638	0.5287
524	

928	0.5866	0.4083
929	0.5888	0.4117
930	0.5698	0.3887
931	0.5868	0.4054
932	0.5746	0.3944
933	0.5991	0.4231
934	0.6136	0.4344
935	0.607	0.4316
936	0.6085	0.4272
937	0.5673	0.3859
938	0.5766	0.4017
939	0.5742	0.3921
940	0.5759	0.3965
941	0.5658	0.3853
942	0.5962	0.4187
943	0.5769	0.3953
944	0.5622	0.3848
945	0.5783	0.3989
946	0.5683	0.3882
947	0.5696	0.3914
948	0.5849	0.4117
949	0.5694	0.39
950	0.5648	0.3837
951	0.5634	0.3818
952	0.5624	0.3827
953	0.5696	0.3916
954	0.5949	0.417
955	0.5749	0.3947
956	0.5624	0.3831
957	0.5677	0.3884
958	0.5661	0.3897
959	0.5665	0.3877
960	0.5771	0.3997
961	0.5736	0.3982
962	0.6362	0.4671
963	0.5885	0.4092
964	0.5742	0.3949
965	0.5633	0.3831
966	0.5706	0.3944
967	0.5811	0.4066
968	0.56	0.3808
969	0.5747	0.3979
970	0.5628	0.3822
971	0.5584	0.3795
972	0.5615	0.3837
973	0.5636	0.3852
974	0.5639	0.3878
975	0.5652	0.3864
976	0.5574	0.3757
977	0.5657	0.3859
978	0.5736	0.3948
979	0.5637	0.3847
980	0.5601	0.3835
981	0.5952	0.4233
982	0.5913	0.4208
983	0.5631	0.386

1369	0.5162	0.3599
1370	0.5035	0.3494
1371	0.4923	0.3353
1372	0.5092	0.3557
1373	0.5008	0.3401
1374	0.4839	0.3254
1375	0.5122	0.3589
1376	0.4884	0.3334
1377	0.4964	0.3408
1378	0.4782	0.3185
1379	0.4783	0.3189
1380	0.4872	0.3316
1381	0.4795	0.3213
1382	0.4822	0.3237
1383	0.4874	0.3283
1384	0.4811	0.325
1385	0.4947	0.3396
1386	0.488	0.3329
1387	0.4868	0.3305
1388	0.483	0.3246
1389	0.4916	0.3328
1390	0.5185	0.3703
1391	0.487	0.3288
1392	0.5051	0.3497
1393	0.4913	0.3342
1394	0.481	0.3213
1395	0.4825	0.3256
1396	0.4855	0.3293
1397	0.4872	0.3306
1398	0.4937	0.3389
1399	0.4784	0.3196
1400	0.4963	0.342
1401	0.4948	0.3399
1402	0.4966	0.3401
1403	0.4825	0.3252
1404	0.484	0.3292
1405	0.4894	0.3379
1406	0.484	0.3287
1407	0.4987	0.3437
1408	0.4779	0.3204
1409	0.4811	0.3256
1410	0.4813	0.3245
1411	0.4812	0.3237
1412	0.5137	0.3659
1413	0.4975	0.3431
1414	0.4927	0.3363
1415	0.4842	0.327
1416	0.4986	0.3446
1417	0.5152	0.3629
1418	0.4893	0.3352
1419	0.5051	0.3504
1420	0.4939	0.3384
1421	0.4939	0.3381
14

In [10]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [11]:
for epoch in range(epoch_num,epoch_num+500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N
     
        log = '{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4))
        print(log)

        with open(file_name,'a+') as f:
            f.write(log)
            f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1500	0.4647	0.3126
1501	0.4561	0.3019
1502	0.4537	0.3001
1503	0.4528	0.2988
1504	0.4523	0.299
1505	0.4523	0.2984
1506	0.4514	0.2971
1507	0.4518	0.2978
1508	0.4501	0.2959
1509	0.4507	0.2964
1510	0.45	0.2956
1511	0.4541	0.3011
1512	0.4498	0.2959
1513	0.4491	0.2947
1514	0.4496	0.2955
1515	0.4506	0.2969
1516	0.4482	0.2942
1517	0.4496	0.2965
1518	0.4497	0.2958
1519	0.4482	0.2944
1520	0.4503	0.2977
1521	0.4498	0.2973
1522	0.448	0.295
1523	0.4486	0.2954
1524	0.4492	0.2965
1525	0.448	0.294
1526	0.4497	0.2957
1527	0.4498	0.2966
1528	0.4471	0.2938
1529	0.4492	0.2958
1530	0.4495	0.2966
1531	0.4468	0.2926
1532	0.4473	0.2929
1533	0.4475	0.2947
1534	0.4472	0.2934
1535	0.4476	0.2944
1536	0.4461	0.2922
1537	0.4488	0.2957
1538	0.45	0.2976
1539	0.4473	0.2941
1540	0.4479	0.2939
1541	0.4454	0.292
1542	0.4456	0.292
1543	0.4459	0.2926
1544	0.4491	0.2961
1545	0.4475	0.2942
1546	0.45	0.2964
1547	0.4445	0.2909
1548	0.4468	0.2938
1549	0.4454	0.2923
1550	0.4454	0.2916
1551	0.4453	0.2927
1552	0.4469	0.2943
1553	0

1937	0.4309	0.2809
1938	0.4308	0.2814
1939	0.4323	0.2825
1940	0.4342	0.2858
1941	0.4365	0.2884
1942	0.4352	0.2864
1943	0.4307	0.2806
1944	0.4306	0.2814
1945	0.4328	0.2835
1946	0.4301	0.2802
1947	0.4315	0.2819
1948	0.4319	0.283
1949	0.4303	0.2805
1950	0.4298	0.2798
1951	0.4309	0.2816
1952	0.4319	0.2827
1953	0.4315	0.2823
1954	0.4294	0.2791
1955	0.4305	0.281
1956	0.4304	0.2813
1957	0.4301	0.2806
1958	0.4296	0.2807
1959	0.4319	0.2838
1960	0.4306	0.2813
1961	0.4292	0.2797
1962	0.4294	0.2802
1963	0.4303	0.2811
1964	0.4311	0.2817
1965	0.4324	0.2838
1966	0.4298	0.2801
1967	0.4301	0.2809
1968	0.4311	0.2824
1969	0.4305	0.2808
1970	0.432	0.2837
1971	0.4318	0.2832
1972	0.4304	0.2824
1973	0.4305	0.2823
1974	0.4304	0.2814
1975	0.4331	0.2851
1976	0.4338	0.2867
1977	0.4307	0.2829
1978	0.4308	0.2816
1979	0.4316	0.2836
1980	0.4334	0.2866
1981	0.429	0.2801
1982	0.4286	0.279
1983	0.4338	0.2858
1984	0.4302	0.2812
1985	0.4307	0.2817
1986	0.4303	0.2821
1987	0.4294	0.2794
1988	0.4291	0.28
1989	0.4323	0.2832


In [12]:
torch.save(model.state_dict(), './Trained_model/{}.pkl'.format(task_name))