In [1]:
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

seed = 4
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
layer_num = 6
num_timesteps = 1

In [6]:
model = dgllife.model.model_zoo.attentivefp_predictor.AttentiveFPPredictor(node_feat_size = 74,
                                                                           edge_feat_size = 12,
                                                                           num_layers= layer_num,
                                                                           num_timesteps= num_timesteps,
                                                                           graph_feat_size = 200,
                                                                           dropout=0.2).to(device)

In [7]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [8]:
train_loader = load_data('./Dataset/acidic_train_0.70_smiles.txt',batch_size = batch_size)
val_loader = load_data('./Dataset/acidic_val_0.15_smiles.txt',batch_size = batch_size)
test_loader = load_data('./Dataset/acidic_test_0.15_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

6337


In [9]:
task_name = 'attentive_fp_acidic_ramdom_split_5'

In [10]:
train_RMSE_lis = []
train_MAE_lis = []
val_RMSE_lis = []
val_MAE_lis = []
test_RMSE_lis = []
test_MAE_lis = []

file_name = './Logger/{}.txt'.format(task_name)
header = 'epoch:\ttrain_RMSD:\ttrain_MAE:\tval_RMSD:\tval_MAE:\ttest_RMSD:\ttest_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')

for epoch in range(1000):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N


        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N
        
        
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(test_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(test_loader.dataset)
        test_RMSE = (SSE.item()/ N)**0.5
        test_MAE = SAE.item()/N
        

        
    log = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4),round(test_RMSE,4),round(test_MAE,4))
    print(log)

    train_RMSE_lis.append(train_RMSE)
    train_MAE_lis.append(train_MAE)
    val_RMSE_lis.append(val_RMSE)
    val_MAE_lis.append(val_MAE)
    test_RMSE_lis.append(test_RMSE)
    test_MAE_lis.append(test_MAE)

    with open(file_name,'a') as f:
        f.write(log)
        f.write('\n')

epoch:	train_RMSD:	train_MAE:	val_RMSD:	val_MAE:	test_RMSD:	test_MAE:




0	5.6871	4.5019	5.72	4.4638	5.7131	4.5257
1	4.4515	3.4131	4.4627	3.2877	4.7385	3.5269
2	4.1572	3.2884	4.1913	3.1738	4.3557	3.371
3	3.8191	2.9001	3.8767	2.8175	3.8833	2.92
4	3.5629	2.9552	3.6515	2.8691	3.6163	2.96
5	3.4988	2.8861	3.6107	2.8488	3.5222	2.8705
6	3.487	2.8866	3.598	2.8532	3.505	2.8685
7	3.4774	2.9111	3.58	2.864	3.5011	2.8983
8	3.4706	2.8611	3.5741	2.82	3.4852	2.8425
9	3.4526	2.8625	3.5545	2.8197	3.464	2.8412
10	3.4358	2.8335	3.5373	2.7925	3.4453	2.8126
11	3.3889	2.8189	3.4881	2.7709	3.4004	2.8012
12	3.3825	2.6901	3.4886	2.6484	3.3972	2.691
13	3.2657	2.5966	3.3676	2.5464	3.2842	2.5997
14	3.2126	2.5679	3.3088	2.5166	3.2379	2.5721
15	3.1697	2.4517	3.2705	2.4078	3.1888	2.4679
16	3.0227	2.3313	3.1253	2.2932	3.0465	2.3441
17	2.9638	2.2981	3.0627	2.2616	3.0086	2.3168
18	2.8441	2.0007	2.9758	1.9774	2.8735	2.0257
19	2.8283	1.9537	2.9701	1.9261	2.8538	1.97
20	2.6196	1.7734	2.7619	1.7515	2.6393	1.7951
21	2.5386	1.7088	2.6758	1.6761	2.561	1.7301
22	2.5309	1.7217	2.6634	1.6818	2.5412	1.

