In [1]:
# importing library
import random
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from load_data import *
from utils import *
from stgcn import *
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error 

In [2]:
# setting a seed
torch.manual_seed(2333)
torch.cuda.manual_seed(2333)
np.random.seed(2333)
random.seed(2333)
torch.backends.cudnn.deterministic = True

In [3]:
# setting gpu
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [4]:
# data path
matrix_path = "dataset/ad.csv"
data_path = "dataset/df_speed.csv"
data_path2 = "dataset/df_tsr.csv"
data_path3 = "dataset/df_brake.csv"
save_path = "save/model.pt"

In [6]:
# setting hyperparameter 
n_route = 15 
day_slot = 144 # 24 * 60 = 1,440  10 miniute = unit
n_train, n_val, n_test = 49, 6, 6 
Ks, Kt = 3, 3
blocks = [[1, 32, 64], [64, 32, 128]]
drop_prob = 0.5
batch_size = 256
epochs = 200
lr = 1e-3

In [8]:
W = load_matrix(matrix_path)
L = scaled_laplacian(W) 
Lk = cheb_poly(L, Ks)
Lk = torch.Tensor(Lk.astype(np.float32)).to(device)

In [9]:
train, val, test = load_data(data_path, n_train * day_slot, n_val * day_slot)
train2, val2, test2 = load_data(data_path2, n_train * day_slot, n_val * day_slot)
train3, val3, test3 = load_data(data_path3, n_train * day_slot, n_val * day_slot)

scaler =  MinMaxScaler()
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

scaler2 =  MinMaxScaler()
train2 = scaler2.fit_transform(train2)
val2 = scaler2.transform(val2)
test2 = scaler2.transform(test2)

scaler3 =  MinMaxScaler()
train3 = scaler3.fit_transform(train3)
val3 = scaler3.transform(val3)
test3 = scaler3.transform(test3)

In [10]:
MAE = [None] * 11
MAPE = [None] * 11
RMSE = [None] * 11 

