In [1]:
import torch

import dgl
import dgllife
from torch.utils.data import DataLoader
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
import torch.optim as optim
import numpy as np
import random 
import pandas as pd

from torch.nn.utils import clip_grad_norm

Using backend: pytorch


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

seed = 0
random.seed(seed) 
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

node_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='h')

def load_data(file_name,batch_size = 128,shuffle = True,split_ratio = False):
    dataset = []
    with open(file_name) as f:
        for line in f.readlines():
            line = line.replace('\n','').split('\t')
            g = smiles_to_bigraph(smiles=line[0], 
                          node_featurizer=node_featurizer,
                          edge_featurizer=edge_featurizer,
                          canonical_atom_order= False)
            dataset.append((g,float(line[1])))
            
    if split_ratio == False:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        return dataloader
    else:
        random.shuffle(dataset)
        length = len(dataset)
        dataloader_list = []
        for i in split_ratio:
            num = round(length * i)
            dataset_part = dataset[:num]
            dataset = dataset[num:]
            dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
            dataloader_list.append(dataloader)
        dataset_part = dataset
        dataloader = DataLoader(dataset_part, batch_size=batch_size, shuffle=shuffle,collate_fn=collate)
        dataloader_list.append(dataloader)
        return dataloader_list

In [4]:
batch_size = 1024
learning_rate = 0.0003
weight_decay = 0.0003

In [5]:
layer_num = 6
num_timesteps = 3

In [6]:
# model = dgllife.model.model_zoo.attentivefp_predictor.AttentiveFPPredictor(node_feat_size = 74,
#                                                                            edge_feat_size = 12,
#                                                                            num_layers= layer_num,
#                                                                            num_timesteps= num_timesteps,
#                                                                            graph_feat_size = 200,
#                                                                            dropout=0.2).to(device)

In [None]:
hidden_feats = [128,128,128]
classifier_hidden_feats = 128
predictor_hidden_feats=128

In [None]:
model = dgllife.model.model_zoo.gcn_predictor.GCNPredictor(in_feats = 74,
                                                           hidden_feats = [128,128,128],
                                                           classifier_hidden_feats=128, 
                                                           classifier_dropout=0.0, 
                                                           n_tasks=1, 
                                                           predictor_hidden_feats=128, 
                                                           predictor_dropout=0.0).to(device)

In [7]:
loss_func = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [8]:
train_loader = load_data('./Dataset/basic_train_0.70_smiles.txt',batch_size = batch_size)
val_loader = load_data('./Dataset/basic_val_0.15_smiles.txt',batch_size = batch_size)
test_loader = load_data('./Dataset/basic_test_0.15_smiles.txt',batch_size = batch_size)
print(len(train_loader.dataset))

5905


In [9]:
epoch_num = 1000

In [10]:
train_RMSE_lis = []
train_MAE_lis = []
val_RMSE_lis = []
val_MAE_lis = []
test_RMSE_lis = []
test_MAE_lis = []


file_name = './Logger/Attentive_FP_search_layer_num_{}_num_timesteps_{}.txt'.format(layer_num,num_timesteps)

header = 'epoch:\ttrain_RMSD:\ttrain_MAE:\tval_RMSD:\tval_MAE:\ttest_RMSD:\ttest_MAE:'
print(header)

with open(file_name,'w+') as f:
    f.write(header)
    f.write('\n')

for epoch in range(epoch_num):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N


        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N
        
        
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(test_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(test_loader.dataset)
        test_RMSE = (SSE.item()/ N)**0.5
        test_MAE = SAE.item()/N
        

        
    log = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4),round(test_RMSE,4),round(test_MAE,4))
    print(log)

    train_RMSE_lis.append(train_RMSE)
    train_MAE_lis.append(train_MAE)
    val_RMSE_lis.append(val_RMSE)
    val_MAE_lis.append(val_MAE)
    test_RMSE_lis.append(test_RMSE)
    test_MAE_lis.append(test_MAE)

    with open(file_name,'a') as f:
        f.write(log)
        f.write('\n')

epoch:	train_RMSD:	train_MAE:	val_RMSD:	val_MAE:	test_RMSD:	test_MAE:




0	6.8288	6.229	6.9825	6.4016	6.9579	6.3651
1	5.5782	4.8867	5.7314	5.0739	5.6905	5.0099
2	4.1841	3.367	4.1416	3.3149	4.0817	3.3146
3	3.2262	2.5818	3.2239	2.5663	3.1629	2.5585
4	3.1403	2.6176	3.224	2.6868	3.1455	2.6747
5	3.0734	2.463	3.088	2.4494	3.0044	2.4456
6	3.0156	2.4664	3.0611	2.4887	2.9748	2.4825
7	2.9653	2.4206	3.0094	2.4444	2.9227	2.4377
8	2.8578	2.2799	2.8698	2.2726	2.7843	2.2714
9	2.666	2.1098	2.6758	2.1084	2.6014	2.1079
10	2.4718	1.8792	2.4446	1.8523	2.3954	1.8514
11	2.4119	1.8528	2.3889	1.8303	2.3671	1.8346
12	2.3994	1.8218	2.3724	1.7907	2.3486	1.7856
13	2.3767	1.803	2.3636	1.7783	2.3234	1.7669
14	2.359	1.7656	2.3588	1.7438	2.3008	1.7295
15	2.3495	1.7586	2.3589	1.7394	2.2917	1.7258
16	2.3422	1.7843	2.364	1.7741	2.2942	1.7571
17	2.3211	1.7491	2.3295	1.7307	2.2662	1.7204
18	2.3053	1.7225	2.3038	1.702	2.2453	1.6932
19	2.2923	1.7452	2.3047	1.734	2.2405	1.7218
20	2.2638	1.7049	2.2716	1.6935	2.207	1.6829
21	2.2382	1.6619	2.2467	1.6503	2.1758	1.6388
22	2.2059	1.6455	2.222	1.6408	2.

183	0.9874	0.7102	1.1401	0.7577	1.0989	0.7642
184	0.9795	0.6993	1.1303	0.7468	1.088	0.7513
185	0.9889	0.7096	1.1411	0.7578	1.1036	0.7621
186	0.9772	0.7005	1.1307	0.7455	1.0901	0.7545
187	0.974	0.696	1.1331	0.7441	1.0853	0.7477
188	0.9762	0.7011	1.128	0.746	1.0874	0.752
189	0.9772	0.6988	1.1293	0.7481	1.0867	0.7534
190	1.0635	0.7913	1.1998	0.82	1.161	0.8396
191	1.0404	0.7621	1.1844	0.8115	1.1506	0.8156
192	0.9751	0.6997	1.132	0.7448	1.0904	0.7547
193	0.983	0.7029	1.1264	0.752	1.0792	0.7524
194	0.9652	0.688	1.1304	0.7368	1.0858	0.7419
195	0.9837	0.7128	1.152	0.7651	1.106	0.7701
196	0.962	0.6912	1.1298	0.7417	1.0822	0.747
197	0.9595	0.6833	1.1291	0.7317	1.0877	0.7404
198	0.9687	0.6896	1.1145	0.7361	1.0689	0.7428
199	0.9532	0.6761	1.1291	0.7237	1.0855	0.7333
200	0.9656	0.6936	1.1329	0.7511	1.0824	0.7528
201	0.9543	0.6815	1.1346	0.7326	1.0807	0.7394
202	0.9526	0.6816	1.125	0.7346	1.0789	0.7419
203	1.0343	0.7739	1.1964	0.82	1.1448	0.8312
204	0.9778	0.7077	1.1401	0.7622	1.0996	0.7638
205	0.96

364	0.7811	0.5502	1.0713	0.633	0.9927	0.6396
365	0.7727	0.5388	1.076	0.6275	0.9997	0.6317
366	0.7731	0.5391	1.0684	0.6278	0.9887	0.6308
367	0.7694	0.5384	1.0835	0.6342	1.0022	0.6336
368	0.7681	0.5356	1.0818	0.6332	0.9949	0.6277
369	0.7762	0.545	1.0771	0.6347	0.9981	0.6337
370	0.7714	0.5415	1.0708	0.6317	0.9981	0.6366
371	0.7669	0.5391	1.071	0.6274	0.9906	0.6297
372	0.7684	0.5361	1.0646	0.6284	0.975	0.6292
373	0.7769	0.549	1.0746	0.6373	0.9925	0.6393
374	0.7992	0.5662	1.1074	0.661	1.0324	0.6583
375	0.7946	0.5669	1.1189	0.6671	1.0392	0.6645
376	0.7967	0.5628	1.0898	0.6541	1.0148	0.648
377	0.7665	0.537	1.0684	0.6296	0.9833	0.6315
378	0.7849	0.5555	1.0778	0.6438	0.9936	0.6425
379	0.7878	0.5597	1.0832	0.6508	1.0082	0.6513
380	0.8165	0.5883	1.0813	0.6645	0.9983	0.6659
381	0.8289	0.6029	1.1172	0.6959	1.0427	0.6893
382	0.8067	0.5858	1.0912	0.6694	1.0084	0.674
383	0.8196	0.5839	1.1151	0.6716	1.033	0.6643
384	0.7979	0.576	1.1026	0.6674	1.0184	0.6676
385	0.7653	0.5351	1.0801	0.6315	0.9935	0.6312


