In [1]:
#  Copyright 2022 Institute of Advanced Research in Artificial Intelligence (IARAI) GmbH.
#  IARAI licenses this file to You under the Apache License, Version 2.0
#  (the "License"); you may not use this file except in compliance with
#  the License. You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

In [2]:
import os
import sys

In [3]:
# Alternatevly, in order to make the module imports work properly set PYTHONPATH=$PWD before launching the notebook server from the repo root folder.
sys.path.insert(0, os.path.abspath("../"))  # noqa:E402

![t4c20logo](../t4c20logo.png)

In [4]:
import statistics
from collections import defaultdict

import pandas as pd
import torch
import torch.nn.functional as F
import torch_geometric
import tqdm
from IPython.core.display import HTML
from IPython.display import display
from torch import nn
from torch_geometric.nn import MessagePassing
from pathlib import Path
import numpy as np

import t4c22
from t4c22.metric.masked_crossentropy import get_weights_from_class_fractions
from t4c22.misc.t4c22_logging import t4c_apply_basic_logging_config
from t4c22.t4c22_config import class_fractions
from t4c22.t4c22_config import load_basedir
from t4c22.dataloading.t4c22_dataset_geometric import T4c22GeometricDataset
from t4c22.plotting.plot_congestion_classification import plot_segment_classifications_simple
from t4c22.misc.notebook_helpers import restartkernel  # noqa:F401


In [5]:
%matplotlib inline
%load_ext autoreload
%load_ext time
%autoreload 2
%autosave 60
display(HTML("<style>.container { width:80% !important; }</style>"))

The time module is not an IPython extension.


Autosaving every 60 seconds


In [6]:
t4c_apply_basic_logging_config(loglevel="DEBUG")

In [7]:
# load BASEDIR from file, change to your data root
BASEDIR = load_basedir(fn="t4c22_config.json", pkg=t4c22)

In [8]:
city = "london"
# city = "melbourne"
# city = "madrid"

## Torch Geometric Dataset

In [9]:
%%time
dataset = T4c22GeometricDataset(root=BASEDIR, city=city, split="train", cachedir=Path("/tmp/processed"))
# train_dataset = T4c22GeometricDataset(root=BASEDIR, city=city, split="train", cachedir=Path("/tmp/processed5"), limit=1000)

CPU times: user 927 ms, sys: 67.9 ms, total: 995 ms
Wall time: 964 ms


In [10]:
len(dataset)

7040

In [11]:
%%time
# 2.41s -> 2.35ms from cachedir!!
dataset.get(0)

CPU times: user 1.51 ms, sys: 0 ns, total: 1.51 ms
Wall time: 1.27 ms


Data(x=[59110, 4], edge_index=[2, 132414], y=[132414])

In [12]:
spl = int(((0.8 * len(dataset)) // 2) * 2)
spl

5632

In [13]:
train_dataset, val_dataset = torch.utils.data.Subset(dataset,range(spl)), torch.utils.data.Subset(dataset,range(spl, len(dataset)))
train_dataset, val_dataset

(<torch.utils.data.dataset.Subset at 0x7f0dae03cb50>,
 <torch.utils.data.dataset.Subset at 0x7f0e053d2730>)

In [14]:
assert not (train_dataset[0].x.nan_to_num(-33) == val_dataset[0].x.nan_to_num(-33)).all()

In [15]:
train_dataset.dataset.day_t[train_dataset.indices[0]]

('2019-12-04', 24)

In [16]:
val_dataset.dataset.day_t[val_dataset.indices[0]]

('2019-08-03', 24)

In [17]:
len(train_dataset.dataset.day_t), len(val_dataset.dataset.day_t)

(7040, 7040)

In [18]:
len(train_dataset)

5632

In [19]:
len(val_dataset)

1408

## GNN Model

In [38]:
class Swish(nn.Module):
    def __init__(self, beta=1):
        super(Swish, self).__init__()
        self.beta = beta

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)


class GNN_Layer(MessagePassing):
    """
    Parameters
    ----------
    in_features : int
        Dimensionality of input features.
    out_features : int
        Dimensionality of output features.
    """

    def __init__(self, in_features, out_features, hidden_features):
        super(GNN_Layer, self).__init__(node_dim=-2, aggr="mean")

        self.message_net = nn.Sequential(
            nn.Linear(2 * in_features, hidden_features), Swish(), nn.BatchNorm1d(hidden_features), nn.Linear(hidden_features, out_features), Swish()
        )
        self.update_net = nn.Sequential(nn.Linear(in_features + hidden_features, hidden_features), Swish(), nn.Linear(hidden_features, out_features), Swish())

    def forward(self, x, edge_index, batch):
        """Propagate messages along edges."""
        x = self.propagate(edge_index, x=x)
        # x = self.norm(x, batch)
        return x

    def message(self, x_i, x_j):
        """Message update."""
        message = self.message_net(torch.cat((x_i, x_j), dim=-1))
        return message

    def update(self, message, x):
        """Node update."""
        x += self.update_net(torch.cat((x, message), dim=-1))
        return x


class CongestioNN(torch.nn.Module):
    def __init__(self, in_features=4, out_features=32, hidden_features=32, hidden_layer=1):

        super(CongestioNN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.hidden_layer = hidden_layer

        # in_features have to be of the same size as out_features for the time being
        self.cgnn = torch.nn.ModuleList(modules=[GNN_Layer(self.out_features, self.out_features, self.hidden_features) for _ in range(self.hidden_layer)])

        self.head_pre_pool = nn.Sequential(nn.Linear(self.out_features, self.hidden_features), Swish(), nn.Linear(self.hidden_features, self.hidden_features))
        self.head_post_pool = nn.Sequential(nn.Linear(self.hidden_features, self.hidden_features), Swish(), nn.Linear(hidden_features, 1))

        self.embedding_mlp = nn.Sequential(nn.Linear(self.in_features, self.out_features))

    def forward(self, data):
        batch = data.batch
        x = data.x
        edge_index = data.edge_index

        x = self.embedding_mlp(x)
        for i in range(self.hidden_layer):
            x = self.cgnn[i](x, edge_index, batch)

        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.swish = Swish()

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = self.swish(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)

        return x

## Training

In [39]:
city_class_fractions = class_fractions[city]
city_class_fractions

{'green': 0.5367906303432076,
 'yellow': 0.35138063340805714,
 'red': 0.11182873624873524}

In [40]:
city_class_weights = torch.tensor(get_weights_from_class_fractions([city_class_fractions[c] for c in ["green", "yellow", "red"]])).float()
city_class_weights

tensor([0.6210, 0.9486, 2.9807])

In [41]:
def train(model, predictor, dataset, optimizer, batch_size, device):
    model.train()

    losses = []
    optimizer.zero_grad()

    for data in tqdm.notebook.tqdm(
        torch_geometric.loader.dataloader.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16),
        "train",
        total=len(dataset) // batch_size,
    ):

        data = data.to(device)

        data.x = data.x.nan_to_num(-1)

        h = model(data)
        assert (h.isnan()).sum() == 0, h
        x_i = torch.index_select(h, 0, data.edge_index[0])
        x_j = torch.index_select(h, 0, data.edge_index[1])

        y_hat = predictor(x_i, x_j)

        y = data.y.nan_to_num(-1)
        y = y.long()

        loss_f = torch.nn.CrossEntropyLoss(weight=city_class_weights, ignore_index=-1)
        loss = loss_f(y_hat, y)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.cpu().item())

    return losses

