In [None]:
!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

In [None]:
import dgl
import dgl.nn as dgl_nn
import dgl.function as dgl_f 
import torch 
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx 
from metra import *
from utils import *
from tqdm.notebook import tqdm
from IPython.display import clear_output
import scipy.sparse as sparse
import math

import os
import ssl
from six.moves import urllib
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(123)

## Dataset

In [None]:
def download_file(dataset):
    print("Start Downloading data: {}".format(dataset))
    url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
        dataset)
    print("Start Downloading File....")
    context = ssl._create_unverified_context()
    data = urllib.request.urlopen(url, context=context)
    with open("./data/{}".format(dataset), "wb") as handle:
        handle.write(data.read())


class SnapShotDataset(Dataset):
    def __init__(self, path, npz_file):
        if not os.path.exists(path+'/'+npz_file):
            if not os.path.exists(path):
                os.mkdir(path)
            download_file(npz_file)
        zipfile = np.load(path+'/'+npz_file)
        self.x = zipfile['x']
        self.y = zipfile['y']

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.x[idx, ...], self.y[idx, ...]


def METR_LAGraphDataset():
    if not os.path.exists('data/graph_la.bin'):
        if not os.path.exists('data'):
            os.mkdir('data')
        download_file('graph_la.bin')
    g, _ = dgl.load_graphs('data/graph_la.bin')
    return g[0]


class METR_LATrainDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz')
        #self.mean = self.x[..., 0].mean()
        #self.std = self.x[..., 0].std()
        print(self.x.shape)
        self.mean = np.mean(self.x, axis = (0,1,2))
        self.std = np.std(self.x, axis = (0,1,2))


class METR_LATestDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz')


class METR_LAValidDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz')


def PEMS_BAYGraphDataset():
    if not os.path.exists('data/graph_bay.bin'):
        if not os.path.exists('data'):
            os.mkdir('data')
        download_file('graph_bay.bin')
    g, _ = dgl.load_graphs('data/graph_bay.bin')
    return g[0]


class PEMS_BAYTrainDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYTrainDataset, self).__init__(
            'data', 'pems_bay_train.npz')
        self.mean = np.mean(self.x, axis = (0,1,2))
        self.std = np.std(self.x, axis = (0,1,2))


class PEMS_BAYTestDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz')


class PEMS_BAYValidDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYValidDataset, self).__init__(
            'data', 'pems_bay_valid.npz')

In [None]:
def BufflogroveGraphDataset():
    g, _ = dgl.load_graphs('data/Buffalogrove/graph_buffalo.bin')
    return g[0]

class BuffalogroveTrainDataset(SnapShotDataset):
    def __init__(self, num_steps = 3):
        super(BuffalogroveTrainDataset, self).__init__("data/Buffalogrove", f"buffalogrove_train_data_{num_steps}_steps.npz")
        print(self.x.shape, self.y.shape)
        self.mean = np.mean(self.x, axis = (0,1,2))
        self.std = np.std(self.x, axis = (0,1,2))
        
class BuffalogroveValDataset(SnapShotDataset):
    def __init__(self, num_steps = 3):
        super(BuffalogroveValDataset, self).__init__("data/Buffalogrove", f"buffalogrove_val_data_{num_steps}_steps.npz")

class BuffalogroveTestDataset(SnapShotDataset):
    def __init__(self, date, num_steps = 3):
        filename = "buffalogrove_test_" + date + f"_data_{num_steps}_steps.npz"
        super(BuffalogroveTestDataset, self).__init__("data/Buffalogrove", filename)

def GurneeGraphDataset():
    g, _ = dgl.load_graphs('data/Gurnee/graph_gurnee.bin')
    return g[0]

class GurneeTrainDataset(SnapShotDataset):
    def __init__(self, num_steps = 3):
        super(GurneeTrainDataset, self).__init__("data/Gurnee", f"gurnee_train_data_{num_steps}_steps.npz")
        print(self.x.shape)
        self.mean = np.mean(self.x, axis = (0,1,2))
        self.std = np.std(self.x, axis = (0,1,2))
        
class GurneeValDataset(SnapShotDataset):
    def __init__(self, num_steps = 3):
        super(GurneeValDataset, self).__init__("data/Gurnee", f"gurnee_val_data_{num_steps}_steps.npz")

