In [103]:
!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter torch-sparse torch-geometric -f https://data.pyg.org/whl/torch-1.10.1+cu113.html
!pip install ../tsl
%load_ext tensorboard
%tensorboard --logdir logs

Processing /home/cerchio/Documents/code/GDL_2022/tsl
Building wheels for collected packages: torch-spatiotemporal
  Building wheel for torch-spatiotemporal (setup.py) ... [?25ldone
[?25h  Created wheel for torch-spatiotemporal: filename=torch_spatiotemporal-0.1.1e-py3-none-any.whl size=159933 sha256=aca47987e35a315da21e66d3ed8ef6e2bd1f23416efd2352253e0443bc1899c3
  Stored in directory: /tmp/pip-ephem-wheel-cache-hszoqtsh/wheels/4a/16/06/49d24bb60326d9f13e0316f97aa48f7ac973d9f710885c9d33
Failed to build torch-spatiotemporal
Installing collected packages: torch-spatiotemporal
  Attempting uninstall: torch-spatiotemporal
    Found existing installation: torch-spatiotemporal 0.1.1e
    Uninstalling torch-spatiotemporal-0.1.1e:
      Successfully uninstalled torch-spatiotemporal-0.1.1e
    Running setup.py install for torch-spatiotemporal ... [?25ldone
[?25h[33m  DEPRECATION: torch-spatiotemporal was installed using the legacy 'setup.py install' method, because a wheel could not be bui

Reusing TensorBoard on port 6006 (pid 32114), started 11:17:12 ago. (Use '!kill 32114' to kill it.)

In [104]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import tsl
from tsl.data import SpatioTemporalDataset, SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler
from tsl.datasets import MetrLA, PemsBay, Portland
from tsl.nn.metrics.metrics import MaskedMAE, MaskedMSE
from tsl.utils.neptune_utils import TslNeptuneLogger
import torch
import numpy as np
from tsl.predictors import Predictor
import tsl
import torch

In [105]:
np.set_printoptions(suppress=True)
tsl.logger.disabled = True
print(f"tsl version  : {tsl.__version__}")
print(f"torch version: {torch.__version__}")

tsl version  : 0.1.1e
torch version: 1.10.1+cu111


In [106]:
# Datasets initialization

portland_dataset = Portland()
metr_la_dataset = MetrLA()
pems_dataset = PemsBay()

portland_adj = portland_dataset.get_connectivity(threshold=0,
                                                 include_self=False,
                                                 normalize_axis=1,
                                                 layout="edge_index")


metr_la_adj = metr_la_dataset.get_connectivity(threshold=0,
                                               include_self=False,
                                               normalize_axis=1,
                                               layout="edge_index")


pems_adj = pems_dataset.get_connectivity(threshold=0,
                                         include_self=False,
                                         normalize_axis=1,
                                         layout="edge_index")

portland_edge_index, portland_edge_weight = portland_adj
metr_la_edge_index, metr_la_edge_weight = metr_la_adj
pems_edge_index, pems_edge_weight = pems_adj


In [107]:
# encode time of the day and use it as exogenous variable.
portland_exog_vars = portland_dataset.datetime_encoded('day').values
portland_exog_vars = {'global_u': portland_exog_vars}
# encode time of the day and use it as exogenous variable.
metr_la_exog_vars = metr_la_dataset.datetime_encoded('day').values
metr_la_exog_vars = {'global_u': metr_la_exog_vars}
# encode time of the day and use it as exogenous variable.
pems_exog_vars = pems_dataset.datetime_encoded('day').values
pems_exog_vars = {'global_u': pems_exog_vars}

In [108]:
from tsl.data import SpatioTemporalDataset

portland_torch = SpatioTemporalDataset(*portland_dataset.numpy(return_idx=True),
                                       connectivity=portland_adj,
                                       horizon=12,
                                       window=12,
                                       mask=portland_dataset.mask,
                                       exogenous=portland_exog_vars
                                       )

metr_la_torch = SpatioTemporalDataset(*metr_la_dataset.numpy(return_idx=True),
                                      connectivity=metr_la_adj,
                                      horizon=12,
                                      window=12,
                                      mask=metr_la_dataset.mask,
                                      exogenous=metr_la_exog_vars
                                      )

pems_torch = SpatioTemporalDataset(*pems_dataset.numpy(return_idx=True),
                                   connectivity=pems_adj,
                                   horizon=12,
                                   window=12,
                                   mask=pems_dataset.mask,
                                   exogenous=pems_exog_vars
                                   )


In [109]:
from tsl.data import SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler

scalers = {'data': StandardScaler(axis=(0, 1))}

portland_splitter = portland_dataset.get_splitter(val_len=0.1, test_len=0.2)
metr_la_splitter = metr_la_dataset.get_splitter(val_len=0.1, test_len=0.2)
pems_splitter = pems_dataset.get_splitter(val_len=0.1, test_len=0.2)


portland_dm = SpatioTemporalDataModule(
    dataset=portland_torch,
    scalers=scalers,
    splitter=portland_splitter,
    batch_size=64,
)

metr_la_dm = SpatioTemporalDataModule(
    dataset=metr_la_torch,
    scalers=scalers,
    splitter=metr_la_splitter,
    batch_size=64,
)


pems_dm = SpatioTemporalDataModule(
    dataset=pems_torch,
    scalers=scalers,
    splitter=pems_splitter,
    batch_size=64,
)

In [110]:
# Execute preprocessing
portland_dm.setup()
metr_la_dm.setup()
pems_dm.setup()

In [111]:
# Setting loss functions and metrics
loss_function = MaskedMAE()

metrics = {
    # Mean absolute error
    "mae": MaskedMAE(compute_on_step=False),
    "mae_at_15": MaskedMAE(compute_on_step=False, at=2),
    "mae_at_30": MaskedMAE(compute_on_step=False, at=5),
    "mae_at_60": MaskedMAE(compute_on_step=False, at=11),
    # Mean Square Error
    "mse": MaskedMSE(compute_on_step=False),
    "mse_at_15": MaskedMSE(compute_on_step=False, at=2),
    "mse_at_30": MaskedMSE(compute_on_step=False, at=5),
    "mse_at_60": MaskedMSE(compute_on_step=False, at=11),
}


In [112]:
# Model hyperparameters
model_kwargs = {
    "input_size": portland_torch.n_channels,
    "exog_size": portland_torch.input_map.u.n_channels,
    "hidden_size": 32,
    "ff_size": 512,
    "output_size": portland_torch.n_channels,
    "n_layers": 8,
    "horizon": 12,
    "temporal_kernel_size": 1,
    "spatial_kernel_size": 2,
    "learned_adjacency": True,
    "n_nodes": portland_dataset.n_nodes,
    "emb_size": 10,
    "dilation": 1,
    "dilation_mod": 2,
    "norm": "batch",
    "dropout": 0.3,
}

# Predictior settings
predictor = Predictor(
    model_class=tsl.nn.models.stgn.GraphWaveNetModel,
    model_kwargs=model_kwargs,
    optim_class=torch.optim.Adam,
    optim_kwargs={"lr": 0.001, "weight_decay": 0.0001},
    loss_fn=loss_function,
    metrics=metrics,
)


  rank_zero_warn(


In [16]:
# Neptune logging initialization
npt_logger = TslNeptuneLogger(api_key="API_KEY",
                              project_name="username/project",
                              experiment_name="experiment_name",
                              tags=[],
                              params=model_kwargs,
                              upload_stdout=False)


https://app.neptune.ai/matteo-maggiolo/graph/e/GRAP-29
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [113]:
# Callback to save model parameters
checkpoint_callback = ModelCheckpoint(
    dirpath="logs",
    save_top_k=1,
    monitor="val_mae",
    mode="min",
)

# Early stopping
early_stop_callback = EarlyStopping(
    monitor='val_mae',
    patience=20,
    mode='min'
)

trainer = pl.Trainer(
    max_epochs=100,
    gpus=1,
    limit_train_batches=100,
    callbacks=[early_stop_callback, checkpoint_callback],
    logger=npt_logger,
)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [114]:
# Begin training
trainer.fit(predictor, datamodule=portland_dm)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | loss_fn       | MaskedMAE         | 0     
1 | train_metrics | MetricCollection  | 0     
2 | val_metrics   | MetricCollection  | 0     
3 | test_metrics  | MetricCollection  | 0     
4 | model         | GraphWaveNetModel | 749 K 
----------------------------------------------------
749 K     Trainable params
0         Non-trainable params
749 K     Total params
2.998     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [15]:
# Loading best model and test it
predictor.load_model(checkpoint_callback.best_model_path)
predictor.freeze()
performance = trainer.test(predictor, datamodule=pems_dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           1.7305392026901245
        test_mae            1.7305386066436768
     test_mae_at_15          1.448764443397522
     test_mae_at_30         1.7953178882598877
     test_mae_at_60         2.1269426345825195
        test_mse            14.747586250305176
     test_mse_at_15          9.119132995605469
     test_mse_at_30         15.373496055603027
     test_mse_at_60          22.11382293701172
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [17]:
# Log validation metrics
npt_logger.finalize('success')
performance[0]["df"] = "pems"
tsl.logger.info(performance)