for j in range(1, 11):
    n_his = 12 
    n_pred = j
    
    x_train, y_train = data_transform(train, n_his, n_pred, day_slot, device)
    x_val, y_val = data_transform(val, n_his, n_pred, day_slot, device)
    x_test, y_test = data_transform(test, n_his, n_pred, day_slot, device)

    x_train2, y_train2 = data_transform(train2, n_his, n_pred, day_slot, device)
    x_val2, y_val2 = data_transform(val2, n_his, n_pred, day_slot, device)
    x_test2, y_test2 = data_transform(test2, n_his, n_pred, day_slot, device)
    
    x_train3, y_train3 = data_transform(train3, n_his, n_pred, day_slot, device)
    x_val3, y_val3 = data_transform(val3, n_his, n_pred, day_slot, device)
    x_test3, y_test3 = data_transform(test3, n_his, n_pred, day_slot, device)
    
    train_data = torch.utils.data.TensorDataset(x_train, y_train, x_train2, x_train3)
    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle = False)

    val_data = torch.utils.data.TensorDataset(x_val, y_val, x_val2, x_val3)
    val_iter = torch.utils.data.DataLoader(val_data, batch_size, shuffle = False)

    test_data = torch.utils.data.TensorDataset(x_test, y_test, x_test2, x_test3)
    test_iter = torch.utils.data.DataLoader(test_data, batch_size, shuffle = False)    
    
    criterion = nn.MSELoss()
    model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    

    class EarlyStopping:
        def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
         
            self.patience = patience
            self.verbose = verbose
            self.counter = 0
            self.best_score = None
            self.early_stop = False
            self.val_loss_min = np.Inf
            self.delta = delta
            self.path = path

        def __call__(self, val_loss, model):

            score = -val_loss

            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
            elif score < self.best_score + self.delta:
                self.counter += 1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
                self.counter = 0

        def save_checkpoint(self, val_loss, model):
            if self.verbose:
                print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            torch.save(model.state_dict(), self.path)
            self.val_loss_min = val_loss


    def train_model(model, batch_size, patience, n_epochs):
        train_losses = []
        valid_losses = []
        avg_train_losses = []
        avg_valid_losses = []

        early_stopping = EarlyStopping(patience = patience, verbose = True)

        for epoch in range(1, n_epochs + 1):
            model.train() 
            for batch, (data, targets, data2, data3) in enumerate(train_iter, 1):
                optimizer.zero_grad()    
                output = model(data, data2, data3)
                loss = criterion(output, targets)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())

            model.eval() 
            for data, targets, data2, data3 in val_iter :
                output = model(data, data2, data3)
                loss = criterion(output, targets)
                valid_losses.append(loss.item())

            train_loss = np.average(train_losses)
            valid_loss = np.average(valid_losses)
            avg_train_losses.append(train_loss)
            avg_valid_losses.append(valid_loss)

            epoch_len = len(str(n_epochs))


            print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                         f'train_loss: {train_loss:.5f} ' +
                         f'valid_loss: {valid_loss:.5f}')

            print(print_msg)

            train_losses = []
            valid_losses = []
            early_stopping(valid_loss, model)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        model.load_state_dict(torch.load('checkpoint.pt'))

        return  model, avg_train_losses, avg_valid_losses
    
    patience = 100
    n_epochs = 200
    model, train_loss, valid_loss = train_model(model, batch_size, patience, n_epochs)   

    model.eval()
    with torch.no_grad():
        valid_tensor = x_test
        valid_tensor2 = x_test2
        valid_tensor3 = x_test3
        
        predict = model(valid_tensor, valid_tensor2, valid_tensor3)
    predict = predict.cpu().data.numpy()
    actual_predictions = predict
    actual_predictions = scaler.inverse_transform(actual_predictions)
    
    groun = y_test.cpu().data.numpy()
    groun = scaler.inverse_transform(groun)

    
    groun2 = pd.DataFrame(groun)
    groun2.to_csv('results/ground_' + str(j) + '.csv')
    actual_predictions2 = pd.DataFrame(actual_predictions)
    actual_predictions2.to_csv('results/predictions_' + str(j) + '.csv')
    torch.save(model.state_dict(), 'results/model_' + str(j) + '.pt')
    
    MAE[j-1] = mean_absolute_error(groun, actual_predictions)
    MSE = mean_squared_error(groun, actual_predictions) 
    RMSE[j-1] = np.sqrt(MSE)

[  1/200] train_loss: 0.06079 valid_loss: 0.03605
Validation loss decreased (inf --> 0.036053).  Saving model ...
[  2/200] train_loss: 0.03989 valid_loss: 0.03236
Validation loss decreased (0.036053 --> 0.032357).  Saving model ...
[  3/200] train_loss: 0.03583 valid_loss: 0.03234
Validation loss decreased (0.032357 --> 0.032337).  Saving model ...
[  4/200] train_loss: 0.03462 valid_loss: 0.03370
EarlyStopping counter: 1 out of 100
[  5/200] train_loss: 0.03305 valid_loss: 0.03166
Validation loss decreased (0.032337 --> 0.031662).  Saving model ...
[  6/200] train_loss: 0.03209 valid_loss: 0.03046
Validation loss decreased (0.031662 --> 0.030462).  Saving model ...
[  7/200] train_loss: 0.03125 valid_loss: 0.03074
EarlyStopping counter: 1 out of 100
[  8/200] train_loss: 0.03038 valid_loss: 0.02978
Validation loss decreased (0.030462 --> 0.029782).  Saving model ...
[  9/200] train_loss: 0.03028 valid_loss: 0.02935
Validation loss decreased (0.029782 --> 0.029354).  Saving model ...