545	0.6746	0.464	1.0695	0.6104	0.9347	0.5863
546	0.6977	0.4919	1.1095	0.6411	0.9694	0.6099
547	0.6948	0.4944	1.0902	0.6341	0.955	0.6084
548	0.7003	0.4843	1.1208	0.6326	0.9839	0.5995
549	0.6784	0.4677	1.0671	0.6097	0.934	0.5822
550	0.705	0.4994	1.0998	0.6336	0.9828	0.6143
551	0.6851	0.4753	1.0457	0.6073	0.916	0.5867
552	0.6925	0.4803	1.0809	0.6238	0.9574	0.6012
553	0.6807	0.4715	1.0556	0.6078	0.9238	0.5843
554	0.683	0.4755	1.0731	0.6153	0.9601	0.6012
555	0.6752	0.4644	1.0624	0.6061	0.94	0.5857
556	0.6699	0.4614	1.0814	0.6099	0.9467	0.5824
557	0.6876	0.4736	1.0904	0.6212	0.9715	0.6016
558	0.6689	0.458	1.0656	0.6064	0.9439	0.5817
559	0.6814	0.4771	1.0762	0.618	0.9529	0.5919
560	0.6721	0.461	1.0755	0.6137	0.9455	0.5849
561	0.6812	0.4712	1.079	0.6186	0.9533	0.5971
562	0.6811	0.4713	1.0829	0.618	0.9538	0.5971
563	0.6657	0.4561	1.0806	0.6099	0.9378	0.5797
564	0.7399	0.5247	1.1154	0.662	0.9883	0.6321
565	0.6785	0.4688	1.0887	0.6168	0.9585	0.5928
566	0.6767	0.4696	1.0712	0.6102	0.939	0.5889
567

726	0.6093	0.4177	1.0954	0.6069	0.9434	0.572
727	0.6041	0.4168	1.0708	0.6077	0.9065	0.5663
728	0.6048	0.4132	1.0719	0.6052	0.9226	0.5645
729	0.6064	0.4146	1.0782	0.6032	0.931	0.5698
730	0.6452	0.4645	1.0894	0.6384	0.9464	0.6087
731	0.6238	0.4333	1.0953	0.6131	0.9425	0.5769
732	0.6169	0.4283	1.0681	0.6055	0.9238	0.5778
733	0.6051	0.4163	1.0765	0.6037	0.928	0.5743
734	0.6127	0.4273	1.0817	0.6117	0.9162	0.576
735	0.5954	0.4068	1.0772	0.6014	0.918	0.5594
736	0.6193	0.4292	1.0778	0.614	0.9204	0.5811
737	0.6271	0.4493	1.0985	0.6378	0.9391	0.5927
738	0.6108	0.4242	1.0879	0.6164	0.9216	0.5721
739	0.624	0.4348	1.1084	0.627	0.9513	0.5853
740	0.6209	0.4302	1.0811	0.6168	0.9233	0.5875
741	0.6504	0.4522	1.1484	0.6373	0.9848	0.5881
742	0.6527	0.4698	1.1034	0.6549	0.9434	0.6099
743	0.6484	0.465	1.1314	0.6494	0.9807	0.6083
744	0.6151	0.4297	1.0818	0.6158	0.9344	0.5815
745	0.617	0.4363	1.0822	0.6182	0.9329	0.584
746	0.6116	0.4269	1.0716	0.6034	0.9221	0.5744
747	0.6184	0.4342	1.0718	0.6167	0.9294	0.5901