class GurneeTestDataset(SnapShotDataset):
    def __init__(self, date, num_steps =3 ):
        filename = "gurnee_test_" + date + f"_data_{num_steps}_steps.npz"
        super(GurneeTestDataset, self).__init__("data/Gurnee", filename)

## Model 

In [None]:
from typing import Optional

import torch
import torch.nn as nn
class STGPCN(nn.Module):
    r"""
    Args:
        in_channels (int): Number of input features.
        n_filter (int): Number of filters.
        time_strides (int): Time strides during temporal convolution.
        num_for_predict (int): Number of predictions to make in the future.
        len_input (int): Length of the input sequence.
    """

    def __init__(
        self,
        in_channels: int,
        n_filter: int,
        time_strides: int,
        num_for_predict: int,
        len_input: int,
        bias: bool = True,
    ):

        super(D3TGCN3, self).__init__()
        print('Start')

        self._conv1 = nn.Conv3d(in_channels=1, out_channels=256, kernel_size=(1, 1, len_input), groups=1)
        self._pool1 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self._conv2 = nn.Conv3d(in_channels=1, out_channels=128, kernel_size=(1, 1, 128), groups=1)
        self._pool2 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self.layer_norm = nn.LayerNorm([64])
        self._conv3 = nn.Conv3d(in_channels=1, out_channels=64, kernel_size=(1, 1, 64), groups=1)
        self._pool3 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self._conv4 = nn.Conv3d(in_channels=1, out_channels=len_input, kernel_size=(1, 1, 32), groups=1)

        self._conv1_1 = nn.Conv3d(in_channels=2, out_channels=256, kernel_size=(1, 1, len_input), groups=1)
        self._pool1_1 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self._conv2_1 = nn.Conv3d(in_channels=1, out_channels=128, kernel_size=(1, 1, 128), groups=1)
        self._pool2_1 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self.layer_norm_1 = nn.LayerNorm([64])
        self._conv3_1 = nn.Conv3d(in_channels=1, out_channels=64, kernel_size=(1, 1, 64), groups=1)
        self._pool3_1 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self._conv4_1 = nn.Conv3d(in_channels=1, out_channels=len_input, kernel_size=(1, 1, 32), groups=1)

        self._conv1_2 = nn.Conv3d(in_channels=2, out_channels=256, kernel_size=(1, 1, len_input), groups=1)
        self._pool1_2 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self._conv2_2 = nn.Conv3d(in_channels=1, out_channels=128, kernel_size=(1, 1, 128), groups=1)
        self._pool2_2 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self.layer_norm_2 = nn.LayerNorm([64])
        self._conv3_2 = nn.Conv3d(in_channels=1, out_channels=64, kernel_size=(1, 1, 64), groups=1)
        self._pool3_2 = nn.MaxPool3d((1, 1, 2), stride=(1, 1, 2))
        self._conv4_2 = nn.Conv3d(in_channels=1, out_channels=len_input, kernel_size=(1, 1, 32), groups=1)

        self._final_conv = nn.Conv3d(in_channels=2, out_channels=num_for_predict, kernel_size=(1, 1, len_input), groups=1)

        self._reset_parameters()

    def _reset_parameters(self):
        """
        Resetting the parameters.
        """
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)

    def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
        """
        Making a forward pass.

        Arg types:
            * **X** (PyTorch FloatTensor) - Node features for T time periods, with shape (B, N_nodes, F_in, T_in).
           
        Return types:
            * **X** (PyTorch FloatTensor)* - Hidden state tensor for all nodes, with shape (B, N_nodes, T_out).
        """
        # X = (B, N, I, T) (16, 318, 1, 6)  
        X = torch.unsqueeze(X, 1)  # (B, 1, NxE, I, T)  (16, 1, 318, 1, 6)
        Z = X.clone()  
        X = self._conv1(X)  # => (16, 256, 318, 1, 1) 
        # X = torch.squeeze(X, 4)  # => (16, 6, 318, 1)  
        # X = torch.unsqueeze(X, 4)
        # X = X.permute(0, 4, 1, 2, 3)  # => (16, 1, 256, 318, 1) 
        X = X.permute(0, 4, 2, 3, 1)  # => (16, 1, 318, 1, 256) 
        # X = X.permute(0, 2, 1, 3)
        # (b,N,F,T)->(b,T,N,F)-conv<1,F>->(b,c_out*T,N,1) 
        # for example (32, 307, 64, 12) -permute-> (32, 12, 307,64) -final_conv-> (32, 12, 307, 1)
        # X = self._final_conv(X.permute(0, 3, 1, 2))
        X = self._pool1(X)
        X = self._conv2(X)  # => (16, 1, 318, 1, 128)
        X = X.permute(0, 4, 2, 3, 1) 
        # X = self._conv3(X)  # => (16, 1, 318, 1, 64) 
        X = self._pool2(X) 
        X = self.layer_norm(X)
        X = self._conv3(X)  # => (16, 1, 318, 1, 128) 
        X = X.permute(0, 4, 2, 3, 1) 
        # # X = self._conv3(X)  # => (16, 1, 318, 1, 64) 
        X = self._pool3(X) 
        X = self._conv4(X)  # => (16, 1, 318, 1, 128) 
        X = X.permute(0, 4, 2, 3, 1) 
        Z_ = X.clone()  
        X = torch.cat((X, Z), 1)

        # Second block
        X = self._conv1_1(X)  # => (16, 256, 318, 1, 1) 
        X = X.permute(0, 4, 2, 3, 1)  # => (16, 1, 318, 1, 256) 
        # (b,N,F,T)->(b,T,N,F)-conv<1,F>->(b,c_out*T,N,1) 
        # X = self._final_conv(X.permute(0, 3, 1, 2))
        X = self._pool1_1(X)
        X = self._conv2_1(X)  # => (16, 1, 318, 1, 128)
        X = X.permute(0, 4, 2, 3, 1) 
        X = self._pool2_1(X) 
        X = self.layer_norm(X)
        X = self._conv3_1(X)  # => (16, 1, 318, 1, 128) 
        X = X.permute(0, 4, 2, 3, 1) 
        X = self._pool3_1(X) 
        X = self._conv4_1(X)  # => (16, 1, 318, 1, 128) 
        X = X.permute(0, 4, 2, 3, 1) 
        X = torch.cat((X, Z_), 1)

        # Third block
        X = self._conv1_2(X)  # => (16, 256, 318, 1, 1) 
        X = X.permute(0, 4, 2, 3, 1)  # => (16, 1, 318, 1, 256) 
        # (b,N,F,T)->(b,T,N,F)-conv<1,F>->(b,c_out*T,N,1) 
        # X = self._final_conv(X.permute(0, 3, 1, 2))
        X = self._pool1_2(X)
        X = self._conv2_2(X)  # => (16, 1, 318, 1, 128)
        X = X.permute(0, 4, 2, 3, 1) 
        X = self._pool2_2(X) 
        X = self.layer_norm(X)
        X = self._conv3_2(X)  # => (16, 1, 318, 1, 128) 
        X = X.permute(0, 4, 2, 3, 1) 
        X = self._pool3_2(X) 
        X = self._conv4_2(X)  # => (16, 1, 318, 1, 128) 
        X = X.permute(0, 4, 2, 3, 1) 
        X = torch.cat((X, Z_), 1)

        X = self._final_conv(X) # => (16, 6, 318, 1, 1) 
        # X = torch.squeeze(X, 2)  # => (16, 6, 318, 1)  
        X = torch.squeeze(X, 4)  # => (16, 6, 318, 1)  
        X = X.permute(0, 2, 3, 1)   # => (16, 318, 1, 6)  
        # X = torch.squeeze(X, 4)
        # print(X.shape)
        # (b,c_out*T,N)->(b,N,T)
        # X = X[:, :, :, -1] # (b,c_out*T,N) for example (32, 12, 307)
        # print(X.shape)
        # X = X.permute(0, 2, 1) # (b,T,N)-> (b,N,T)
        # print(X.shape)
        return X #(b,N,T) for exmaple (32, 307,12)