[ 90/200] train_loss: 0.02491 valid_loss: 0.02802
EarlyStopping counter: 65 out of 100
[ 91/200] train_loss: 0.02461 valid_loss: 0.02799
EarlyStopping counter: 66 out of 100
[ 92/200] train_loss: 0.02449 valid_loss: 0.02800
EarlyStopping counter: 67 out of 100
[ 93/200] train_loss: 0.02438 valid_loss: 0.02813
EarlyStopping counter: 68 out of 100
[ 94/200] train_loss: 0.02434 valid_loss: 0.02828
EarlyStopping counter: 69 out of 100
[ 95/200] train_loss: 0.02435 valid_loss: 0.02839
EarlyStopping counter: 70 out of 100
[ 96/200] train_loss: 0.02441 valid_loss: 0.02824
EarlyStopping counter: 71 out of 100
[ 97/200] train_loss: 0.02440 valid_loss: 0.02819
EarlyStopping counter: 72 out of 100
[ 98/200] train_loss: 0.02429 valid_loss: 0.02819
EarlyStopping counter: 73 out of 100
[ 99/200] train_loss: 0.02428 valid_loss: 0.02841
EarlyStopping counter: 74 out of 100
[100/200] train_loss: 0.02420 valid_loss: 0.02816
EarlyStopping counter: 75 out of 100
[101/200] train_loss: 0.02422 valid_loss: 0

[ 57/200] train_loss: 0.02931 valid_loss: 0.03267
EarlyStopping counter: 41 out of 100
[ 58/200] train_loss: 0.02942 valid_loss: 0.03327
EarlyStopping counter: 42 out of 100
[ 59/200] train_loss: 0.02964 valid_loss: 0.03284
EarlyStopping counter: 43 out of 100
[ 60/200] train_loss: 0.02936 valid_loss: 0.03286
EarlyStopping counter: 44 out of 100
[ 61/200] train_loss: 0.02905 valid_loss: 0.03294
EarlyStopping counter: 45 out of 100
[ 62/200] train_loss: 0.02916 valid_loss: 0.03275
EarlyStopping counter: 46 out of 100
[ 63/200] train_loss: 0.02888 valid_loss: 0.03285
EarlyStopping counter: 47 out of 100
[ 64/200] train_loss: 0.02901 valid_loss: 0.03281
EarlyStopping counter: 48 out of 100
[ 65/200] train_loss: 0.02893 valid_loss: 0.03301
EarlyStopping counter: 49 out of 100
[ 66/200] train_loss: 0.02890 valid_loss: 0.03328
EarlyStopping counter: 50 out of 100
[ 67/200] train_loss: 0.02891 valid_loss: 0.03305
EarlyStopping counter: 51 out of 100
[ 68/200] train_loss: 0.02880 valid_loss: 0

[ 32/200] train_loss: 0.03088 valid_loss: 0.03433
EarlyStopping counter: 6 out of 100
[ 33/200] train_loss: 0.03073 valid_loss: 0.03455
EarlyStopping counter: 7 out of 100
[ 34/200] train_loss: 0.03049 valid_loss: 0.03536
EarlyStopping counter: 8 out of 100
[ 35/200] train_loss: 0.03062 valid_loss: 0.03509
EarlyStopping counter: 9 out of 100
[ 36/200] train_loss: 0.03034 valid_loss: 0.03516
EarlyStopping counter: 10 out of 100
[ 37/200] train_loss: 0.03031 valid_loss: 0.03550
EarlyStopping counter: 11 out of 100
[ 38/200] train_loss: 0.03045 valid_loss: 0.03514
EarlyStopping counter: 12 out of 100
[ 39/200] train_loss: 0.03020 valid_loss: 0.03518
EarlyStopping counter: 13 out of 100
[ 40/200] train_loss: 0.03032 valid_loss: 0.03441
EarlyStopping counter: 14 out of 100
[ 41/200] train_loss: 0.03024 valid_loss: 0.03367
EarlyStopping counter: 15 out of 100
[ 42/200] train_loss: 0.03010 valid_loss: 0.03397
EarlyStopping counter: 16 out of 100
[ 43/200] train_loss: 0.02996 valid_loss: 0.034

