In [1]:
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed = 9
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
layer_num = 6
num_timesteps = 1

In [6]:
model = dgllife.model.model_zoo.attentivefp_predictor.AttentiveFPPredictor(node_feat_size = 74,
                                                                           edge_feat_size = 12,
                                                                           num_layers= layer_num,
                                                                           num_timesteps= num_timesteps,
                                                                           graph_feat_size = 200,
                                                                           dropout=0.2).to(device)

In [7]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [8]:
train_loader = load_data('./Dataset/basic_train_0.70_smiles.txt',batch_size = batch_size)
val_loader = load_data('./Dataset/basic_val_0.15_smiles.txt',batch_size = batch_size)
test_loader = load_data('./Dataset/basic_test_0.15_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

5905


In [9]:
task_name = 'attentive_fp_basic_ramdom_split_9'

In [10]:
train_RMSE_lis = []
train_MAE_lis = []
val_RMSE_lis = []
val_MAE_lis = []
test_RMSE_lis = []
test_MAE_lis = []

file_name = './Logger/{}.txt'.format(task_name)
header = 'epoch:\ttrain_RMSD:\ttrain_MAE:\tval_RMSD:\tval_MAE:\ttest_RMSD:\ttest_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')

for epoch in range(1000):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N


        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N
        
        
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(test_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(test_loader.dataset)
        test_RMSE = (SSE.item()/ N)**0.5
        test_MAE = SAE.item()/N
        

        
    log = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4),round(test_RMSE,4),round(test_MAE,4))
    print(log)

    train_RMSE_lis.append(train_RMSE)
    train_MAE_lis.append(train_MAE)
    val_RMSE_lis.append(val_RMSE)
    val_MAE_lis.append(val_MAE)
    test_RMSE_lis.append(test_RMSE)
    test_MAE_lis.append(test_MAE)

    with open(file_name,'a') as f:
        f.write(log)
        f.write('\n')

epoch:	train_RMSD:	train_MAE:	val_RMSD:	val_MAE:	test_RMSD:	test_MAE:




0	5.7482	5.0744	5.9024	5.2629	5.8653	5.2021
1	3.9941	3.1763	4.0643	3.252	3.9686	3.1895
2	3.9394	3.1362	3.9562	3.149	3.8765	3.1235
3	3.5999	2.8795	3.6902	2.9741	3.625	2.9237
4	3.2271	2.5816	3.2528	2.6041	3.1879	2.5877
5	3.0444	2.5002	3.0973	2.5352	3.0197	2.5338
6	3.012	2.3941	3.0197	2.3798	2.9267	2.3661
7	2.9122	2.3869	2.9567	2.4126	2.8637	2.4001
8	2.735	2.155	2.7477	2.1555	2.6532	2.1466
9	2.5458	1.9666	2.5519	1.9664	2.4742	1.9573
10	2.4529	1.8556	2.4477	1.835	2.3866	1.8267
11	2.4121	1.7991	2.3972	1.7776	2.3509	1.7701
12	2.376	1.8003	2.3667	1.782	2.3261	1.7752
13	2.3712	1.8201	2.3678	1.8045	2.3285	1.798
14	2.3653	1.8197	2.3634	1.8036	2.3229	1.7977
15	2.3475	1.756	2.3317	1.7327	2.2884	1.7281
16	2.3379	1.7473	2.3251	1.7251	2.28	1.7207
17	2.3265	1.7606	2.3231	1.7416	2.2761	1.7376
18	2.3126	1.738	2.3063	1.7192	2.2582	1.7152
19	2.3008	1.7125	2.2898	1.6929	2.2399	1.6885
20	2.2831	1.7322	2.2865	1.7199	2.2316	1.7141
21	2.263	1.7163	2.2736	1.7103	2.2131	1.7017
22	2.2325	1.6573	2.2388	1.6436	2.17