184	1.0649	0.7473	1.4329	0.8218	1.2437	0.8007
185	1.0672	0.7464	1.3995	0.7999	1.2097	0.7938
186	1.0399	0.7253	1.397	0.7955	1.1915	0.7732
187	1.0575	0.7512	1.4106	0.8262	1.2121	0.8055
188	1.0445	0.7306	1.3913	0.7931	1.1955	0.7782
189	1.069	0.7606	1.4183	0.8371	1.222	0.8161
190	1.0503	0.7487	1.4228	0.8263	1.222	0.7991
191	1.0345	0.7321	1.3926	0.8019	1.1982	0.7821
192	1.0341	0.7264	1.3823	0.7963	1.1958	0.7866
193	1.0964	0.7908	1.4418	0.8745	1.2559	0.8529
194	1.0642	0.7477	1.4131	0.8077	1.2211	0.7869
195	1.0488	0.7381	1.3927	0.8113	1.2003	0.7948
196	1.023	0.7146	1.3725	0.7793	1.1711	0.7573
197	1.2132	0.8896	1.5291	0.9396	1.347	0.9155
198	1.13	0.8034	1.4529	0.8491	1.2661	0.8336
199	1.1217	0.7961	1.435	0.86	1.2524	0.8468
200	1.0397	0.7452	1.3982	0.8248	1.2018	0.8013
201	1.0365	0.7373	1.3898	0.8147	1.1971	0.7913
202	1.0161	0.7084	1.3698	0.7787	1.1803	0.7639
203	1.0244	0.7162	1.3603	0.7738	1.1791	0.7692
204	1.0104	0.7037	1.369	0.7821	1.1721	0.7637
205	1.013	0.7126	1.3831	0.7854	1.188	0.7665
20

365	0.8519	0.5996	1.2186	0.6849	1.0467	0.672
366	0.8607	0.6117	1.2428	0.7158	1.0724	0.7105
367	0.871	0.6074	1.2267	0.689	1.0646	0.6806
368	0.859	0.6024	1.2315	0.6969	1.059	0.6975
369	0.8555	0.5976	1.2133	0.6795	1.0534	0.691
370	0.8402	0.5875	1.2103	0.6737	1.0428	0.6781
371	0.8839	0.6267	1.2457	0.7101	1.0812	0.701
372	0.8355	0.5859	1.2185	0.6847	1.0434	0.6798
373	0.8309	0.5782	1.198	0.6649	1.0311	0.6691
374	0.8369	0.5859	1.2092	0.6768	1.0327	0.6692
375	0.8519	0.5956	1.2187	0.6865	1.0519	0.6904
376	0.825	0.5741	1.2064	0.668	1.0342	0.663
377	0.8462	0.5903	1.2171	0.6807	1.0484	0.6738
378	0.8503	0.5999	1.2186	0.6864	1.0517	0.6746
379	0.842	0.5886	1.203	0.6777	1.0498	0.6881
380	0.8189	0.5684	1.1916	0.6586	1.0307	0.6607
381	0.8233	0.5761	1.2042	0.6682	1.0352	0.6629
382	0.8849	0.6383	1.2537	0.7388	1.0838	0.7372
383	0.8333	0.5875	1.1984	0.6716	1.0321	0.6654
384	0.8365	0.5818	1.1998	0.6733	1.0395	0.6767
385	0.8262	0.5776	1.2031	0.6749	1.0339	0.6666
386	0.8236	0.5744	1.1934	0.6634	1.0312	0.6658
3

546	0.7602	0.5371	1.1249	0.6469	0.9849	0.6429
547	0.7547	0.528	1.1367	0.6425	0.9838	0.6456
548	0.7411	0.5175	1.1194	0.6261	0.9801	0.6383
549	0.8093	0.5854	1.1525	0.6801	1.0282	0.6821
550	0.7925	0.5655	1.1609	0.6741	1.0146	0.6771
551	0.7421	0.5131	1.1104	0.6204	0.9729	0.6287
552	0.7685	0.5481	1.1324	0.6515	0.9944	0.6503
553	0.7824	0.5602	1.1511	0.6716	1.0094	0.6826
554	0.7669	0.5392	1.1337	0.6425	1.0014	0.6474
555	0.7554	0.5321	1.1367	0.6443	0.9923	0.6555
556	0.739	0.5164	1.1289	0.6354	0.9793	0.6367
557	0.7645	0.5414	1.1241	0.6447	0.9927	0.6465
558	0.8117	0.5852	1.1781	0.6987	1.0442	0.7081
559	0.7511	0.5299	1.1202	0.6406	0.9859	0.6398
560	0.7366	0.5144	1.1262	0.6348	0.9819	0.6341
561	0.7894	0.5629	1.1497	0.6706	1.0115	0.6827
562	0.762	0.5401	1.1339	0.6515	0.9908	0.6445
563	0.7729	0.5573	1.1579	0.6779	1.0103	0.6826
564	0.7233	0.5	1.1075	0.6148	0.9661	0.6195
565	0.7408	0.52	1.1219	0.633	0.9789	0.6311
566	0.8432	0.6327	1.1968	0.7434	1.0651	0.7564
567	0.7412	0.5229	1.1243	0.637	0.9844	0.636

