In [1]:
# importing library
import os
import random
import torch.nn as nn
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error 
from load_data import *
from utils import *
import math
import torch.nn.init as init
import torch.nn.functional as F

In [2]:
# setting a seed
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True

In [5]:
# setting gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

Device: cuda
Current cuda device: 0
Count of using GPUs: 2


In [6]:
# data path
data_path = "dataset/df_speed.csv"
data_path2 = "dataset/df_tsr.csv"
data_path3 = "dataset/df_brake.csv"
save_path = "save/model.pt"

In [8]:
# setting hyperparameter 
n_link = 15 # the number of road links
day_slot = 144 # 24 * 60 = 1,440  10 miniute = unit
n_train, n_val, n_test = 49, 6, 6  
batch_size = 128
epochs = 200
lr = 1e-3

In [11]:
# model define
class CNN(torch.nn.Module):

    def __init__(self, size):
        super(CNN, self).__init__()
        self.size = size
        self.keep_prob = 0.5
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 256, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2))

        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2))

        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=1))

        self.fc = torch.nn.Linear(self.size, 625, bias=True)
        torch.nn.init.xavier_uniform_(self.fc.weight)
        
        self.layer4 = torch.nn.Sequential(
            self.fc,
            torch.nn.ReLU(),
            torch.nn.Dropout(p=1 - self.keep_prob))

        self.fc2 = torch.nn.Linear(625, n_link, bias=True)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)   # Flatten them for FC
        out = self.layer4(out)
        out = self.fc2(out)
        return out

In [12]:
# setting data
train, val, test = load_data(data_path, n_train * day_slot, n_val * day_slot)
train2, val2, test2 = load_data(data_path2, n_train * day_slot, n_val * day_slot)
train3, val3, test3 = load_data(data_path3, n_train * day_slot, n_val * day_slot)

# data scaling 
scaler =  MinMaxScaler()
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

scaler2 =  MinMaxScaler()
train2 = scaler2.fit_transform(train2)
val2 = scaler2.transform(val2)
test2 = scaler2.transform(test2)

scaler3 =  MinMaxScaler()
train3 = scaler3.fit_transform(train3)
val3 = scaler3.transform(val3)
test3 = scaler3.transform(test3)

In [13]:
# model training and saving results
MAE = [None] * 10
MAPE = [None] * 10
RMSE = [None] * 10 

for j in range(1, 11):
    n_his = 12 
    n_pred = j
    
    x_train, y_train = data_transform(train, n_his, n_pred, day_slot, device)
    x_val, y_val = data_transform(val, n_his, n_pred, day_slot, device)
    x_test, y_test = data_transform(test, n_his, n_pred, day_slot, device)

    x_train2, y_train2 = data_transform(train2, n_his, n_pred, day_slot, device)
    x_val2, y_val2 = data_transform(val2, n_his, n_pred, day_slot, device)
    x_test2, y_test2 = data_transform(test2, n_his, n_pred, day_slot, device)
    
    x_train3, y_train3 = data_transform(train3, n_his, n_pred, day_slot, device)
    x_val3, y_val3 = data_transform(val3, n_his, n_pred, day_slot, device)
    x_test3, y_test3 = data_transform(test3, n_his, n_pred, day_slot, device)
    
    train_data = torch.utils.data.TensorDataset(x_train, y_train, x_train2, x_train3)
    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle = False)

    val_data = torch.utils.data.TensorDataset(x_val, y_val, x_val2, x_val3)
    val_iter = torch.utils.data.DataLoader(val_data, batch_size, shuffle = False)

    test_data = torch.utils.data.TensorDataset(x_test, y_test, x_test2, x_test3)
    test_iter = torch.utils.data.DataLoader(test_data, batch_size, shuffle = False)    
    
    criterion = nn.MSELoss()
    
    test_inputs = train_data[0:256][0].to(device)
    
    layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 256, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
    layer1.to(device)

    layer2 = torch.nn.Sequential(
                torch.nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=2, stride=2))
    layer2.to(device)

    layer3 = torch.nn.Sequential(
                torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1))
    layer3.to(device)
    
    ss = layer3(layer2(layer1(test_inputs))).shape[1] * layer3(layer2(layer1(test_inputs))).shape[2] * layer3(layer2(layer1(test_inputs))).shape[3]

    model = CNN(ss).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    

    class EarlyStopping:
        def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):

            self.patience = patience
            self.verbose = verbose
            self.counter = 0
            self.best_score = None
            self.early_stop = False
            self.val_loss_min = np.Inf
            self.delta = delta
            self.path = path

        def __call__(self, val_loss, model):

            score = -val_loss

            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
            elif score < self.best_score + self.delta:
                self.counter += 1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
                self.counter = 0

        def save_checkpoint(self, val_loss, model):
            if self.verbose:
                print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            torch.save(model.state_dict(), self.path)
            self.val_loss_min = val_loss


    def train_model(model, batch_size, patience, n_epochs):

        train_losses = []
        valid_losses = []
        avg_train_losses = []
        avg_valid_losses = []

        early_stopping = EarlyStopping(patience = patience, verbose = True)

        for epoch in range(1, n_epochs + 1):

            model.train() 
            for batch, (data, targets, data2, data3) in enumerate(train_iter, 1):
                optimizer.zero_grad()    
                output = model(data)
                loss = criterion(output, targets)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())


            model.eval() 
            for data, targets, data2, data3 in val_iter :
                output = model(data)
                loss = criterion(output, targets)
                valid_losses.append(loss.item())

            train_loss = np.average(train_losses)
            valid_loss = np.average(valid_losses)
            avg_train_losses.append(train_loss)
            avg_valid_losses.append(valid_loss)

            epoch_len = len(str(n_epochs))

            train_losses = []
            valid_losses = []

            early_stopping(valid_loss, model)

            if early_stopping.early_stop:
                print("Early stopping")
                break

        model.load_state_dict(torch.load('checkpoint.pt'))

        return  model, avg_train_losses, avg_valid_losses
    
    patience = 100
    n_epochs = 200
    model, train_loss, valid_loss = train_model(model, batch_size, patience, n_epochs)   

    model.eval()
    with torch.no_grad():
        valid_tensor = x_test
        predict = model(valid_tensor)
    predict = predict.cpu().data.numpy()
    actual_predictions = predict
    actual_predictions = scaler.inverse_transform(actual_predictions)
    
    groun = y_test.cpu().data.numpy()
    groun = scaler.inverse_transform(groun)
    
    groun2 = pd.DataFrame(groun)
    groun2.to_csv('results/ground_' + str(j) + '.csv')
    actual_predictions2 = pd.DataFrame(actual_predictions)
    actual_predictions2.to_csv('results/predictions_' + str(j) + '.csv')
    torch.save(model.state_dict(), 'results/model_' + str(j) + '.pt')
    
    MAE[j-1] = mean_absolute_error(groun, actual_predictions)
    MSE = mean_squared_error(groun, actual_predictions) 
    RMSE[j-1] = np.sqrt(MSE)
    print(MAE)
    print(RMSE)

