In [1]:
# importing library
import random
import torch.nn as nn
import torch
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error 

from load_data import *
from utils import *

import math
import torch.nn.init as init
import torch.nn.functional as F

In [2]:
# setting a seed
torch.manual_seed(2333)
torch.cuda.manual_seed(2333)
np.random.seed(2333)
random.seed(2333)
torch.backends.cudnn.deterministic = True

In [3]:
# setting gpu
device = torch.device('cuda:0')
device

device(type='cuda', index=0)

In [6]:
# data path
matrix_path = "dataset/ad.csv"
data_path = "dataset/df_speed.csv"
data_path3 = "dataset/df_brake.csv"
save_path = "save/model.pt"

In [7]:
# setting hyperparameter 
day_slot = 144 # 24 * 60 = 1,440  10 miniute = unit
n_train, n_val, n_test = 49, 6, 6 
n_route = 15 # the number of road links
Ks, Kt = 3, 3
blocks = [[1, 32, 64], [64, 32, 128]]
drop_prob = 0.5
batch_size = 256
epochs = 200
lr = 1e-3

In [10]:
# setting matrix for GCN
# degree - adjacency = laplacian matrix
W = load_matrix(matrix_path)
L = scaled_laplacian(W) 
Lk = cheb_poly(L, Ks)
Lk = torch.Tensor(Lk.astype(np.float32)).to(device)

In [11]:
# setting data
train, val, test = load_data(data_path, n_train * day_slot, n_val * day_slot)
train3, val3, test3 = load_data(data_path3, n_train * day_slot, n_val * day_slot)

scaler =  MinMaxScaler()
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

scaler3 =  MinMaxScaler()
train3 = scaler3.fit_transform(train3)
val3 = scaler3.transform(val3)
test3 = scaler3.transform(test3)

In [12]:
# custom CNN layer for hybrid model
class custom_layer(nn.Module):
    def __init__(self):
        super(custom_layer, self).__init__()
        self.keep_prob = 0.5
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 256, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2))

        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2))

        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
        torch.nn.Dropout(p=1 - self.keep_prob))

        self.fc = torch.nn.Linear(256, 625, bias=True)
        torch.nn.init.xavier_uniform_(self.fc.weight)

        self.layer4 = torch.nn.Sequential(
            self.fc,
            torch.nn.ReLU(),
            torch.nn.Dropout(p=1 - self.keep_prob))

        self.fc2 = torch.nn.Linear(625, 15, bias=True)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)  
        out = self.layer4(out)
        return self.fc2(out)

In [13]:
class align(nn.Module):
    def __init__(self, c_in, c_out):
        super(align, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        if c_in > c_out:
            self.conv1x1 = nn.Conv2d(c_in, c_out, 1)

    def forward(self, x):
        if self.c_in > self.c_out:
            return self.conv1x1(x)
        if self.c_in < self.c_out:
            return F.pad(x, [0, 0, 0, 0, 0, self.c_out - self.c_in, 0, 0])
        return x

class temporal_conv_layer(nn.Module):
    def __init__(self, kt, c_in, c_out, act="relu"):
        super(temporal_conv_layer, self).__init__()
        self.kt = kt
        self.act = act
        self.c_out = c_out
        self.align = align(c_in, c_out)
        if self.act == "GLU":
            self.conv = nn.Conv2d(c_in, c_out * 2, (kt, 1), 1)
        else:
            self.conv = nn.Conv2d(c_in, c_out, (kt, 1), 1)

    def forward(self, x):
        x_in = self.align(x)[:, :, self.kt - 1:, :]
        if self.act == "GLU":
            x_conv = self.conv(x)
            return (x_conv[:, :self.c_out, :, :] + x_in) * torch.sigmoid(x_conv[:, self.c_out:, :, :])
        if self.act == "sigmoid":
            return torch.sigmoid(self.conv(x) + x_in)
        return torch.relu(self.conv(x) + x_in)

class spatio_conv_layer(nn.Module):
    def __init__(self, ks, c, Lk):
        super(spatio_conv_layer, self).__init__()
        self.Lk = Lk
        self.theta = nn.Parameter(torch.FloatTensor(c, c, ks))
        self.b = nn.Parameter(torch.FloatTensor(1, c, 1, 1))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.theta, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.theta)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.b, -bound, bound)

    def forward(self, x):
        x_c = torch.einsum("knm,bitm->bitkn", self.Lk, x)
        x_gc = torch.einsum("iok,bitkn->botn", self.theta, x_c) + self.b
        return torch.relu(x_gc + x)

