In [28]:
from preprocess import dataset,data_deal
from sklearn.preprocessing import MinMaxScaler
from models import AE_MLP,get_models
from torch.utils.data import DataLoader,Dataset
from sklearn.model_selection import KFold
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
import random
from TrainandTest import train_MLP,test_MLP
import pandas as pd


def same_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# Config

In [29]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
file = './data/total 12 data.xlsx'
data = pd.read_excel(file,engine="openpyxl")
data = np.array(data)
l = len(data)
para_path = './save/parameter'
batch_size = 300
plot = True
np.set_printoptions(threshold=np.sys.maxsize)
times = 0
best_seed1= best_seed2=0
best_rmse = rmse = 5

# Train

In [30]:
with open('config.json','r',encoding='utf-8')as f:
    config_list = json.load(f)
flag =  3
config = config_list[3]
Type = config['Type']
N_size = config['N_size']
lr = config['lr']
k = config['k']
alpha = config['alpha']
input_size = config['input_size']
epsilon = config['epsilon']
seed1,seed2 = config['seed1'],config['seed2']
while rmse >=0.2937:
    rmse = 0
    total_loss = 0

    seed1 = random.randint(0, 4294967295)
    seed2 = random.randint(0, 4294967295)

    kf = KFold(n_splits=k ,shuffle=True,random_state=seed1)

    total_data = dataset(data,Type,N_size,input_size)
    total_data = np.array(total_data,dtype = object)

    X = total_data[:,0]
    y = total_data[:,1].astype(np.float32)

    temp_loss = 0
    l1= X[0].shape[0]
    l2 =len(X)
    results_df = pd.DataFrame()
    for fold, (train_index, test_index) in enumerate(kf.split(X)):
        same_seeds(seed2)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        model = get_models(config['model'],input_size,epsilon)
        scaler = MinMaxScaler()
        X_train,X_test,y_train,y_test,scaler = data_deal(X_train,X_test,y_train,y_test,scaler,input_size,flag)
        train_MLP(X_train,y_train,model,lr,alpha,fold,para_path,flag)
        loss,results_df = test_MLP(X_test,y_test,model,fold,para_path,scaler,results_df,l1,flag)
        total_loss += loss
        temp_loss += loss

    rmse = (temp_loss/l2)**0.5
    rmse = float(rmse)
    times += 1
    
    if best_rmse>=rmse:
        best_rmse=rmse
        best_seed1=seed1
        best_seed2=seed2
        
    print(f'Times:{times},RMSE:{rmse:.10f},seed1:{seed1},seed2:{seed2},{best_rmse},{best_seed1},{best_seed2}')


Times:1,RMSE:0.3759529889,seed1:1818265470,seed2:3755615731,0.37595298886299133,1818265470,3755615731
Times:2,RMSE:0.3397589028,seed1:343457569,seed2:2749207437,0.33975890278816223,343457569,2749207437
Times:3,RMSE:0.3580758274,seed1:3737404468,seed2:3792315240,0.33975890278816223,343457569,2749207437
Times:4,RMSE:0.3384208381,seed1:3015551513,seed2:1224842713,0.3384208381175995,3015551513,1224842713
Times:5,RMSE:0.3418034017,seed1:745616444,seed2:254532726,0.3384208381175995,3015551513,1224842713
Times:6,RMSE:0.3668203056,seed1:3558210728,seed2:2433911995,0.3384208381175995,3015551513,1224842713
Times:7,RMSE:0.3693297505,seed1:432668591,seed2:763373117,0.3384208381175995,3015551513,1224842713
Times:8,RMSE:0.3681012392,seed1:2056190780,seed2:3792590504,0.3384208381175995,3015551513,1224842713
Times:9,RMSE:0.4004748166,seed1:1501079486,seed2:1813733013,0.3384208381175995,3015551513,1224842713
Times:10,RMSE:0.4017687142,seed1:3645590306,seed2:2039803819,0.3384208381175995,3015551513,1224

Times:82,RMSE:0.3257442117,seed1:1981230013,seed2:1143836477,0.3257442116737366,1981230013,1143836477
Times:83,RMSE:0.3738317490,seed1:958022201,seed2:871751451,0.3257442116737366,1981230013,1143836477
Times:84,RMSE:0.3697520792,seed1:2336610535,seed2:3312851000,0.3257442116737366,1981230013,1143836477
Times:85,RMSE:0.4017309248,seed1:1420557383,seed2:950136919,0.3257442116737366,1981230013,1143836477
Times:86,RMSE:0.3532688916,seed1:378597957,seed2:26157991,0.3257442116737366,1981230013,1143836477
Times:87,RMSE:0.3654958904,seed1:1979573590,seed2:1263600015,0.3257442116737366,1981230013,1143836477
Times:88,RMSE:0.3732536435,seed1:1729089035,seed2:1205723232,0.3257442116737366,1981230013,1143836477
Times:89,RMSE:0.3843051195,seed1:2436092636,seed2:1746461732,0.3257442116737366,1981230013,1143836477
Times:90,RMSE:0.3479238451,seed1:3745055141,seed2:3978789464,0.3257442116737366,1981230013,1143836477
Times:91,RMSE:0.3738873601,seed1:3192882552,seed2:383721485,0.3257442116737366,198123001

KeyboardInterrupt: 