184	0.9968	0.711	1.1844	0.7796	1.1165	0.7632
185	1.0464	0.7736	1.217	0.8242	1.1556	0.8215
186	1.0128	0.7313	1.2027	0.7996	1.1391	0.7832
187	0.9887	0.7052	1.1652	0.7676	1.1012	0.757
188	0.982	0.7	1.1689	0.7653	1.1031	0.7531
189	0.9891	0.7041	1.1732	0.7637	1.1052	0.7571
190	0.9986	0.7122	1.1716	0.7766	1.1047	0.7664
191	0.9992	0.7184	1.1797	0.7847	1.1158	0.774
192	0.9802	0.7033	1.1741	0.771	1.1075	0.7603
193	0.975	0.6982	1.1661	0.7678	1.096	0.752
194	0.9788	0.7004	1.1701	0.7713	1.0972	0.7513
195	0.9714	0.6943	1.1713	0.7619	1.1041	0.7488
196	0.977	0.6955	1.1617	0.761	1.0864	0.746
197	1.019	0.7366	1.2054	0.7839	1.1283	0.7848
198	1.0115	0.736	1.1999	0.8037	1.142	0.7986
199	1.0087	0.739	1.2039	0.7951	1.129	0.7899
200	0.9727	0.6981	1.1637	0.7656	1.1006	0.754
201	0.9863	0.7123	1.1677	0.7703	1.0966	0.7629
202	0.9785	0.7039	1.1805	0.7721	1.1134	0.7593
203	1.023	0.7536	1.1997	0.8012	1.1298	0.8014
204	0.9766	0.7028	1.1733	0.7751	1.1036	0.7598
205	0.9566	0.6805	1.1582	0.7454	1.0856	0.7359
206	0.9569

365	0.8496	0.6182	1.0906	0.7005	1.0579	0.7086
366	0.8441	0.6111	1.0869	0.7045	1.0635	0.706
367	0.8189	0.5824	1.0571	0.6677	1.0302	0.6677
368	0.8276	0.5905	1.071	0.6754	1.0317	0.679
369	0.8401	0.6106	1.0728	0.7069	1.0491	0.7065
370	0.8238	0.5853	1.0798	0.6771	1.0451	0.68
371	0.817	0.5824	1.0883	0.6789	1.0502	0.6814
372	0.821	0.5867	1.0814	0.6805	1.0425	0.6782
373	0.8023	0.5669	1.0645	0.661	1.0286	0.6599
374	0.8031	0.5674	1.0613	0.6648	1.0189	0.663
375	0.8141	0.5809	1.0951	0.679	1.0515	0.6751
376	0.8014	0.5653	1.076	0.6589	1.0212	0.6587
377	0.8014	0.5679	1.0791	0.6628	1.0274	0.6621
378	0.8001	0.5666	1.0597	0.6632	1.024	0.6596
379	0.8001	0.5626	1.0602	0.6608	1.0265	0.6585
380	0.7996	0.5682	1.0489	0.6628	1.0158	0.657
381	0.8153	0.5792	1.0963	0.6714	1.0561	0.6737
382	0.8081	0.5775	1.0578	0.6711	1.0322	0.6742
383	0.803	0.569	1.0832	0.674	1.0429	0.6633
384	0.8007	0.566	1.071	0.6693	1.0332	0.6623
385	0.8059	0.5699	1.0768	0.6738	1.0474	0.6666
386	0.8222	0.5944	1.0747	0.6927	1.0434	0.6842
387	0.

546	0.7194	0.5062	1.0334	0.6357	0.9887	0.6251
547	0.7246	0.5087	1.0412	0.6347	0.9912	0.6258
548	0.73	0.5156	1.0424	0.6311	0.9899	0.6234
549	0.7248	0.506	1.0754	0.643	1.0152	0.6337
550	0.7102	0.4952	1.0487	0.6287	0.9868	0.6157
551	0.719	0.5034	1.0631	0.636	1.0	0.6155
552	0.7194	0.5025	1.0647	0.6367	1.0121	0.6277
553	0.713	0.4978	1.0573	0.6268	1.002	0.6189
554	0.7092	0.4922	1.0689	0.632	1.0096	0.6165
555	0.719	0.5046	1.0357	0.6351	1.0005	0.6273
556	0.7084	0.4927	1.0545	0.6264	1.0004	0.6184
557	0.7278	0.5145	1.0577	0.6516	1.0089	0.6323
558	0.7054	0.491	1.0495	0.6234	0.9949	0.6148
559	0.7428	0.5329	1.0875	0.6628	1.0272	0.6532
560	0.7175	0.505	1.0502	0.6386	0.9976	0.6261
561	0.7421	0.525	1.0855	0.6606	1.0292	0.6475
562	0.7174	0.5019	1.069	0.6396	1.0088	0.625
563	0.7159	0.5055	1.043	0.634	0.9879	0.6237
564	0.7263	0.5127	1.0842	0.6458	1.0266	0.6385
565	0.7438	0.5301	1.0483	0.656	1.0014	0.6387
566	0.7298	0.5092	1.0974	0.6485	1.04	0.6342
567	0.7171	0.507	1.0592	0.6414	1.0016	0.6261
568	0.7158	0