class st_conv_block(nn.Module):
    def __init__(self, ks, kt, n, c, p, Lk):
        super(st_conv_block, self).__init__()
        self.tconv1 = temporal_conv_layer(kt, c[0], c[1], "GLU")
        self.sconv = spatio_conv_layer(ks, c[1], Lk)
        self.tconv2 = temporal_conv_layer(kt, c[1], c[2])
        self.ln = nn.LayerNorm([n, c[2]])
        self.dropout = nn.Dropout(p)

    def forward(self, x):
        x_t1 = self.tconv1(x)
        x_s = self.sconv(x_t1)
        x_t2 = self.tconv2(x_s)
        x_ln = self.ln(x_t2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)        
        return self.dropout(x_ln)

class fully_conv_layer(nn.Module):
    def __init__(self, c):
        super(fully_conv_layer, self).__init__()
        self.conv = nn.Conv2d(c, 1, 1)

    def forward(self, x):
        return self.conv(x)

class output_layer(nn.Module):
    def __init__(self, c, T, n):
        super(output_layer, self).__init__()
        self.tconv1 = temporal_conv_layer(T, c, c, "GLU")
        self.ln = nn.LayerNorm([n, c])
        self.tconv2 = temporal_conv_layer(1, c, c, "sigmoid")
        self.fc = fully_conv_layer(c)

    def forward(self, x):
        x_t1 = self.tconv1(x)
        x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x_t2 = self.tconv2(x_ln)
        return self.fc(x_t2)

In [14]:
class STGCN(nn.Module):
    def __init__(self, ks, kt, bs, T, n, Lk, p): 
        super(STGCN, self).__init__()
        self.st_conv1 = st_conv_block(ks, kt, n, bs[0], p, Lk)
        self.st_conv2 = st_conv_block(ks, kt, n, bs[1], p, Lk)
        
        self.custom_layer1 = custom_layer()
        self.custom_layer2 = custom_layer()
        
        self.output1 = output_layer(bs[1][2], T - 4 * (kt - 1), n)
        
        self.fc = torch.nn.Linear(30, 625, bias=True)
        torch.nn.init.xavier_uniform_(self.fc.weight)
        
        self.layer13 = torch.nn.Sequential(
            self.fc,
            torch.nn.ReLU(),
            torch.nn.Dropout(p))

        self.fc2 = torch.nn.Linear(625, n, bias=True)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x, y):
        x_st1 = self.st_conv1(x)
        x_st2 = self.st_conv2(x_st1)
        x_st3 = self.output1(x_st2)
        
        y_st1 = self.custom_layer1(y)
        x_st3 = x_st3[:,0,0,:]

        z = torch.cat((x_st3, y_st1), 1)
        z = z.contiguous().view(z.size(0), -1)
        z = self.layer13(z)
        z = self.fc2(z)
        return z

In [15]:
MAE = [None] * 11
MAPE = [None] * 11
RMSE = [None] * 11 

