In [None]:
import os
import pickle as pk
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from utils import load_metr_la_data, normalize_data, split_dataset
from stgcn import STGCN
from train_epoch import train_epoch

use_gpu = False
num_timesteps_input = 12
num_timesteps_output = 3

epochs = 10
batch_size = 50

torch.manual_seed(7) # for reproducibility

In [2]:
A, X = load_metr_la_data()
A_wave, X, means, stds = normalize_data(A, X)
data = split_dataset(X, num_timesteps_input, num_timesteps_output)

train_input = data['train_input'] # ([20549, 207, 12, 2])
train_target = data['train_target'] # ([20549, 207, 3])
val_input = data['val_input'] # ([6840, 207, 12, 2])
val_target = data['val_target'] # ([6840, 207, 3])
test_input = data['test_input'] # [6841, 207, 12, 2])
test_target = data['test_target'] # ([6841, 207, 3])

In [3]:
device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

A_wave = A_wave.to(device=device) 

model = STGCN(A_wave.shape[0], # nodes,
            train_input.shape[3], # features
            num_timesteps_input, 
            num_timesteps_output).to(device=device) 

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_criterion = nn.MSELoss()

In [None]:
training_losses = []
validation_losses = []
validation_maes = []
for epoch in range(epochs):
    # 1 epoch training takes ~5 minutes
    loss = train_epoch(A_wave=A_wave, model=model, loss_criterion=loss_criterion, optimizer=optimizer, 
                        train_input=train_input, train_target=train_target,
                        batch_size=batch_size, device=device) # batch averaged loss of one epoch
    training_losses.append(loss)

    # Run validation
    with torch.no_grad(): # Disable gradient calculation for evaluation
        model.eval() # Set model to evaluation mode
        val_input = val_input.to(device=device)
        val_target = val_target.to(device=device)

        prediction = model(A_wave, val_input)
        val_loss = loss_criterion(prediction, val_target).to(device="cpu")
        validation_losses.append(val_loss.detach().numpy().item())

        prediction_unnormalized = prediction.detach().cpu().numpy()*stds[0]+means[0]
        target_unnormalized = val_target.detach().cpu().numpy()*stds[0]+means[0]
        mae = np.mean(np.absolute(prediction_unnormalized - target_unnormalized))
        validation_maes.append(mae)

        prediction_unnormalized = None
        val_input = val_input.to(device="cpu")
        val_target = val_target.to(device="cpu")

    if epochs % 10 == 0:
        print("epochs: ", epochs, 
              "Training loss: {}".format(training_losses[-1]),
              "Validation loss: {}".format(validation_losses[-1]))
        #print("Validation MAE: {}".format(validation_maes[-1]))

    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        pk.dump((training_losses, validation_losses, validation_maes), fd)

In [None]:
plt.plot(training_losses, label="training loss")
plt.plot(validation_losses, label="validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()