727	0.6571	0.4563	1.0486	0.6101	0.9774	0.5966
728	0.6898	0.4848	1.0897	0.6431	1.0242	0.6196
729	0.6682	0.4688	1.0322	0.6145	0.9775	0.6055
730	0.6565	0.4557	1.0424	0.6107	0.9823	0.5954
731	0.6585	0.4588	1.0563	0.6112	0.9922	0.6028
732	0.6554	0.4558	1.0447	0.6077	0.9772	0.5944
733	0.6616	0.46	1.0699	0.625	1.0024	0.6016
734	0.6569	0.4573	1.0409	0.6093	0.974	0.5947
735	0.655	0.4566	1.0323	0.6097	0.97	0.5942
736	0.658	0.4571	1.0472	0.6188	0.9895	0.6013
737	0.6509	0.4508	1.0403	0.6058	0.9689	0.5894
738	0.6559	0.4565	1.0649	0.6144	1.0003	0.6027
739	0.648	0.446	1.0343	0.604	0.9743	0.5915
740	0.7191	0.5193	1.0944	0.6705	1.0375	0.6473
741	0.6578	0.4584	1.0537	0.6113	0.9854	0.5975
742	0.6595	0.4581	1.0518	0.6125	0.984	0.598
743	0.6493	0.4509	1.054	0.6126	0.9879	0.5943
744	0.6552	0.4533	1.042	0.6079	0.979	0.5927
745	0.6732	0.474	1.0633	0.6312	0.9961	0.6192
746	0.6814	0.4818	1.0562	0.6342	0.9959	0.6177
747	0.68	0.4799	1.0767	0.6321	1.0057	0.6146
748	0.6905	0.4919	1.0577	0.6309	0.9928	0.624
749	0.68

909	0.6238	0.4374	1.0677	0.6099	0.9839	0.5942
910	0.6188	0.4265	1.0676	0.6088	0.9811	0.5807
911	0.6235	0.4339	1.0514	0.6126	0.9769	0.5917
912	0.6441	0.4575	1.0825	0.6239	1.0041	0.6098
913	0.6177	0.4266	1.0416	0.6033	0.9676	0.5855
914	0.6171	0.428	1.0608	0.6091	0.9778	0.5861
915	0.6365	0.4483	1.0469	0.6051	0.964	0.5925
916	0.6362	0.4529	1.0502	0.6259	0.9698	0.606
917	0.6202	0.4335	1.051	0.6126	0.9641	0.5933
918	0.6124	0.4241	1.0432	0.5944	0.9587	0.5797
919	0.6273	0.4388	1.0518	0.6192	0.9795	0.5978
920	0.6241	0.4372	1.0606	0.6078	0.9847	0.5932
921	0.6093	0.4206	1.044	0.5955	0.9614	0.5768
922	0.6175	0.4291	1.052	0.6113	0.9732	0.5894
923	0.6043	0.4135	1.0527	0.5957	0.9781	0.5816
924	0.6007	0.411	1.0508	0.5931	0.9687	0.5726
925	0.6535	0.4598	1.1042	0.6446	1.0268	0.6144
926	0.6383	0.4558	1.0788	0.6326	0.9963	0.6091
927	0.6147	0.4233	1.0757	0.6044	0.9917	0.5795
928	0.6364	0.4487	1.0749	0.636	0.9979	0.6084
929	0.6424	0.4612	1.0594	0.6171	0.9731	0.6047
930	0.6113	0.4203	1.0623	0.5948	0.9797	0.5

In [11]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [12]:
for epoch in range(1000,1500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N


        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N
        
        
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(test_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(test_loader.dataset)
        test_RMSE = (SSE.item()/ N)**0.5
        test_MAE = SAE.item()/N
        

        
    log = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4),round(test_RMSE,4),round(test_MAE,4))
    print(log)

    train_RMSE_lis.append(train_RMSE)
    train_MAE_lis.append(train_MAE)
    val_RMSE_lis.append(val_RMSE)
    val_MAE_lis.append(val_MAE)
    test_RMSE_lis.append(test_RMSE)
    test_MAE_lis.append(test_MAE)

    with open(file_name,'a') as f:
        f.write(log)
        f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1000	0.6051	0.4185	1.0633	0.6113	0.9855	0.5874