[126/200] train_loss: 0.02653 valid_loss: 0.03462
EarlyStopping counter: 70 out of 100
[127/200] train_loss: 0.02650 valid_loss: 0.03524
EarlyStopping counter: 71 out of 100
[128/200] train_loss: 0.02652 valid_loss: 0.03492
EarlyStopping counter: 72 out of 100
[129/200] train_loss: 0.02641 valid_loss: 0.03491
EarlyStopping counter: 73 out of 100
[130/200] train_loss: 0.02647 valid_loss: 0.03460
EarlyStopping counter: 74 out of 100
[131/200] train_loss: 0.02642 valid_loss: 0.03474
EarlyStopping counter: 75 out of 100
[132/200] train_loss: 0.02660 valid_loss: 0.03507
EarlyStopping counter: 76 out of 100
[133/200] train_loss: 0.02641 valid_loss: 0.03486
EarlyStopping counter: 77 out of 100
[134/200] train_loss: 0.02652 valid_loss: 0.03479
EarlyStopping counter: 78 out of 100
[135/200] train_loss: 0.02640 valid_loss: 0.03520
EarlyStopping counter: 79 out of 100
[136/200] train_loss: 0.02658 valid_loss: 0.03600
EarlyStopping counter: 80 out of 100
[137/200] train_loss: 0.02665 valid_loss: 0

[ 61/200] train_loss: 0.02965 valid_loss: 0.03469
EarlyStopping counter: 1 out of 100
[ 62/200] train_loss: 0.02963 valid_loss: 0.03491
EarlyStopping counter: 2 out of 100
[ 63/200] train_loss: 0.02958 valid_loss: 0.03496
EarlyStopping counter: 3 out of 100
[ 64/200] train_loss: 0.02954 valid_loss: 0.03480
EarlyStopping counter: 4 out of 100
[ 65/200] train_loss: 0.02955 valid_loss: 0.03473
EarlyStopping counter: 5 out of 100
[ 66/200] train_loss: 0.02943 valid_loss: 0.03468
EarlyStopping counter: 6 out of 100
[ 67/200] train_loss: 0.02950 valid_loss: 0.03471
EarlyStopping counter: 7 out of 100
[ 68/200] train_loss: 0.02936 valid_loss: 0.03449
Validation loss decreased (0.034538 --> 0.034485).  Saving model ...
[ 69/200] train_loss: 0.02950 valid_loss: 0.03454
EarlyStopping counter: 1 out of 100
[ 70/200] train_loss: 0.02939 valid_loss: 0.03450
EarlyStopping counter: 2 out of 100
[ 71/200] train_loss: 0.02925 valid_loss: 0.03491
EarlyStopping counter: 3 out of 100
[ 72/200] train_loss:

[155/200] train_loss: 0.02611 valid_loss: 0.03625
EarlyStopping counter: 87 out of 100
[156/200] train_loss: 0.02613 valid_loss: 0.03644
EarlyStopping counter: 88 out of 100
[157/200] train_loss: 0.02619 valid_loss: 0.03646
EarlyStopping counter: 89 out of 100
[158/200] train_loss: 0.02595 valid_loss: 0.03718
EarlyStopping counter: 90 out of 100
[159/200] train_loss: 0.02611 valid_loss: 0.03701
EarlyStopping counter: 91 out of 100
[160/200] train_loss: 0.02596 valid_loss: 0.03691
EarlyStopping counter: 92 out of 100
[161/200] train_loss: 0.02609 valid_loss: 0.03609
EarlyStopping counter: 93 out of 100
[162/200] train_loss: 0.02612 valid_loss: 0.03677
EarlyStopping counter: 94 out of 100
[163/200] train_loss: 0.02601 valid_loss: 0.03754
EarlyStopping counter: 95 out of 100
[164/200] train_loss: 0.02606 valid_loss: 0.03748
EarlyStopping counter: 96 out of 100
[165/200] train_loss: 0.02619 valid_loss: 0.03769
EarlyStopping counter: 97 out of 100
[166/200] train_loss: 0.02617 valid_loss: 0