908	0.5347	0.3691	1.0866	0.6049	0.8968	0.5567
909	0.5727	0.4078	1.0827	0.6211	0.9122	0.5778
910	0.566	0.4044	1.0966	0.6302	0.9107	0.5792
911	0.6193	0.4547	1.096	0.6645	0.927	0.6195
912	0.5615	0.3937	1.0809	0.6114	0.8948	0.5626
913	0.5543	0.3885	1.0772	0.6072	0.9019	0.5716
914	0.5524	0.3883	1.0873	0.6128	0.8908	0.5662
915	0.5495	0.3858	1.0706	0.6125	0.8961	0.5665
916	0.5355	0.3707	1.0785	0.6016	0.8943	0.5537
917	0.5439	0.379	1.0779	0.6068	0.8838	0.5572
918	0.5403	0.3754	1.0854	0.6106	0.9115	0.5619
919	0.5673	0.4065	1.0903	0.6251	0.9132	0.5795
920	0.5381	0.3719	1.0675	0.601	0.8835	0.5553
921	0.5385	0.3742	1.0738	0.603	0.893	0.5601
922	0.5355	0.3719	1.0667	0.6061	0.8976	0.5661
923	0.5623	0.4018	1.0795	0.6261	0.9059	0.5853
924	0.5344	0.3692	1.0779	0.6035	0.8977	0.5588
925	0.5352	0.3728	1.0769	0.6125	0.8978	0.5639
926	0.5822	0.4218	1.1191	0.6482	0.9466	0.5961
927	0.5488	0.3864	1.0817	0.6139	0.8991	0.5708
928	0.5461	0.383	1.0806	0.612	0.9023	0.5643
929	0.5309	0.3681	1.0878	0.6055	0.9084	0.56

In [11]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate/10,weight_decay=weight_decay)