1001	0.5922	0.4104	1.0391	0.5909	0.9581	0.5776
1002	0.5925	0.4053	1.0564	0.6013	0.9764	0.5761
1003	0.5847	0.4002	1.0396	0.5884	0.9562	0.5705
1004	0.5842	0.3982	1.0466	0.5918	0.9618	0.5693
1005	0.5821	0.396	1.0536	0.5893	0.9652	0.567
1006	0.5824	0.3967	1.0471	0.5901	0.9619	0.5681
1007	0.581	0.3958	1.0457	0.5882	0.9607	0.5674
1008	0.5811	0.3961	1.0441	0.5874	0.9601	0.5674
1009	0.5829	0.3963	1.0463	0.5902	0.9627	0.5679
1010	0.5828	0.3986	1.0422	0.586	0.9559	0.5686
1011	0.5821	0.3954	1.051	0.591	0.963	0.5677
1012	0.5801	0.3952	1.0457	0.587	0.9558	0.5662
1013	0.5808	0.3955	1.0498	0.5894	0.9598	0.5663
1014	0.5801	0.394	1.0476	0.5879	0.9612	0.5651
1015	0.5808	0.3955	1.0519	0.5905	0.964	0.5679
1016	0.583	0.3985	1.0582	0.5935	0.9681	0.5708
1017	0.5791	0.3931	1.0522	0.5879	0.9629	0.5655
1018	0.5805	0.3944	1.0548	0.5907	0.9652	0.5677
1019	0.581	0.3961	1.0562	0.5898	0.965	0.5692
1020	0.5803	0.3937	1.059	0.5908	0.97	0.5681
1021	0.578	0.3931	1.0501	0.58

1177	0.5681	0.3866	1.056	0.5898	0.966	0.5651
1178	0.5678	0.3852	1.0572	0.5903	0.9681	0.5644
1179	0.5667	0.3834	1.0572	0.5878	0.9659	0.5622
1180	0.5663	0.3853	1.0508	0.5864	0.958	0.5624
1181	0.567	0.386	1.0554	0.5864	0.9625	0.5632
1182	0.5682	0.3856	1.0484	0.5867	0.9605	0.5637
1183	0.5673	0.3842	1.0556	0.5884	0.9659	0.5637
1184	0.5711	0.391	1.0563	0.5886	0.963	0.5681
1185	0.5661	0.3835	1.0608	0.586	0.9654	0.5603
1186	0.5695	0.3867	1.0644	0.5919	0.969	0.5641
1187	0.5664	0.3846	1.0576	0.5887	0.964	0.5637
1188	0.5681	0.3881	1.0552	0.5878	0.9626	0.5674
1189	0.5662	0.3845	1.0562	0.587	0.9625	0.5638
1190	0.5705	0.3887	1.0595	0.5929	0.9661	0.5665
1191	0.567	0.384	1.0646	0.5888	0.9703	0.5609
1192	0.567	0.3855	1.0605	0.5847	0.9674	0.5625
1193	0.5665	0.3841	1.0581	0.5869	0.9669	0.5629
1194	0.5668	0.3846	1.052	0.5871	0.961	0.5635
1195	0.5666	0.3833	1.0572	0.5874	0.9635	0.5618
1196	0.5656	0.3831	1.0556	0.5851	0.9602	0.5606
1197	0.5654	0.3834	1.0522	0.5857	0.9585	0.5616
1198	0.5683	0.3847	1.0642	0.5

1354	0.5623	0.385	1.0598	0.5873	0.9575	0.565
1355	0.5586	0.3787	1.0562	0.5873	0.9574	0.561
1356	0.5647	0.3845	1.0687	0.5964	0.9718	0.5674
1357	0.567	0.3894	1.0566	0.5979	0.9642	0.5726
1358	0.5652	0.3882	1.0586	0.5931	0.9634	0.5695
1359	0.562	0.3849	1.0549	0.5856	0.9604	0.5644
1360	0.561	0.3803	1.0656	0.5918	0.9729	0.5633
1361	0.563	0.384	1.0532	0.5915	0.9595	0.5665
1362	0.563	0.3859	1.0597	0.5897	0.9607	0.5677
1363	0.5628	0.3849	1.0603	0.586	0.966	0.5675
1364	0.5575	0.3772	1.0629	0.586	0.9691	0.5615
1365	0.558	0.3775	1.0582	0.5864	0.9622	0.5608
1366	0.5569	0.3768	1.0561	0.585	0.9595	0.5596
1367	0.5579	0.3776	1.0621	0.5855	0.9663	0.5613
1368	0.5568	0.3774	1.0511	0.5835	0.9568	0.5608
1369	0.5569	0.3784	1.0517	0.582	0.9558	0.5606
1370	0.5567	0.3772	1.0614	0.5852	0.9649	0.5603
1371	0.5574	0.3773	1.0615	0.5871	0.9676	0.5619
1372	0.5565	0.3774	1.0564	0.5852	0.9633	0.5623
1373	0.5573	0.3776	1.0594	0.5852	0.9664	0.5621
1374	0.5566	0.3774	1.0591	0.5833	0.9631	0.5609
1375	0.561	0.381	1.0678	0.59