[ 76/200] train_loss: 0.02894 valid_loss: 0.03198
EarlyStopping counter: 40 out of 100
[ 77/200] train_loss: 0.02899 valid_loss: 0.03194
EarlyStopping counter: 41 out of 100
[ 78/200] train_loss: 0.02904 valid_loss: 0.03148
EarlyStopping counter: 42 out of 100
[ 79/200] train_loss: 0.02893 valid_loss: 0.03149
EarlyStopping counter: 43 out of 100
[ 80/200] train_loss: 0.02863 valid_loss: 0.03152
EarlyStopping counter: 44 out of 100
[ 81/200] train_loss: 0.02848 valid_loss: 0.03124
EarlyStopping counter: 45 out of 100
[ 82/200] train_loss: 0.02847 valid_loss: 0.03134
EarlyStopping counter: 46 out of 100
[ 83/200] train_loss: 0.02835 valid_loss: 0.03154
EarlyStopping counter: 47 out of 100
[ 84/200] train_loss: 0.02839 valid_loss: 0.03157
EarlyStopping counter: 48 out of 100
[ 85/200] train_loss: 0.02823 valid_loss: 0.03137
EarlyStopping counter: 49 out of 100
[ 86/200] train_loss: 0.02834 valid_loss: 0.03144
EarlyStopping counter: 50 out of 100
[ 87/200] train_loss: 0.02827 valid_loss: 0

[ 30/200] train_loss: 0.03068 valid_loss: 0.03128
EarlyStopping counter: 7 out of 100
[ 31/200] train_loss: 0.03057 valid_loss: 0.03228
EarlyStopping counter: 8 out of 100
[ 32/200] train_loss: 0.03068 valid_loss: 0.03106
EarlyStopping counter: 9 out of 100
[ 33/200] train_loss: 0.03062 valid_loss: 0.03101
EarlyStopping counter: 10 out of 100
[ 34/200] train_loss: 0.03064 valid_loss: 0.03095
EarlyStopping counter: 11 out of 100
[ 35/200] train_loss: 0.03037 valid_loss: 0.03093
EarlyStopping counter: 12 out of 100
[ 36/200] train_loss: 0.03045 valid_loss: 0.03099
EarlyStopping counter: 13 out of 100
[ 37/200] train_loss: 0.03033 valid_loss: 0.03093
EarlyStopping counter: 14 out of 100
[ 38/200] train_loss: 0.03013 valid_loss: 0.03115
EarlyStopping counter: 15 out of 100
[ 39/200] train_loss: 0.03018 valid_loss: 0.03113
EarlyStopping counter: 16 out of 100
[ 40/200] train_loss: 0.03011 valid_loss: 0.03127
EarlyStopping counter: 17 out of 100
[ 41/200] train_loss: 0.03006 valid_loss: 0.03

[  2/200] train_loss: 0.03912 valid_loss: 0.03416
Validation loss decreased (0.034661 --> 0.034161).  Saving model ...
[  3/200] train_loss: 0.03614 valid_loss: 0.03365
Validation loss decreased (0.034161 --> 0.033651).  Saving model ...
[  4/200] train_loss: 0.03472 valid_loss: 0.03463
EarlyStopping counter: 1 out of 100
[  5/200] train_loss: 0.03405 valid_loss: 0.03439
EarlyStopping counter: 2 out of 100
[  6/200] train_loss: 0.03340 valid_loss: 0.03376
EarlyStopping counter: 3 out of 100
[  7/200] train_loss: 0.03295 valid_loss: 0.03360
Validation loss decreased (0.033651 --> 0.033598).  Saving model ...
[  8/200] train_loss: 0.03275 valid_loss: 0.03507
EarlyStopping counter: 1 out of 100
[  9/200] train_loss: 0.03248 valid_loss: 0.03357
Validation loss decreased (0.033598 --> 0.033567).  Saving model ...
[ 10/200] train_loss: 0.03248 valid_loss: 0.03155
Validation loss decreased (0.033567 --> 0.031553).  Saving model ...
[ 11/200] train_loss: 0.03214 valid_loss: 0.03184
EarlyStoppi