727	0.6628	0.4637	1.0492	0.6	0.936	0.6058
728	0.6683	0.4695	1.0514	0.601	0.9325	0.604
729	0.7485	0.5384	1.1048	0.6583	0.9986	0.6722
730	0.6743	0.4759	1.0472	0.6037	0.9369	0.6071
731	0.6643	0.4621	1.0378	0.6011	0.9263	0.6055
732	0.6612	0.4624	1.043	0.5945	0.93	0.6062
733	0.6903	0.4924	1.0704	0.6216	0.9602	0.6335
734	0.6912	0.494	1.0652	0.6193	0.9559	0.6338
735	0.686	0.4865	1.0732	0.6221	0.9526	0.6303
736	0.6632	0.4647	1.039	0.5922	0.925	0.5979
737	0.6577	0.4605	1.0413	0.5964	0.9268	0.5995
738	0.7081	0.508	1.0802	0.6386	0.9719	0.6446
739	0.6568	0.4595	1.0534	0.6014	0.9279	0.5967
740	0.6563	0.4565	1.0542	0.597	0.9362	0.6015
741	0.6485	0.4497	1.0428	0.5883	0.9275	0.5957
742	0.6629	0.4649	1.0473	0.6034	0.9333	0.5985
743	0.6489	0.4514	1.0368	0.5902	0.9263	0.5989
744	0.6578	0.4608	1.0571	0.6041	0.9401	0.6079
745	0.697	0.5032	1.0634	0.6299	0.9556	0.6275
746	0.6803	0.477	1.045	0.5992	0.9402	0.6153
747	0.6499	0.4541	1.0478	0.5978	0.9324	0.6059
748	0.6518	0.4553	1.0362	0.5947	0.9253	0.6027
749	0.

908	0.5973	0.418	0.9868	0.5714	0.9058	0.5816
909	0.6151	0.4351	0.9967	0.584	0.9091	0.5935
910	0.6107	0.4343	1.0058	0.5936	0.9146	0.5907
911	0.6009	0.4245	0.9906	0.5732	0.9018	0.588
912	0.6396	0.4606	1.0286	0.6151	0.9285	0.6217
913	0.6038	0.4239	0.9987	0.5801	0.9025	0.5819
914	0.6113	0.4338	0.985	0.5828	0.9021	0.584
915	0.6051	0.4315	0.9805	0.5799	0.8997	0.5843
916	0.6125	0.4341	1.0022	0.5942	0.9111	0.5972
917	0.6013	0.4241	0.9789	0.569	0.8979	0.5845
918	0.6059	0.4269	0.9917	0.5822	0.9078	0.589
919	0.5973	0.4201	0.9948	0.5804	0.9058	0.5843
920	0.5999	0.4213	0.9878	0.5709	0.8945	0.5781
921	0.6018	0.4228	0.9868	0.5738	0.901	0.5807
922	0.5994	0.4213	0.9963	0.5768	0.9011	0.5797
923	0.6218	0.4426	0.9991	0.5909	0.9086	0.5962
924	0.603	0.4275	0.9973	0.581	0.9062	0.5894
925	0.626	0.4501	1.0101	0.6033	0.9237	0.6101
926	0.7477	0.5752	1.091	0.7041	1.0137	0.7106
927	0.7184	0.5302	1.059	0.6548	0.9765	0.6549
928	0.6139	0.4331	1.0007	0.5853	0.9115	0.5944
929	0.6244	0.4493	1.0031	0.5901	0.9269	0.611
93