In [13]:
torch.save(model.state_dict(), './Trained_model/{}.pkl'.format(task_name))

In [14]:
import numpy as np
import pandas as pd
import pickle
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.ML.Cluster import Butina

In [15]:
def Hierarchical_mae(similarity_list,d_list):
    data=pd.DataFrame({'similarity':similarity_list,'d':d_list})
    print(np.mean(data[(data['similarity']>=0.8)]['d']))
    print(np.mean(data[(data['similarity']>=0.7)&(data['similarity']<0.8)]['d']))
    print(np.mean(data[(data['similarity']>=0.6)&(data['similarity']<0.7)]['d']))
    print(np.mean(data[(data['similarity']>=0.5)&(data['similarity']<0.6)]['d']))
    print(np.mean(data[(data['similarity']<0.5)]['d']))

In [16]:
train_smiles_list = []
with open('./Dataset/basic_train_0.70_smiles.txt') as f:
    for line in f.readlines():
        line = line.replace('\n','').split('./t')
        train_smiles_list.append(line[0])

train_ms = [Chem.MolFromSmiles(i) for i in train_smiles_list]
train_fps_list = [AllChem.GetMorganFingerprintAsBitVect(x,2,1024) for x in train_ms]

test_smiles_list = []
with open('./Dataset/basic_test_0.15_smiles.txt') as f:
    for line in f.readlines():
        line = line.replace('\n','').split('./t')
        test_smiles_list.append(line[0])
        
test_ms = [Chem.MolFromSmiles(i) for i in test_smiles_list]
test_fps_list = [AllChem.GetMorganFingerprintAsBitVect(x,2,1024) for x in test_ms]

In [17]:
similarity_list = []
for test_fps in test_fps_list:
    sims = DataStructs.BulkTanimotoSimilarity(test_fps,train_fps_list)
    similarity_list.append(max(sims))

In [18]:
print(np.sum(np.array(similarity_list) >= 0.8))
print(np.sum(np.array(similarity_list) >= 0.7))
print(np.sum(np.array(similarity_list) >= 0.6))
print(np.sum(np.array(similarity_list) >= 0.5))
print(np.sum(np.array(similarity_list) <0.5))

304
479
743
1027
239


In [19]:
from My_Pka_Model import Pka_basic_view,Pka_acidic_view
import torch
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from sklearn.metrics import r2_score,mean_squared_error,mean_absolute_error

In [37]:
task_name = 'attentive_fp_basic_ramdom_split_6'

In [38]:
def predict(smiles,model_view):

    node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
    edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')
    bg = smiles_to_bigraph(smiles= smiles, 
                  node_featurizer=node_featurizer,
                  edge_featurizer=edge_featurizer,canonical_atom_order= False)

    with torch.no_grad():
        model_view.eval()
        molecule_pka = model_view(bg,bg.ndata['h'], bg.edata['h'])
        
    return molecule_pka

In [39]:
base_pred = dgllife.model.model_zoo.attentivefp_predictor.AttentiveFPPredictor(node_feat_size = 74,
                                                                           edge_feat_size = 12,
                                                                           num_layers= layer_num,
                                                                           num_timesteps= num_timesteps,
                                                                           graph_feat_size = 200,
                                                                           dropout=0.2)


base_pred.load_state_dict(torch.load('./Trained_model/{}.pkl'.format(task_name),map_location='cuda:0'))