Validation loss decreased (inf --> 0.031283).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
Validation loss decreased (0.031283 --> 0.030386).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.030386 --> 0.030114).  Saving model ...
Validation loss decreased (0.030114 --> 0.029664).  Saving model ...
Validation loss decreased (0.029664 --> 0.029392).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.029392 --> 0.029381).  Saving model ...
Validation loss decreased (0.029381 --> 0.029076).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
Validation loss decreased (0.029076 --> 0.028994).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.028994 --> 0.028882).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.028882 --> 0.028749).  Saving model ...
EarlyStopping counter: 1 ou

EarlyStopping counter: 33 out of 100
EarlyStopping counter: 34 out of 100
EarlyStopping counter: 35 out of 100
EarlyStopping counter: 36 out of 100
EarlyStopping counter: 37 out of 100
EarlyStopping counter: 38 out of 100
EarlyStopping counter: 39 out of 100
EarlyStopping counter: 40 out of 100
EarlyStopping counter: 41 out of 100
EarlyStopping counter: 42 out of 100
EarlyStopping counter: 43 out of 100
EarlyStopping counter: 44 out of 100
EarlyStopping counter: 45 out of 100
EarlyStopping counter: 46 out of 100
EarlyStopping counter: 47 out of 100
EarlyStopping counter: 48 out of 100
EarlyStopping counter: 49 out of 100
EarlyStopping counter: 50 out of 100
EarlyStopping counter: 51 out of 100
EarlyStopping counter: 52 out of 100
EarlyStopping counter: 53 out of 100
EarlyStopping counter: 54 out of 100
EarlyStopping counter: 55 out of 100
EarlyStopping counter: 56 out of 100
EarlyStopping counter: 57 out of 100
EarlyStopping counter: 58 out of 100
EarlyStopping counter: 59 out of 100
E

EarlyStopping counter: 6 out of 100
EarlyStopping counter: 7 out of 100
EarlyStopping counter: 8 out of 100
EarlyStopping counter: 9 out of 100
EarlyStopping counter: 10 out of 100
EarlyStopping counter: 11 out of 100
EarlyStopping counter: 12 out of 100
EarlyStopping counter: 13 out of 100
EarlyStopping counter: 14 out of 100
EarlyStopping counter: 15 out of 100
Validation loss decreased (0.032950 --> 0.032788).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.032788 --> 0.032763).  Saving model ...
Validation loss decreased (0.032763 --> 0.032682).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
EarlyStopping counter: 4 out of 100
EarlyStopping counter: 5 out of 100
EarlyStopping counter: 6 out of 100
EarlyStopping counter: 7 out of 100
EarlyStopping counter: 8 out of 100
EarlyStopping counter: 9 out of 100
Validation loss decreased (0.032682 --> 0.032502).  Saving model ..