In [11]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [12]:
for epoch in range(1000,1500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N


        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N
        
        
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(test_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(test_loader.dataset)
        test_RMSE = (SSE.item()/ N)**0.5
        test_MAE = SAE.item()/N
        

        
    log = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4),round(test_RMSE,4),round(test_MAE,4))
    print(log)

    train_RMSE_lis.append(train_RMSE)
    train_MAE_lis.append(train_MAE)
    val_RMSE_lis.append(val_RMSE)
    val_MAE_lis.append(val_MAE)
    test_RMSE_lis.append(test_RMSE)
    test_MAE_lis.append(test_MAE)

    with open(file_name,'a') as f:
        f.write(log)
        f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1000	0.562	0.3921	0.9651	0.5621	0.8899	0.5671
1001	0.5623	0.3925	0.9642	0.5626	0.8904	0.5665
1002	0.5588	0.389	0.9586	0.5576	0.886	0.5628
1003	0.5581	0.3882	0.9593	0.558	0.8873	0.564
1004	0.558	0.388	0.9601	0.5589	0.8869	0.561
1005	0.5614	0.3925	0.9632	0.5609	0.8907	0.5676
1006	0.5585	0.3885	0.9578	0.5577	0.8857	0.5602
1007	0.5571	0.3876	0.9613	0.5581	0.8876	0.5627
1008	0.5562	0.3865	0.9587	0.5561	0.8869	0.5609
1009	0.5554	0.3861	0.9579	0.5574	0.8864	0.5598
1010	0.5552	0.3865	0.9609	0.5585	0.889	0.5622
1011	0.5568	0.3879	0.9613	0.5582	0.8901	0.5666
1012	0.555	0.3852	0.9562	0.5549	0.8854	0.5598
1013	0.5577	0.388	0.9578	0.5568	0.8854	0.5597
1014	0.5598	0.3904	0.9637	0.5601	0.8886	0.5656
1015	0.5543	0.3849	0.9575	0.5555	0.8838	0.5583
1016	0.5541	0.3851	0.9588	0.5556	0.8855	0.561
1017	0.5551	0.3861	0.9556	0.554	0.8838	0.56
1018	0.5545	0.386	0.9574	0.5556	0.8855	0.5605
1019	0.5542	0.3859	0.954	0.5547	0.883	0.559
1020	0.5539	0.3852	0.9575	0.5554	0.8848	0.5617
1021	0.5551	0.3872	0.9616	0.5582

1177	0.5482	0.3846	0.9538	0.5582	0.8897	0.564
1178	0.5385	0.374	0.9475	0.5499	0.884	0.555
1179	0.5382	0.3738	0.9466	0.5504	0.8809	0.5538
1180	0.5404	0.3772	0.9497	0.5541	0.8826	0.5564
1181	0.5421	0.3775	0.9504	0.5536	0.8836	0.5585
1182	0.5386	0.3734	0.9436	0.5482	0.8809	0.5554
1183	0.5398	0.3758	0.9466	0.5532	0.8832	0.5551
1184	0.5425	0.3783	0.9523	0.5564	0.884	0.5578
1185	0.5382	0.3738	0.9471	0.5504	0.8801	0.5531
1186	0.5377	0.3731	0.9458	0.5478	0.88	0.5529
1187	0.5383	0.3738	0.9465	0.5496	0.8809	0.5539
1188	0.5389	0.3736	0.9448	0.5483	0.8782	0.5531
1189	0.539	0.375	0.9478	0.5512	0.8804	0.5572
1190	0.5372	0.374	0.9466	0.5512	0.8811	0.554
1191	0.5391	0.3751	0.9458	0.5511	0.8819	0.5545
1192	0.5379	0.3734	0.9435	0.5488	0.8809	0.5537
1193	0.5373	0.3729	0.9431	0.5478	0.8806	0.5547
1194	0.5392	0.3752	0.9445	0.5503	0.8825	0.5548
1195	0.5393	0.3751	0.9473	0.5523	0.8826	0.5556
1196	0.5375	0.373	0.9439	0.5483	0.8813	0.5549
1197	0.5366	0.3725	0.9452	0.5489	0.8824	0.5554
1198	0.5364	0.3718	0.9453