for j in range(1, 11):
    n_his = 12 
    n_pred = j
    
    x_train, y_train = data_transform(train, n_his, n_pred, day_slot, device)
    x_val, y_val = data_transform(val, n_his, n_pred, day_slot, device)
    x_test, y_test = data_transform(test, n_his, n_pred, day_slot, device)
   
    x_train3, y_train3 = data_transform(train3, n_his, n_pred, day_slot, device)
    x_val3, y_val3 = data_transform(val3, n_his, n_pred, day_slot, device)
    x_test3, y_test3 = data_transform(test3, n_his, n_pred, day_slot, device)
    
    train_data = torch.utils.data.TensorDataset(x_train, y_train, x_train3)
    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle = False)

    val_data = torch.utils.data.TensorDataset(x_val, y_val, x_val3)
    val_iter = torch.utils.data.DataLoader(val_data, batch_size, shuffle = False)

    test_data = torch.utils.data.TensorDataset(x_test, y_test, x_test3)
    test_iter = torch.utils.data.DataLoader(test_data, batch_size, shuffle = False)    
    
    criterion = nn.MSELoss()
    model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    

    class EarlyStopping:
        def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
            self.patience = patience
            self.verbose = verbose
            self.counter = 0
            self.best_score = None
            self.early_stop = False
            self.val_loss_min = np.Inf
            self.delta = delta
            self.path = path

        def __call__(self, val_loss, model):
            score = -val_loss
            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
            elif score < self.best_score + self.delta:
                self.counter += 1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
                self.counter = 0

        def save_checkpoint(self, val_loss, model):
            if self.verbose:
                print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            torch.save(model.state_dict(), self.path)
            self.val_loss_min = val_loss


    def train_model(model, batch_size, patience, n_epochs):
        train_losses = []
        valid_losses = []
        avg_train_losses = []
        avg_valid_losses = []

        early_stopping = EarlyStopping(patience = patience, verbose = True)

        for epoch in range(1, n_epochs + 1):

            model.train() 
            for batch, (data, targets, data2) in enumerate(train_iter, 1):
                optimizer.zero_grad()    
                output = model(data, data2)
                loss = criterion(output, targets)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())
            model.eval()
            for data, targets, data2 in val_iter :
                output = model(data, data2)
                loss = criterion(output, targets)
                valid_losses.append(loss.item())

            train_loss = np.average(train_losses)
            valid_loss = np.average(valid_losses)
            avg_train_losses.append(train_loss)
            avg_valid_losses.append(valid_loss)

            epoch_len = len(str(n_epochs))

            train_losses = []
            valid_losses = []
            early_stopping(valid_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
        model.load_state_dict(torch.load('checkpoint.pt'))
        return  model, avg_train_losses, avg_valid_losses
    
    patience = 100
    n_epochs = 200
    model, train_loss, valid_loss = train_model(model, batch_size, patience, n_epochs)   

    model.eval()
    with torch.no_grad():
        valid_tensor = x_test
        valid_tensor3 = x_test3
        
        predict = model(valid_tensor, valid_tensor3)
    predict = predict.cpu().data.numpy()
    actual_predictions = predict
    actual_predictions = scaler.inverse_transform(actual_predictions)
    
    groun = y_test.cpu().data.numpy()
    groun = scaler.inverse_transform(groun)

    
    groun2 = pd.DataFrame(groun)
    groun2.to_csv('CNSTGCN_results2/ground_' + str(j) + '.csv')
    actual_predictions2 = pd.DataFrame(actual_predictions)
    actual_predictions2.to_csv('CNSTGCN_results2/predictions_' + str(j) + '.csv')
    torch.save(model.state_dict(), 'CNSTGCN_results2/model_' + str(j) + '.pt')
    
    MAE[j-1] = mean_absolute_error(groun, actual_predictions)
    MSE = mean_squared_error(groun, actual_predictions) 
    RMSE[j-1] = np.sqrt(MSE)
    print(MAE)
    print(RMSE)

Validation loss decreased (inf --> 0.035846).  Saving model ...
Validation loss decreased (0.035846 --> 0.032495).  Saving model ...
Validation loss decreased (0.032495 --> 0.032457).  Saving model ...
Validation loss decreased (0.032457 --> 0.031938).  Saving model ...
Validation loss decreased (0.031938 --> 0.031226).  Saving model ...
Validation loss decreased (0.031226 --> 0.030966).  Saving model ...
Validation loss decreased (0.030966 --> 0.030803).  Saving model ...
Validation loss decreased (0.030803 --> 0.030466).  Saving model ...
Validation loss decreased (0.030466 --> 0.029752).  Saving model ...
Validation loss decreased (0.029752 --> 0.029231).  Saving model ...
Validation loss decreased (0.029231 --> 0.028550).  Saving model ...
Validation loss decreased (0.028550 --> 0.028427).  Saving model ...
Validation loss decreased (0.028427 --> 0.028085).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.028085 --> 0.027825).  Saving model ...
Val