## Load Data

In [None]:
def calculate_lap_pos(g, k):
    A = sparse.coo_matrix.todense(g.adj(scipy_fmt = 'coo'))
    A = np.array(A)
    L = calculate_normalized_laplacian(A)
    lap_pos = laplacian_positional_encoding(L, k)
    lap_pos = np.real(lap_pos)
    return lap_pos

In [None]:
scaler = ZScaler()
g = METR_LAGraphDataset()
train_data = METR_LATrainDataset()
test_data = METR_LATestDataset()
valid_data = METR_LAValidDataset()

# g = PEMS_BAYGraphDataset()
# train_data = PEMS_BAYTrainDataset()
# test_data = PEMS_BAYTestDataset()
# valid_data = PEMS_BAYValidDataset()

# g = BufflogroveGraphDataset()
# train_data = BuffalogroveTrainDataset(num_steps = 3)
# # test_data = BuffalogroveTestDataset()
# valid_data = BuffalogroveValDataset(num_steps = 3)

# g = GurneeGraphDataset()
# train_data = GurneeTrainDataset(num_steps = 6)
# test_data = GurneeTestDataset()
# valid_data = GurneeValDataset(num_steps = 6)

num_node = g.num_nodes()
num_edges = g.num_edges()
mean = train_data.mean
std = train_data.std