with open('./Dataset/basic_test_0.15_smiles.txt') as f: #acidic_test_0.15_smiles.txt,SAMPL7_acidic_smiles.txt
    pred = []
    label = []
    for line in f.readlines():
        line = line.replace('\n','').split('\t')
        molecule_pka = predict(line[0],base_pred)
        pred.append(molecule_pka)
        label.append(float(line[1]))

d_list = []
for i,j in zip(pred,label):
    d_list.append(abs(i-j))

    
print(r2_score(label,pred))
    
print(np.mean(d_list))
print('')

Hierarchical_mae(similarity_list,d_list)
print('')

0.9132477475221112
0.533315

0.35499934146278783
0.39333971296037945
0.45729642925840436
0.6017072973117022
0.8653197986810277



In [None]:
'attentive_fp_basic_ramdom_split_1'

0.9108314604529422
0.54661095

0.37864353782252264
0.41122528076171877
0.46382499463630444
0.5952671749491087
0.8930202468169783

In [None]:
'attentive_fp_basic_ramdom_split_2'
0.913721733203913
0.5491998

0.38041473689832184
0.4342647443498884
0.49626708753181226
0.5797460314253686
0.87021876379037


In [None]:
'attentive_fp_basic_ramdom_split_3'

0.9087380761523143
0.556724

0.3869677091899671
0.4265252685546875
0.47582707260594226
0.5995750427246094
0.9064218688709467


In [None]:
# 'attentive_fp_basic_ramdom_split_4'

# 0.9163469734776615
# 0.5467407

# 0.3916906808551989
# 0.42055943080357144
# 0.46014947602243134
# 0.5891618862958021
# 0.881591158431943

In [None]:
# 'attentive_fp_basic_ramdom_split_6'

In [None]:
0.5623687

0.3965450086091694
0.43599875313895087
0.47306283315022785
0.6031838000660211
0.9159686994353099

In [None]:
0.5491998

0.38041473689832184
0.4342647443498884
0.49626708753181226
0.5797460314253686
0.87021876379037

In [None]:
0.556724

0.3869677091899671
0.4265252685546875
0.47582707260594226
0.5995750427246094
0.9064218688709467

In [None]:
0.5467407

0.3916906808551989
0.42055943080357144
0.46014947602243134
0.5891618862958021
0.881591158431943

In [None]:
0.53809017

0.3520124084071109
0.38675959995814735
0.45269659793738165
0.6018713292941241
0.9041172809680635

In [None]:
0.54776216

0.3785461877521716
0.42188729422433036
0.4662445241754705
0.5893427351830711
0.8958020070606695

In [None]:
0.5582808

0.3777780532836914
0.39223645891462056
0.47042676174279413
0.5953046234560685
0.9625036008188416

In [22]:
(0.54661095+0.5623687+0.5491998+0.556724+0.5467407+0.53809017)/6

0.5499557200000001

In [None]:
base_pred = Pka_basic_view(node_feat_size = 74,
                            edge_feat_size = 12,
                            output_size = 1,
                            num_layers= 6,
                            graph_feat_size=200,
                            dropout=0.2)

for i in range(1,4):
    base_pred.load_state_dict(torch.load('./Trained_model/basic_ramdom_split_{}.pkl'.format(i),map_location='cuda:1'))

    with open('./Dataset/basic_test_0.15_smiles.txt') as f: #acidic_test_0.15_smiles.txt,SAMPL7_acidic_smiles.txt
        pred = []
        label = []
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            molecule_pka,atom_pka = predict(line[0],base_pred)
            pred.append(molecule_pka)
            label.append(float(line[1]))

    d_list = []
    for i,j in zip(pred,label):
        d_list.append(abs(i-j))

    print(np.mean(d_list))
    print('')

    Hierarchical_mae(similarity_list,d_list)
    print('')

In [13]:
with open(file_name,'a') as f:
    for a,b,c,d,e,f in zip(train_RMSE_lis,train_MAE_lis,val_RMSE_lis,val_MAE_lis,test_RMSE_lis,test_MAE_lis):
        f.write(a)
        f.write('\t')
        f.write(b)
        f.write('\t')
        f.write(c)
        f.write('\t')
        f.write(d)
        f.write('\t')
        f.write(e)
        f.write('\t')
        f.write(f)
        f.write('\n')

SyntaxError: invalid syntax (<ipython-input-13-20f7fc593900>, line 2)

In [None]:
6比5好比4好

In [None]:
7比6差 6的时候是0.558(2层)

In [None]:
attentve FP 最好的应该也是6层

In [None]:
在搜索下步数