EarlyStopping counter: 5 out of 100
EarlyStopping counter: 6 out of 100
EarlyStopping counter: 7 out of 100
EarlyStopping counter: 8 out of 100
EarlyStopping counter: 9 out of 100
EarlyStopping counter: 10 out of 100
EarlyStopping counter: 11 out of 100
EarlyStopping counter: 12 out of 100
EarlyStopping counter: 13 out of 100
EarlyStopping counter: 14 out of 100
EarlyStopping counter: 15 out of 100
EarlyStopping counter: 16 out of 100
EarlyStopping counter: 17 out of 100
EarlyStopping counter: 18 out of 100
EarlyStopping counter: 19 out of 100
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 21 out of 100
EarlyStopping counter: 22 out of 100
EarlyStopping counter: 23 out of 100
EarlyStopping counter: 24 out of 100
EarlyStopping counter: 25 out of 100
EarlyStopping counter: 26 out of 100
EarlyStopping counter: 27 out of 100
EarlyStopping counter: 28 out of 100
EarlyStopping counter: 29 out of 100
EarlyStopping counter: 30 out of 100
EarlyStopping counter: 31 out of 100
EarlyS

EarlyStopping counter: 45 out of 100
EarlyStopping counter: 46 out of 100
EarlyStopping counter: 47 out of 100
EarlyStopping counter: 48 out of 100
EarlyStopping counter: 49 out of 100
EarlyStopping counter: 50 out of 100
EarlyStopping counter: 51 out of 100
EarlyStopping counter: 52 out of 100
EarlyStopping counter: 53 out of 100
EarlyStopping counter: 54 out of 100
EarlyStopping counter: 55 out of 100
EarlyStopping counter: 56 out of 100
EarlyStopping counter: 57 out of 100
EarlyStopping counter: 58 out of 100
EarlyStopping counter: 59 out of 100
EarlyStopping counter: 60 out of 100
EarlyStopping counter: 61 out of 100
EarlyStopping counter: 62 out of 100
EarlyStopping counter: 63 out of 100
EarlyStopping counter: 64 out of 100
EarlyStopping counter: 65 out of 100
EarlyStopping counter: 66 out of 100
EarlyStopping counter: 67 out of 100
EarlyStopping counter: 68 out of 100
EarlyStopping counter: 69 out of 100
EarlyStopping counter: 70 out of 100
EarlyStopping counter: 71 out of 100
E

Validation loss decreased (inf --> 0.034477).  Saving model ...
EarlyStopping counter: 1 out of 100
Validation loss decreased (0.034477 --> 0.033262).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
Validation loss decreased (0.033262 --> 0.032968).  Saving model ...
Validation loss decreased (0.032968 --> 0.032468).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
Validation loss decreased (0.032468 --> 0.031855).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
EarlyStopping counter: 4 out of 100
EarlyStopping counter: 5 out of 100
EarlyStopping counter: 6 out of 100
EarlyStopping counter: 7 out of 100
EarlyStopping counter: 8 out of 100
EarlyStopping counter: 9 out of 100
EarlyStopping counter: 10 out of 100
EarlyStopping counter: 11 out of 100
EarlyStopping counter: 12 out of 100
EarlyStop