1354	0.5305	0.3698	0.9461	0.5524	0.8869	0.5614
1355	0.5264	0.365	0.94	0.5478	0.8811	0.5521
1356	0.526	0.3653	0.9418	0.5483	0.8838	0.5535
1357	0.5275	0.3666	0.9405	0.5475	0.8813	0.5528
1358	0.5333	0.373	0.9438	0.5531	0.8845	0.5619
1359	0.5281	0.3666	0.9341	0.5445	0.8778	0.5519
1360	0.5325	0.3713	0.9343	0.5472	0.8815	0.5541
1361	0.5259	0.365	0.9326	0.5439	0.8784	0.5503
1362	0.5267	0.3674	0.9396	0.5473	0.8847	0.5543
1363	0.5317	0.3721	0.9432	0.5514	0.8882	0.559
1364	0.5251	0.3651	0.9387	0.548	0.8831	0.5557
1365	0.528	0.3674	0.9356	0.5468	0.8815	0.553
1366	0.5257	0.3648	0.9364	0.5447	0.8811	0.5531
1367	0.5241	0.3637	0.9363	0.5452	0.8792	0.5518
1368	0.5347	0.3756	0.9476	0.5575	0.8894	0.5643
1369	0.5241	0.3638	0.9352	0.5448	0.8819	0.5523
1370	0.5273	0.3672	0.9335	0.5463	0.8802	0.5512
1371	0.5277	0.367	0.9352	0.5453	0.8803	0.551
1372	0.5247	0.3641	0.936	0.5445	0.8797	0.5503
1373	0.5256	0.3646	0.9334	0.5433	0.8789	0.5487
1374	0.5254	0.3656	0.9401	0.5483	0.8828	0.5543
1375	0.5283	0.3675	0.9379	

In [22]:
torch.save(model.state_dict(), './Trained_model/{}.pkl'.format(task_name))

In [23]:
import numpy as np
import pandas as pd
import pickle
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.ML.Cluster import Butina

In [24]:
def Hierarchical_mae(similarity_list,d_list):
    data=pd.DataFrame({'similarity':similarity_list,'d':d_list})
    print(np.mean(data[(data['similarity']>=0.8)]['d']))
    print(np.mean(data[(data['similarity']>=0.7)&(data['similarity']<0.8)]['d']))
    print(np.mean(data[(data['similarity']>=0.6)&(data['similarity']<0.7)]['d']))
    print(np.mean(data[(data['similarity']>=0.5)&(data['similarity']<0.6)]['d']))
    print(np.mean(data[(data['similarity']<0.5)]['d']))

In [25]:
train_smiles_list = []
with open('./Dataset/acidic_train_0.70_smiles.txt') as f:
    for line in f.readlines():
        line = line.replace('\n','').split('./t')
        train_smiles_list.append(line[0])

train_ms = [Chem.MolFromSmiles(i) for i in train_smiles_list]
train_fps_list = [AllChem.GetMorganFingerprintAsBitVect(x,2,1024) for x in train_ms]

test_smiles_list = []
with open('./Dataset/acidic_test_0.15_smiles.txt') as f:
    for line in f.readlines():
        line = line.replace('\n','').split('./t')
        test_smiles_list.append(line[0])
        
test_ms = [Chem.MolFromSmiles(i) for i in test_smiles_list]
test_fps_list = [AllChem.GetMorganFingerprintAsBitVect(x,2,1024) for x in test_ms]

In [26]:
similarity_list = []
for test_fps in test_fps_list:
    sims = DataStructs.BulkTanimotoSimilarity(test_fps,train_fps_list)
    similarity_list.append(max(sims))

In [27]:
print(np.sum(np.array(similarity_list) >= 0.8))
print(np.sum(np.array(similarity_list) >= 0.7))
print(np.sum(np.array(similarity_list) >= 0.6))
print(np.sum(np.array(similarity_list) >= 0.5))
print(np.sum(np.array(similarity_list) <0.5))

258
461
789
1101
257


In [64]:
task_name = 'attentive_fp_acidic_ramdom_split_5'

In [65]:
from My_Pka_Model import Pka_basic_view,Pka_acidic_view
import torch
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from sklearn.metrics import r2_score,mean_squared_error,mean_absolute_error