[ 93/200] train_loss: 0.02776 valid_loss: 0.03176
EarlyStopping counter: 59 out of 100
[ 94/200] train_loss: 0.02781 valid_loss: 0.03150
EarlyStopping counter: 60 out of 100
[ 95/200] train_loss: 0.02788 valid_loss: 0.03129
EarlyStopping counter: 61 out of 100
[ 96/200] train_loss: 0.02755 valid_loss: 0.03172
EarlyStopping counter: 62 out of 100
[ 97/200] train_loss: 0.02767 valid_loss: 0.03199
EarlyStopping counter: 63 out of 100
[ 98/200] train_loss: 0.02764 valid_loss: 0.03163
EarlyStopping counter: 64 out of 100
[ 99/200] train_loss: 0.02763 valid_loss: 0.03127
EarlyStopping counter: 65 out of 100
[100/200] train_loss: 0.02775 valid_loss: 0.03163
EarlyStopping counter: 66 out of 100
[101/200] train_loss: 0.02773 valid_loss: 0.03136
EarlyStopping counter: 67 out of 100
[102/200] train_loss: 0.02775 valid_loss: 0.03166
EarlyStopping counter: 68 out of 100
[103/200] train_loss: 0.02749 valid_loss: 0.03180
EarlyStopping counter: 69 out of 100
[104/200] train_loss: 0.02745 valid_loss: 0

[ 50/200] train_loss: 0.02971 valid_loss: 0.03099
EarlyStopping counter: 29 out of 100
[ 51/200] train_loss: 0.02961 valid_loss: 0.03142
EarlyStopping counter: 30 out of 100
[ 52/200] train_loss: 0.02958 valid_loss: 0.03144
EarlyStopping counter: 31 out of 100
[ 53/200] train_loss: 0.02946 valid_loss: 0.03136
EarlyStopping counter: 32 out of 100
[ 54/200] train_loss: 0.02936 valid_loss: 0.03153
EarlyStopping counter: 33 out of 100
[ 55/200] train_loss: 0.02946 valid_loss: 0.03148
EarlyStopping counter: 34 out of 100
[ 56/200] train_loss: 0.02934 valid_loss: 0.03129
EarlyStopping counter: 35 out of 100
[ 57/200] train_loss: 0.02927 valid_loss: 0.03146
EarlyStopping counter: 36 out of 100
[ 58/200] train_loss: 0.02923 valid_loss: 0.03149
EarlyStopping counter: 37 out of 100
[ 59/200] train_loss: 0.02909 valid_loss: 0.03164
EarlyStopping counter: 38 out of 100
[ 60/200] train_loss: 0.02920 valid_loss: 0.03194
EarlyStopping counter: 39 out of 100
[ 61/200] train_loss: 0.02913 valid_loss: 0