In [12]:
for epoch in range(1000,1500):
    model.train()
    for iter, (bg, label) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.reshape(-1,1).to(device)
        prediction = model(bg,bg.ndata['h'], bg.edata['h'])
        loss = loss_func(prediction, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm(model.parameters(),max_norm=20,norm_type=2)
        optimizer.step()

    with torch.no_grad():
        model.eval()
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(train_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(train_loader.dataset)
        train_RMSE = (SSE.item()/ N)**0.5
        train_MAE = SAE.item()/N


        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(val_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(val_loader.dataset)
        val_RMSE = (SSE.item()/ N)**0.5
        val_MAE = SAE.item()/N
        
        
        SSE = 0
        SAE = 0 
        for iter, (bg, label) in enumerate(test_loader):
            bg = bg.to(device)
            prediction = model(bg,bg.ndata['h'], bg.edata['h'])
            prediction = torch.squeeze(prediction)
            label = label.to(device)
            loss = prediction-label
            SSE += sum(loss**2)
            SAE += sum(torch.abs(loss))
        N = len(test_loader.dataset)
        test_RMSE = (SSE.item()/ N)**0.5
        test_MAE = SAE.item()/N
        

        
    log = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch,round(train_RMSE,4),round(train_MAE,4),round(val_RMSE,4),round(val_MAE,4),round(test_RMSE,4),round(test_MAE,4))
    print(log)

    train_RMSE_lis.append(train_RMSE)
    train_MAE_lis.append(train_MAE)
    val_RMSE_lis.append(val_RMSE)
    val_MAE_lis.append(val_MAE)
    test_RMSE_lis.append(test_RMSE)
    test_MAE_lis.append(test_MAE)

    with open(file_name,'a') as f:
        f.write(log)
        f.write('\n')

  # Remove the CWD from sys.path while we load stuff.


1000	0.5036	0.3471	1.0791	0.6006	0.9054	0.5544
1001	0.4953	0.3379	1.0741	0.5959	0.9053	0.5501
1002	0.4966	0.3407	1.0749	0.5966	0.9024	0.5516
1003	0.4952	0.3376	1.0826	0.597	0.9055	0.5492
1004	0.4905	0.3354	1.0754	0.5935	0.8931	0.5477
1005	0.4964	0.3421	1.0816	0.6022	0.8999	0.5549
1006	0.4938	0.3367	1.0774	0.5975	0.8967	0.5504
1007	0.4907	0.3338	1.0826	0.5971	0.904	0.5495
1008	0.4933	0.3385	1.0776	0.5984	0.9001	0.5537
1009	0.4933	0.3369	1.0752	0.5995	0.8978	0.5542
1010	0.493	0.3361	1.08	0.6007	0.9036	0.552
1011	0.488	0.3326	1.0754	0.595	0.8968	0.5481
1012	0.492	0.3355	1.0807	0.5999	0.9008	0.5503
1013	0.4888	0.3327	1.0701	0.5964	0.89	0.5494
1014	0.4887	0.3331	1.0792	0.5994	0.8971	0.5487
1015	0.4866	0.3316	1.0776	0.599	0.895	0.5491
1016	0.4875	0.3314	1.081	0.601	0.8995	0.5526
1017	0.4875	0.3318	1.0842	0.6003	0.9004	0.5515
1018	0.4876	0.3316	1.0819	0.5984	0.898	0.5501
1019	0.4863	0.3308	1.0818	0.5975	0.898	0.5485
1020	0.4876	0.3324	1.0835	0.5991	0.8999	0.5496
1021	0.4915	0.3353	1.088	0.604

1177	0.4696	0.3184	1.0867	0.6044	0.8985	0.5504
1178	0.4736	0.3235	1.0931	0.6065	0.9036	0.5534
1179	0.4713	0.3214	1.0924	0.6057	0.9013	0.5518
1180	0.4708	0.3195	1.0881	0.6051	0.8996	0.552
1181	0.4756	0.3244	1.0922	0.6106	0.9004	0.5539
1182	0.4732	0.3231	1.0982	0.6082	0.9046	0.5515
1183	0.476	0.3253	1.0969	0.6093	0.9061	0.5525
1184	0.4692	0.3197	1.0839	0.6044	0.8916	0.548
1185	0.4717	0.3226	1.0881	0.6082	0.8969	0.5529
1186	0.4708	0.3201	1.0913	0.6053	0.9026	0.5513
1187	0.4781	0.3283	1.0954	0.6102	0.9086	0.5538
1188	0.474	0.3246	1.0894	0.6069	0.8992	0.5515
1189	0.474	0.3246	1.0975	0.6099	0.906	0.5527
1190	0.4698	0.3201	1.0892	0.6075	0.8993	0.5509
1191	0.474	0.3242	1.0913	0.6084	0.9032	0.5518
1192	0.4735	0.3239	1.0916	0.6059	0.9024	0.5518
1193	0.4702	0.3201	1.0915	0.6044	0.9035	0.5526
1194	0.4704	0.3212	1.0841	0.6022	0.8981	0.5522
1195	0.4758	0.3264	1.0877	0.607	0.9059	0.5557
1196	0.4704	0.3207	1.0774	0.6017	0.8925	0.5526
1197	0.4733	0.3235	1.0931	0.6093	0.9086	0.5543
1198	0.469	0.3198	1.0

1354	0.457	0.3104	1.0831	0.6055	0.8924	0.5507
1355	0.4602	0.3137	1.092	0.6115	0.9049	0.5521
1356	0.4572	0.3108	1.0845	0.6046	0.8977	0.5485
1357	0.458	0.3122	1.0815	0.603	0.8944	0.5512
1358	0.458	0.3116	1.0869	0.6068	0.8969	0.551
1359	0.4568	0.3104	1.0897	0.6045	0.8942	0.5488
1360	0.4639	0.3166	1.0959	0.6114	0.9015	0.5528
1361	0.4595	0.3131	1.0882	0.6109	0.8932	0.5511
1362	0.4612	0.3153	1.0929	0.6111	0.8999	0.5536
1363	0.4554	0.3084	1.0857	0.6045	0.8925	0.5493
1364	0.4587	0.3122	1.0907	0.6091	0.9014	0.5509
1365	0.4629	0.3162	1.0931	0.6111	0.9058	0.554
1366	0.4591	0.3123	1.0944	0.609	0.9061	0.5502
1367	0.459	0.3139	1.089	0.6088	0.8942	0.5494
1368	0.4618	0.3153	1.0969	0.6121	0.9031	0.5528
1369	0.4553	0.3093	1.0917	0.6049	0.897	0.5491
1370	0.4569	0.3113	1.089	0.6053	0.898	0.5512
1371	0.4577	0.3121	1.0911	0.6063	0.902	0.5519
1372	0.4568	0.3112	1.0934	0.6066	0.9025	0.5503
1373	0.4616	0.3156	1.0908	0.6088	0.8997	0.5502
1374	0.4727	0.3278	1.0953	0.6177	0.9086	0.5601
1375	0.465	0.3196	1.0945	0.

In [13]:
with open(file_name,'a') as f:
    for a,b,c,d,e,f in zip(train_RMSE_lis,train_MAE_lis,val_RMSE_lis,val_MAE_lis,test_RMSE_lis,test_MAE_lis):
        f.write(a)
        f.write('\t')
        f.write(b)
        f.write('\t')
        f.write(c)
        f.write('\t')
        f.write(d)
        f.write('\t')
        f.write(e)
        f.write('\t')
        f.write(f)
        f.write('\n')

SyntaxError: invalid syntax (<ipython-input-13-20f7fc593900>, line 2)

In [None]:
6比5好比4好

In [None]:
7比6差 6的时候是0.558(2层)

In [None]:
attentve FP 最好的应该也是6层

In [None]:
在搜索下步数