print("Mean: ", mean)
print("Std: ", std)
print("Num node", num_node)
print("Num edge: ", num_edges)

(23974, 12, 207, 2)
Mean:  [54.4059283   0.49721458]
Std:  [19.49373927  0.28892871]
Num node 207
Num edge:  1722


In [None]:
mean = mean.reshape(2)
std = std.reshape(2)
print("Mean: ", mean)
print("Std: ", std)

Mean:  [54.4059283   0.49721458]
Std:  [19.49373927  0.28892871]


In [None]:
print("Train size: ", len(train_data))
print("Valid size: ", len(valid_data))
print("Test size: ", len(test_data))

In [None]:
lap_pos = calculate_lap_pos(g, 100)
lap_pos = torch.tensor(lap_pos).float()
g.ndata['lap_pos'] = lap_pos

In [None]:
batch_size = 128

In [None]:
train_loader = DataLoader(train_data, batch_size= batch_size, num_workers= 2, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size= batch_size, num_workers= 2, shuffle=False)
test_loader  = DataLoader(test_data,  batch_size= batch_size, num_workers= 2, shuffle=False)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
def masked_mae_loss(y_pred, y_true):
    mask = (y_true != 0).float()
    mask /= mask.mean()
    loss = torch.abs(y_pred - y_true)
    loss = loss * mask
    # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
    loss[loss != loss] = 0
    return loss.mean()

def masked_mse_loss(y_pred, y_true):
    mask = (y_true != 0).float()
    mask /= mask.mean()
    loss = (y_pred- y_true)**2
    loss = loss * mask
    # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
    loss[loss != loss] = 0
    return loss.mean()

def masked_mape_loss(y_pred, y_true):
    mask = (y_true != 0).float()
    mask /= mask.mean()
    loss = (y_true - y_pred).abs() / y_true.abs()
    loss = loss * mask
    # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
    loss[loss != loss] = 0
    return loss.mean()

## Load Model

In [None]:
num_decode_steps = seq_len = 12
num_encode_steps = 12
model = None
model = STGPCN(n_filter=num_encode_steps,
               in_channels = 2, time_strides = 1, num_for_predict = num_decode_steps, len_input = num_encode_steps
               ).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, amsgrad = True)