[ 18/200] train_loss: 0.03145 valid_loss: 0.03156
EarlyStopping counter: 3 out of 100
[ 19/200] train_loss: 0.03144 valid_loss: 0.03146
EarlyStopping counter: 4 out of 100
[ 20/200] train_loss: 0.03130 valid_loss: 0.03197
EarlyStopping counter: 5 out of 100
[ 21/200] train_loss: 0.03131 valid_loss: 0.03129
EarlyStopping counter: 6 out of 100
[ 22/200] train_loss: 0.03109 valid_loss: 0.03205
EarlyStopping counter: 7 out of 100
[ 23/200] train_loss: 0.03146 valid_loss: 0.03121
EarlyStopping counter: 8 out of 100
[ 24/200] train_loss: 0.03128 valid_loss: 0.03093
Validation loss decreased (0.030934 --> 0.030933).  Saving model ...
[ 25/200] train_loss: 0.03089 valid_loss: 0.03103
EarlyStopping counter: 1 out of 100
[ 26/200] train_loss: 0.03094 valid_loss: 0.03090
Validation loss decreased (0.030933 --> 0.030903).  Saving model ...
[ 27/200] train_loss: 0.03122 valid_loss: 0.03103
EarlyStopping counter: 1 out of 100
[ 28/200] train_loss: 0.03072 valid_loss: 0.03079
Validation loss decrease

[111/200] train_loss: 0.02685 valid_loss: 0.03210
EarlyStopping counter: 68 out of 100
[112/200] train_loss: 0.02675 valid_loss: 0.03252
EarlyStopping counter: 69 out of 100
[113/200] train_loss: 0.02684 valid_loss: 0.03266
EarlyStopping counter: 70 out of 100
[114/200] train_loss: 0.02685 valid_loss: 0.03284
EarlyStopping counter: 71 out of 100
[115/200] train_loss: 0.02667 valid_loss: 0.03271
EarlyStopping counter: 72 out of 100
[116/200] train_loss: 0.02674 valid_loss: 0.03214
EarlyStopping counter: 73 out of 100
[117/200] train_loss: 0.02664 valid_loss: 0.03206
EarlyStopping counter: 74 out of 100
[118/200] train_loss: 0.02674 valid_loss: 0.03312
EarlyStopping counter: 75 out of 100
[119/200] train_loss: 0.02662 valid_loss: 0.03275
EarlyStopping counter: 76 out of 100
[120/200] train_loss: 0.02656 valid_loss: 0.03245
EarlyStopping counter: 77 out of 100
[121/200] train_loss: 0.02644 valid_loss: 0.03257
EarlyStopping counter: 78 out of 100
[122/200] train_loss: 0.02651 valid_loss: 0

[ 58/200] train_loss: 0.02915 valid_loss: 0.03069
EarlyStopping counter: 23 out of 100
[ 59/200] train_loss: 0.02912 valid_loss: 0.03062
EarlyStopping counter: 24 out of 100
[ 60/200] train_loss: 0.02927 valid_loss: 0.03054
EarlyStopping counter: 25 out of 100
[ 61/200] train_loss: 0.02918 valid_loss: 0.03074
EarlyStopping counter: 26 out of 100
[ 62/200] train_loss: 0.02916 valid_loss: 0.03077
EarlyStopping counter: 27 out of 100
[ 63/200] train_loss: 0.02904 valid_loss: 0.03077
EarlyStopping counter: 28 out of 100
[ 64/200] train_loss: 0.02917 valid_loss: 0.03105
EarlyStopping counter: 29 out of 100
[ 65/200] train_loss: 0.02908 valid_loss: 0.03132
EarlyStopping counter: 30 out of 100
[ 66/200] train_loss: 0.02905 valid_loss: 0.03106
EarlyStopping counter: 31 out of 100
[ 67/200] train_loss: 0.02884 valid_loss: 0.03116
EarlyStopping counter: 32 out of 100
[ 68/200] train_loss: 0.02896 valid_loss: 0.03103
EarlyStopping counter: 33 out of 100
[ 69/200] train_loss: 0.02882 valid_loss: 0

In [11]:
MAE

[14.201457,
 16.536694,
 17.080729,
 17.257412,
 17.1137,
 17.20898,
 17.245981,
 17.11227,
 17.239424,
 17.146719,
 None]

In [12]:
RMSE

[18.002491,
 20.237745,
 20.877827,
 21.074883,
 20.90935,
 21.00921,
 20.981134,
 20.831465,
 20.963219,
 20.896502,
 None]