In [66]:
def predict(smiles,model_view):

    node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
    edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')
    bg = smiles_to_bigraph(smiles= smiles, 
                  node_featurizer=node_featurizer,
                  edge_featurizer=edge_featurizer,canonical_atom_order= False)

    with torch.no_grad():
        model_view.eval()
        molecule_pka = model_view(bg,bg.ndata['h'], bg.edata['h'])
        
    return molecule_pka

In [67]:
acid_pred = dgllife.model.model_zoo.attentivefp_predictor.AttentiveFPPredictor(node_feat_size = 74,
                                                                           edge_feat_size = 12,
                                                                           num_layers= layer_num,
                                                                           num_timesteps= num_timesteps,
                                                                           graph_feat_size = 200,
                                                                           dropout=0.2)


acid_pred.load_state_dict(torch.load('./Trained_model/{}.pkl'.format(task_name),map_location='cuda:1'))

with open('./Dataset/acidic_test_0.15_smiles.txt') as f: #acidic_test_0.15_smiles.txt,SAMPL7_acidic_smiles.txt
    pred = []
    label = []
    for line in f.readlines():
        line = line.replace('\n','').split('\t')
        molecule_pka = predict(line[0],acid_pred)
        pred.append(molecule_pka)
        label.append(float(line[1]))

d_list = []
for i,j in zip(pred,label):
    d_list.append(abs(i-j))


print(r2_score(label,pred))
    
print(np.mean(d_list))
print('')

Hierarchical_mae(similarity_list,d_list)
print('')

0.935677117647524
0.55201435

0.3415940159050993
0.46638834535194734
0.4893402471774962
0.5815651722443409
0.8750017218088826



In [None]:
# 'attentive_fp_acidic_ramdom_split_1'

# 0.9322493488631574
# 0.57189924

# 0.3448248870613039
# 0.4497933974994227
# 0.47963975115520197
# 0.642538804274339
# 0.928296575286509

In [None]:
'attentive_fp_acidic_ramdom_split_2'

0.9301949221218309
0.5718746

0.35184461016987645
0.4599577805091595
0.4820726906381002
0.6328718723394932
0.9217219965003344

In [None]:
'attentive_fp_acidic_ramdom_split_3'

0.9310545712535644
0.57769036

0.34964175556981286
0.48165720671855755
0.4916661890541635
0.6217847481752053
0.9387404742407891

In [None]:
'attentive_fp_acidic_ramdom_split_4'

0.9308790718875417
0.5703063

0.33233355736547665
0.4831363320937885
0.48163451218023534
0.6266873677571615
0.9227813245721365

In [None]:
# 'attentive_fp_acidic_ramdom_split_5'

# 0.935677117647524
# 0.55201435

# 0.3415940159050993
# 0.46638834535194734
# 0.4893402471774962
# 0.5815651722443409
# 0.8750017218088826

In [None]:
0.57189924

0.3448248870613039
0.4497933974994227
0.47963975115520197
0.642538804274339
0.928296575286509

In [None]:
0.5718746

0.35184461016987645
0.4599577805091595
0.4820726906381002
0.6328718723394932
0.9217219965003344

In [None]:
0.57769036

0.34964175556981286
0.48165720671855755
0.4916661890541635
0.6217847481752053
0.9387404742407891


In [None]:
0.5703063

0.33233355736547665
0.4831363320937885
0.48163451218023534
0.6266873677571615
0.9227813245721365

In [None]:
0.55201435

0.3415940159050993
0.46638834535194734
0.4893402471774962
0.5815651722443409
0.8750017218088826


In [13]:
with open(file_name,'a') as f:
    for a,b,c,d,e,f in zip(train_RMSE_lis,train_MAE_lis,val_RMSE_lis,val_MAE_lis,test_RMSE_lis,test_MAE_lis):
        f.write(a)
        f.write('\t')
        f.write(b)
        f.write('\t')
        f.write(c)
        f.write('\t')
        f.write(d)
        f.write('\t')
        f.write(e)
        f.write('\t')
        f.write(f)
        f.write('\n')

SyntaxError: invalid syntax (<ipython-input-13-20f7fc593900>, line 2)

In [None]:
6比5好比4好

In [None]:
7比6差 6的时候是0.558(2层)

In [None]:
attentve FP 最好的应该也是6层

In [None]:
在搜索下步数