In [42]:
@torch.no_grad()
def test(model, predictor, validation_dataset, batch_size, device):
    model.eval()
    total_loss = 0.0

    y_hat_list = []
    y_list = []

    for data in tqdm.notebook.tqdm(validation_dataset, "test", total=len(validation_dataset)):
        data = data.to(device)

        data.x = data.x.nan_to_num(-1)
        h = model(data)

        x_i = torch.index_select(h, 0, data.edge_index[0])
        x_j = torch.index_select(h, 0, data.edge_index[1])

        y_hat = predictor(x_i, x_j)

        y_hat_list.append(y_hat)
        y_list.append(data.y)

    y_hat = torch.cat(y_hat_list, 0)
    y = torch.cat(y_list, 0)
    y = y.nan_to_num(-1)
    y = y.long()
    loss = torch.nn.CrossEntropyLoss(weight=city_class_weights, ignore_index=-1)
    total_loss = loss(y_hat, y)
    print(f"total losses {total_loss}")
    return total_loss

In [43]:
hidden_channels = 256
num_layers = 10
batch_size = 2
eval_steps = 1
epochs = 20
runs = 1
dropout = 0.0
num_edge_classes = 3
num_features = 4

In [88]:
device = 0
device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
device = torch.device(device)


city_class_weights = city_class_weights.to(device)

model = CongestioNN(num_features, hidden_channels, hidden_channels, num_layers)
model = model.to(device)

predictor = LinkPredictor(hidden_channels, hidden_channels, num_edge_classes, num_layers, dropout).to(device)

train_losses = defaultdict(lambda: [])
val_losses = defaultdict(lambda: -1)

for run in tqdm.notebook.tqdm(range(runs), desc="runs", total=runs):
    # model.reset_parameters()
    predictor.reset_parameters()
    optimizer = torch.optim.AdamW(
            [
                {"params": model.parameters()},
                {"params": predictor.parameters()}
            ],
            lr=5e-4,
            weight_decay=0.001
        )

    for epoch in tqdm.notebook.tqdm(range(1, 1 + epochs), "epochs", total=epochs):
        losses = train(model, predictor, dataset=train_dataset, optimizer=optimizer, batch_size=batch_size, device=device)
        train_losses[(run, epoch)] = losses

        print(statistics.mean(losses))
        if epoch % eval_steps == 0:

            val_loss = test(model, predictor, validation_dataset=val_dataset, batch_size=batch_size, device=device)
            val_losses[(run, epoch)] = val_loss
            print(f"val_loss={val_loss} after epoch {epoch} of run {run}")
            torch.save(model.state_dict(), f"GNN_model_{epoch:03d}.pt")
            torch.save(predictor.state_dict(), f"GNN_predictor_{epoch:03d}.pt")

runs:   0%|          | 0/1 [00:00<?, ?it/s]

epochs:   0%|          | 0/19 [00:00<?, ?it/s]

train:   0%|          | 0/2816 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
for e, v in train_losses.items():
    print(e)
    print(statistics.mean(v))

In [None]:
for e, v in val_losses.items():
    print(e)
    print(v)

In [None]:
# free resources by restarting kernel
# restartkernel()