EarlyStopping counter: 60 out of 100
EarlyStopping counter: 61 out of 100
EarlyStopping counter: 62 out of 100
EarlyStopping counter: 63 out of 100
EarlyStopping counter: 64 out of 100
EarlyStopping counter: 65 out of 100
EarlyStopping counter: 66 out of 100
EarlyStopping counter: 67 out of 100
EarlyStopping counter: 68 out of 100
EarlyStopping counter: 69 out of 100
EarlyStopping counter: 70 out of 100
EarlyStopping counter: 71 out of 100
EarlyStopping counter: 72 out of 100
EarlyStopping counter: 73 out of 100
EarlyStopping counter: 74 out of 100
EarlyStopping counter: 75 out of 100
EarlyStopping counter: 76 out of 100
EarlyStopping counter: 77 out of 100
EarlyStopping counter: 78 out of 100
EarlyStopping counter: 79 out of 100
EarlyStopping counter: 80 out of 100
EarlyStopping counter: 81 out of 100
EarlyStopping counter: 82 out of 100
EarlyStopping counter: 83 out of 100
EarlyStopping counter: 84 out of 100
EarlyStopping counter: 85 out of 100
EarlyStopping counter: 86 out of 100
E

EarlyStopping counter: 14 out of 100
EarlyStopping counter: 15 out of 100
EarlyStopping counter: 16 out of 100
EarlyStopping counter: 17 out of 100
EarlyStopping counter: 18 out of 100
EarlyStopping counter: 19 out of 100
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 21 out of 100
EarlyStopping counter: 22 out of 100
EarlyStopping counter: 23 out of 100
EarlyStopping counter: 24 out of 100
EarlyStopping counter: 25 out of 100
EarlyStopping counter: 26 out of 100
EarlyStopping counter: 27 out of 100
EarlyStopping counter: 28 out of 100
EarlyStopping counter: 29 out of 100
EarlyStopping counter: 30 out of 100
EarlyStopping counter: 31 out of 100
EarlyStopping counter: 32 out of 100
EarlyStopping counter: 33 out of 100
EarlyStopping counter: 34 out of 100
EarlyStopping counter: 35 out of 100
EarlyStopping counter: 36 out of 100
EarlyStopping counter: 37 out of 100
EarlyStopping counter: 38 out of 100
EarlyStopping counter: 39 out of 100
EarlyStopping counter: 40 out of 100
E

EarlyStopping counter: 98 out of 100
EarlyStopping counter: 99 out of 100
EarlyStopping counter: 100 out of 100
[15.51894, 16.557083, 16.83041, 16.90737, 16.960464, 17.1457, 17.14902, 17.256634, None, None]
[19.211773, 20.32164, 20.532719, 20.818008, 20.68599, 20.846935, 20.901573, 20.91474, None, None]
Validation loss decreased (inf --> 0.033417).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
EarlyStopping counter: 4 out of 100
EarlyStopping counter: 5 out of 100
Validation loss decreased (0.033417 --> 0.031806).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.031806 --> 0.031256).  Saving model ...
Validation loss decreased (0.031256 --> 0.031140).  Saving model ...
Validation loss decreased (0.031140 --> 0.030923).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
EarlyStopping counter: 4 out of

EarlyStopping counter: 58 out of 100
EarlyStopping counter: 59 out of 100
EarlyStopping counter: 60 out of 100
EarlyStopping counter: 61 out of 100
EarlyStopping counter: 62 out of 100
EarlyStopping counter: 63 out of 100
EarlyStopping counter: 64 out of 100
EarlyStopping counter: 65 out of 100
EarlyStopping counter: 66 out of 100
EarlyStopping counter: 67 out of 100
EarlyStopping counter: 68 out of 100
EarlyStopping counter: 69 out of 100
EarlyStopping counter: 70 out of 100
EarlyStopping counter: 71 out of 100
EarlyStopping counter: 72 out of 100
EarlyStopping counter: 73 out of 100
EarlyStopping counter: 74 out of 100
EarlyStopping counter: 75 out of 100
EarlyStopping counter: 76 out of 100
EarlyStopping counter: 77 out of 100
EarlyStopping counter: 78 out of 100
EarlyStopping counter: 79 out of 100
EarlyStopping counter: 80 out of 100
EarlyStopping counter: 81 out of 100
EarlyStopping counter: 82 out of 100
EarlyStopping counter: 83 out of 100
EarlyStopping counter: 84 out of 100
E