In [1]:
from graph_traffic.dataloading import graph_dataset, npzDataset
from graph_traffic.dcrnn import DiffConv
from graph_traffic.config import project_path
from graph_traffic.model import GraphRNN
from graph_traffic.train import train, eval
from graph_traffic.utils import NormalizationLayer, masked_mae_loss


from torch.utils.data import DataLoader
import dgl
import torch
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


## 0. Define training parameters

In [2]:
n_points = 10000
dataset_name = "madrid"
batch_size = 64
diffsteps = 2
decay_steps = 2000
lr = 0.01
minimum_lr = 2e-6
epochs = 100
max_grad_norm = 5.0
num_workers = 0
model = "dcrnn"
gpu = -1
num_heads = 2 # relevant for model="gaan"
out_feats = 256
num_layers = 2

In [3]:
if gpu == -1:
    device = torch.device('cpu')
else:
    device = torch.device('cuda:{}'.format(gpu))

## 1. Load data

In [4]:
g = graph_dataset(dataset_name)
train_data = npzDataset(dataset_name, "train", n_points)
test_data = npzDataset(dataset_name, "test", n_points)
valid_data = npzDataset(dataset_name, "valid", n_points)

seq_len = train_data.x.shape[1]
in_feats = train_data.x.shape[-1]

In [5]:
train_data.x.shape

(4993, 12, 5, 2)

In [6]:
train_loader = DataLoader(
    train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
valid_loader = DataLoader(
    valid_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = DataLoader(
    test_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)

normalizer = NormalizationLayer(train_data.mean, train_data.std)

## 2. Define the model

In [7]:
if model == "dcrnn":
    batch_g = dgl.batch([g] * batch_size).to(device)
    out_gs, in_gs = DiffConv.attach_graph(batch_g, diffsteps)
    net = partial(DiffConv, k=diffsteps, in_graph_list=in_gs, out_graph_list=out_gs)
elif model == 'gaan':
    print("not available")

dcrnn = GraphRNN(in_feats=in_feats,
                 out_feats=out_feats,
                 seq_len=seq_len,
                 num_layers=num_layers,
                 net=net,
                 decay_steps=decay_steps).to(device)

## 3. Define learning parameters

In [8]:
optimizer = torch.optim.Adam(dcrnn.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

loss_fn = masked_mae_loss

## 4. Train model

DespuÃ©s de hacer el cambio sigmoid -> tanh

In [None]:
for e in range(epochs):
    train(dcrnn, g, train_loader, optimizer, scheduler, normalizer, loss_fn, device, batch_size, max_grad_norm, minimum_lr)
    train_loss = eval(dcrnn, g, train_loader, normalizer, loss_fn, device, batch_size)
    valid_loss = eval(dcrnn, g, valid_loader, normalizer, loss_fn, device, batch_size)
    test_loss = eval(dcrnn, g, test_loader, normalizer, loss_fn, device, batch_size)
    print(f"Epoch: {e} Train Loss: {train_loss} Valid Loss: {valid_loss} Test Loss: {test_loss}")

Epoch: 8 Train Loss: 266.5638197057833 Valid Loss: 262.79227326094036 Test Loss: 268.1165119622273
Epoch: 9 Train Loss: 266.5960756501041 Valid Loss: 262.6397631446307 Test Loss: 268.3928483442049
Epoch: 10 Train Loss: 265.90474781037557 Valid Loss: 262.5918209236464 Test Loss: 269.1274005343653
Epoch: 11 Train Loss: 267.8067512259211 Valid Loss: 262.5564204427837 Test Loss: 268.3231428251254
Epoch: 12 Train Loss: 266.7422842779866 Valid Loss: 262.5454366798092 Test Loss: 269.5108677077794
Epoch: 13 Train Loss: 266.65345807973296 Valid Loss: 262.5086409722154 Test Loss: 267.63029021662504
Batch:  78

KeyboardInterrupt: 

Antes de hacer el cambio sigmoid -> tanh

In [9]:
for e in range(epochs):
    train(dcrnn, g, train_loader, optimizer, scheduler, normalizer, loss_fn, device, batch_size, max_grad_norm, minimum_lr)
    train_loss = eval(dcrnn, g, train_loader, normalizer, loss_fn, device, batch_size)
    valid_loss = eval(dcrnn, g, valid_loader, normalizer, loss_fn, device, batch_size)
    test_loss = eval(dcrnn, g, test_loader, normalizer, loss_fn, device, batch_size)
    print(f"Epoch: {e} Train Loss: {train_loss} Valid Loss: {valid_loss} Test Loss: {test_loss}")



Batch:  0
Batch:  1
Batch:  2
Batch:  3
Batch:  4
Batch:  5
Batch:  6
Batch:  7
Batch:  8
Batch:  9
Batch:  10
Batch:  11
Batch:  12
Batch:  13
Batch:  14
Batch:  15
Batch:  16
Batch:  17
Batch:  18
Batch:  19
Batch:  20
Batch:  21
Batch:  22
Batch:  23
Batch:  24
Batch:  25
Batch:  26
Batch:  27
Batch:  28
Batch:  29
Batch:  30
Batch:  31
Batch:  32
Batch:  33
Batch:  34
Batch:  35
Batch:  36
Batch:  37
Batch:  38
Batch:  39
Batch:  40
Batch:  41
Batch:  42
Batch:  43
Batch:  44
Batch:  45
Batch:  46
Batch:  47
Batch:  48
Batch:  49
Batch:  50
Batch:  51
Batch:  52
Batch:  53
Batch:  54
Batch:  55
Batch:  56
Batch:  57
Batch:  58
Batch:  59
Batch:  60
Batch:  61
Batch:  62
Batch:  63
Batch:  64
Batch:  65
Batch:  66
Batch:  67
Batch:  68
Batch:  69
Batch:  70
Batch:  71
Batch:  72
Batch:  73
Batch:  74
Batch:  75
Batch:  76
Batch:  77
Batch:  78
Epoch: 0 Train Loss: 417.7223258760343 Valid Loss: 411.6834200720523 Test Loss: 434.8402419671191
Batch:  0
Batch:  1
Batch:  2
Batch:  3
Bat

KeyboardInterrupt: 

## 5. Save model

In [None]:
torch.save(dcrnn.state_dict(), f"{project_path}/models/dcrnn.pt")