EarlyStopping counter: 40 out of 100
EarlyStopping counter: 41 out of 100
EarlyStopping counter: 42 out of 100
EarlyStopping counter: 43 out of 100
EarlyStopping counter: 44 out of 100
EarlyStopping counter: 45 out of 100
EarlyStopping counter: 46 out of 100
EarlyStopping counter: 47 out of 100
EarlyStopping counter: 48 out of 100
EarlyStopping counter: 49 out of 100
EarlyStopping counter: 50 out of 100
EarlyStopping counter: 51 out of 100
EarlyStopping counter: 52 out of 100
EarlyStopping counter: 53 out of 100
EarlyStopping counter: 54 out of 100
EarlyStopping counter: 55 out of 100
EarlyStopping counter: 56 out of 100
EarlyStopping counter: 57 out of 100
EarlyStopping counter: 58 out of 100
EarlyStopping counter: 59 out of 100
EarlyStopping counter: 60 out of 100
EarlyStopping counter: 61 out of 100
EarlyStopping counter: 62 out of 100
EarlyStopping counter: 63 out of 100
EarlyStopping counter: 64 out of 100
EarlyStopping counter: 65 out of 100
EarlyStopping counter: 66 out of 100
E

Validation loss decreased (0.031792 --> 0.031731).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
Validation loss decreased (0.031731 --> 0.031369).  Saving model ...
EarlyStopping counter: 1 out of 100
EarlyStopping counter: 2 out of 100
EarlyStopping counter: 3 out of 100
EarlyStopping counter: 4 out of 100
EarlyStopping counter: 5 out of 100
EarlyStopping counter: 6 out of 100
EarlyStopping counter: 7 out of 100
EarlyStopping counter: 8 out of 100
EarlyStopping counter: 9 out of 100
EarlyStopping counter: 10 out of 100
EarlyStopping counter: 11 out of 100
EarlyStopping counter: 12 out of 100
EarlyStopping counter: 13 out of 100
EarlyStopping counter: 14 out of 100
EarlyStopping counter: 15 out of 100
EarlyStopping counter: 16 out of 100
EarlyStopping counter: 17 out of 100
EarlyStopping counter: 18 out of 100
EarlyStopping counter: 19 out of 100
EarlyStopping counter: 20 out of 100
EarlyStopping counter: 

EarlyStopping counter: 44 out of 100
EarlyStopping counter: 45 out of 100
EarlyStopping counter: 46 out of 100
EarlyStopping counter: 47 out of 100
EarlyStopping counter: 48 out of 100
EarlyStopping counter: 49 out of 100
EarlyStopping counter: 50 out of 100
EarlyStopping counter: 51 out of 100
EarlyStopping counter: 52 out of 100
EarlyStopping counter: 53 out of 100
EarlyStopping counter: 54 out of 100
EarlyStopping counter: 55 out of 100
EarlyStopping counter: 56 out of 100
EarlyStopping counter: 57 out of 100
EarlyStopping counter: 58 out of 100
EarlyStopping counter: 59 out of 100
EarlyStopping counter: 60 out of 100
EarlyStopping counter: 61 out of 100
EarlyStopping counter: 62 out of 100
EarlyStopping counter: 63 out of 100
EarlyStopping counter: 64 out of 100
EarlyStopping counter: 65 out of 100
EarlyStopping counter: 66 out of 100
EarlyStopping counter: 67 out of 100
EarlyStopping counter: 68 out of 100
EarlyStopping counter: 69 out of 100
EarlyStopping counter: 70 out of 100
E

In [16]:
MAE

[14.166544,
 16.144623,
 16.87106,
 17.003078,
 17.027254,
 17.040333,
 17.242882,
 17.089788,
 17.072144,
 17.333147,
 None]

In [17]:
RMSE

[17.908945,
 19.920412,
 20.750677,
 20.8829,
 20.877731,
 20.889994,
 20.974121,
 20.979414,
 20.822517,
 21.065914,
 None]