loss_fn = masked_mae_loss
print('Trainable parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))

Start
Trainable parameters: 80016


In [None]:
# Run this cell when you already have a checkpoint
#checkpoint = torch.load('model_stt_week_day_sample.pt')
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

## Train 

In [None]:
# Run this cell to load the model that is saved during training
checkpoint = torch.load('best/model_metrla_stgpcn_window.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
best_loss = [100]
try:
    best_loss[0] = checkpoint['loss']
    epoch = checkpoint['epoch']
except:
    best_loss[0] = 100
    epoch = 0

In [None]:
epoch = 0
best_loss = [100]

In [None]:
print("Best Loss: ", best_loss[0])
print("Epoch: ", epoch)

In [None]:
import numpy as np
def train_1(model, num_node, g, train_loader, lap_pos, scaler, loss_fn, optimizer, batch_size, num_samples, seq_len):
    total_loss = []
    for idx, (x, y) in tqdm(enumerate(train_loader), total = num_samples // batch_size):
        '''
        lap_pos = g.ndata['lap_pos']
        sign_flip = torch.rand(lap_pos.size(1))
        sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
        lap_pos = lap_pos * sign_flip.unsqueeze(0)
        g.ndata['lap_pos'] = lap_pos
        '''
        
        bs, _, num_node, dim = x.shape

        batch_g = dgl.batch([g] * bs).to(device)
        y = y[..., 0]
        # x = x.permute(1, 0, 2, 3)
        # y = y.permute(1, 0, 2)
        x = x.permute(0, 2, 3, 1)[:, :, 0, :].unsqueeze(2)
        y = y.permute(0, 2, 1)[:, :, :seq_len]

        x = scaler.scale(x, mean[0],std[0])
        x = torch.tensor(x).float().clone().to(device)
        #y = torch.tensor(y).float().clone().to(device)

        targets = y[:-1]
        targets = scaler.scale(targets, mean[0],std[0])
        targets = torch.tensor(targets).float().clone().to(device)
        y = torch.tensor(y).float().clone().to(device)

        out = model(x).squeeze()
        #out = scaler.inverse_scale(out, mean[0], std[0])
        out = out * std[0] + mean[0]
        loss = loss_fn(out, y)
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss.append(loss.detach().item())
    return np.mean(total_loss)
   

def eval_1(model, num_node, g, val_loader, lap_pos, scaler, loss_fn, batch_size, num_samples, seq_len):
    model.eval()
    total_loss = []
    for idx, (x, y) in tqdm(enumerate(val_loader), total = num_samples // batch_size):
        '''
        lap_pos = g.ndata['lap_pos']
        sign_flip = torch.rand(lap_pos.size(1))
        sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
        lap_pos = lap_pos * sign_flip.unsqueeze(0)
        g.ndata['lap_pos'] = lap_pos
        '''
        
        bs, _, num_node, dim = x.shape

        batch_g = dgl.batch([g] * bs).to(device)

        y = y[..., 0]
        # x = x.permute(1, 0, 2, 3)
        # y = y.permute(1, 0, 2)
        x = x.permute(0, 2, 3, 1)[:, :, 0, :].unsqueeze(2)
        y = y.permute(0, 2, 1)[:, :, :seq_len]

        x = scaler.scale(x, mean[0],std[0])
        x = torch.tensor(x).float().clone().to(device)
        #y = torch.tensor(y).float().clone().to(device)

        targets = y[:-1]
        targets = scaler.scale(targets, mean[0],std[0])
        targets = torch.tensor(targets).float().clone().to(device)
        y = torch.tensor(y).float().clone().to(device)
  
        with torch.no_grad():
            out = model(x).squeeze()
        #out = scaler.inverse_scale(out, mean[0], std[0])
        out = out * std[0] + mean[0]
        loss = loss_fn(out, y)
        #loss = torch.sqrt(loss_fn(out, y))
        total_loss.append(loss.detach().item())
        del loss

    return np.mean(total_loss)

def test_1(model, num_node, g, val_loader, lap_pos, scaler, metrics, batch_size, num_samples, num_steps, seq_len):
    model.eval()
    total_rmse = []
    total_mae = []
    total_mape = []
    predictions = []
    actuals = []
    for idx, (x, y) in tqdm(enumerate(val_loader), total = num_samples // batch_size):

        '''
        g.ndata['lap_pos'] = lap_pos
        '''
        
        bs, _ , num_node, dim = x.shape
        
        batch_g = dgl.batch([g] * bs).to(device)

        y = y[..., 0]
        # x = x.permute(1, 0, 2, 3)
        # y = y.permute(1, 0, 2)
        print(y.shape)
        x = x.permute(0, 2, 3, 1)[:, :, 0, :].unsqueeze(2)
        y = y.permute(0, 2, 1)[:, :, :seq_len]

        x = scaler.scale(x, mean[0],std[0])
        x = torch.tensor(x).float().clone().to(device)
        #y = torch.tensor(y).float().clone().to(device)

        targets = y[:-1]
        targets = scaler.scale(targets, mean[0],std[0])
        targets = torch.tensor(targets).float().clone().to(device)
        y = torch.tensor(y).float().clone().to(device)

        with torch.no_grad():
            out = model(x).squeeze()
        #out = scaler.inverse_scale(out, mean[ 0], std[0])
        out = out * std[0] + mean[0]

        out = out
        y = y
        mae = metrics['mae'](out, y).detach().item()
        rmse = torch.sqrt(metrics['mse'](out, y)).detach().item()
        mape = metrics['mape'](out, y).detach().item()

        total_mae.append(mae)
        total_rmse.append(rmse)
        total_mape.append(mape)
        
        out = out.detach().cpu().numpy()
        y = y.detach().cpu().numpy()

        out = np.transpose(out, axes = (1, 2, 0))
        y = np.transpose(y, axes = (1, 2, 0))

        predictions += out[:, -1, 0].tolist()
        actuals += y[:, -1, 0].tolist()

    return np.mean(total_mae), np.mean(total_rmse), np.mean(total_mape), predictions, actuals


In [None]:
# print("Best Loss", best_loss[0])
for e in range(1, 31):
    print("Epoch: ", epoch + e)
    print("Best Validation Loss", best_loss[0])
    tmp_loss = best_loss[0]
    # loss_fn = nn.L1Loss() # We should use L1Loss for STREETS because there is no nan and 0 is meaningful
    train_loss = train_1(model, num_node, g, train_loader, lap_pos, scaler, loss_fn, optimizer, batch_size, len(train_data), seq_len)
    val_loss = eval_1(model, num_node, g, valid_loader, lap_pos, scaler, loss_fn, batch_size, len(valid_data), seq_len)
    torch.save({
            'epoch' : epoch + e, 
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss' : val_loss,
            }, 'weight/model_metrla_stgpcn_window.pt')
    #print("Train Loss: {}".format(train_loss))
    print("Train Loss: {}, Val Loss: {}".format(train_loss, val_loss))
    
    if best_loss[0] > val_loss:
        best_loss[0] = val_loss
        torch.save({
                'epoch' : epoch + e, 
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss' : val_loss,
                }, 'best/model_metrla_stgpcn_window.pt')
        print("------------- Saved model ------------")
    print()
    

In [None]:
for x in checkpoint:
  print(x)

## Test

In [None]:
# Run this cell to load the model that is saved during training
checkpoint = torch.load('best/model_metrla_stgpcn_window.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
try:
    best_loss[0] = checkpoint['loss']
    epoch = checkpoint['epoch']
except:
    best_loss[0] = 100
    epoch = 100

In [None]:
print(best_loss[0])
print(epoch)

In [None]:
metrics = {
    'mae' : masked_mae_loss,
    'mse' : masked_mse_loss,
    'mape': masked_mape_loss
}
mae_loss, rmse_loss, mape_loss, predictions, gts = test_1(model, num_node, g, test_loader, lap_pos, scaler, metrics, batch_size, len(test_data), 12)
print("MAE: {}".format(mae_loss))
print("RMSE: {}".format(rmse_loss))
print("MAPE: {}".format(mape_loss))

## Visualize

In [None]:
print(predictions[10])
print(gts[10])

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def visualize(preds, gts, start, end, data):
    horizon = end - start
    x = np.arange(horizon)
    plt.plot(x, preds[start : end], label = 'Prediction')
    plt.plot(x, gts[start : end], label = 'Ground Truth')
    plt.legend(loc = "lower center")
    plt.savefig(f"report/{data}_{start}_{end}.png")
    plt.show()

In [None]:
predictions = np.array(predictions)
gts = np.array(gts)

In [None]:
visualize(predictions, gts, 1200, 1300, 'metr-la')

In [None]:
import random
for i in range(10):
    idx = random.randint(0, len(gts))
    print("Actual: ", gts[idx])
    print("Prediction: ", predictions[idx])
    print()
    print()    

# STREETS

In [None]:
# This box is for STREETS dataset

testing_dates = ['2019-7-15', '2019-7-16', '2019-7-17', '2019-7-18']
test_loaders = {}

for date in testing_dates:
    # test_data = BuffalogroveTestDataset(date, num_steps = 6)
    test_data = GurneeTestDataset(date, num_steps = 3)
    print(len(test_data))
    test_loader  = DataLoader(test_data,  batch_size= 16, num_workers= 2, shuffle=False)
    print(len(test_loader))
    test_loaders[date] = test_loader
metrics = {
    'mae' : masked_mae_loss,
    'mse' : masked_mse_loss,
    'mape': masked_mape_loss
}
results = {}
gts = {}
preds = {}
for date in testing_dates:
    test_loader = test_loaders[date]
    val_loss = eval_1(model, num_node, g, test_loader, lap_pos, scaler, loss_fn, batch_size, len(test_data), seq_len)
    results[date] = val_loss
for date in testing_dates:
    print(results[date])
    print()