In [1]:
import math
import numpy as np
import wandb

import torch
import torch_geometric
from torch_geometric.data import Data

import sys
import os
from tqdm import tqdm

# Add the 'scripts' directory to the Python path
scripts_path = os.path.abspath(os.path.join('..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)
    
import joblib

# Now you can import the gnn_io module
import gnn_io as gio

import gnn_architectures as garch

## 1. Define model and parameters

In [2]:
# Define parameters 
num_epochs = 1000
unique_model_description = "mse_just_on_highways"
project_name = unique_model_description
path_to_save_dataloader = "../../data/data_created_during_training_needed_for_testing/"
indices_of_datasets_to_use = [0, 1, 2, 3]
batch_size = 16

data_dict_list = torch.load('../../data/train_data/dataset_1pm_0-2600_16_07.pt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
early_stopping = gio.EarlyStopping(patience=30, verbose=True)
torch.set_printoptions(precision=4, sci_mode=False)

## 2. Load data

In [3]:
# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]
dataset_only_relevant_dimensions = gio.cut_dimensions(dataset=datalist, indices_of_dimensions_to_keep=indices_of_datasets_to_use)
dataset_normalized = gio.normalize_dataset(dataset_only_relevant_dimensions)

In [4]:
baseline_error = gio.compute_baseline_of_no_policies(dataset_normalized)
print(f'Baseline error no policies: {baseline_error}')

baseline_error = gio.compute_baseline_of_mean_target(dataset_normalized)
print(f'Baseline error mean: {baseline_error}')

Baseline error no policies: 0.32162734866142273
Baseline error mean: 0.0032576548401266336


## 4. Train the model

We first find a good model for one batch. 

In [5]:
train_dl, valid_dl, test_dl = gio.create_dataloaders(batch_size = batch_size, dataset=dataset_normalized, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)
gio.save_dataloader(test_dl, path_to_save_dataloader + 'test_dl' + unique_model_description + '.pt')
gio.save_dataloader_params(test_dl, path_to_save_dataloader + 'test_loader_params' + unique_model_description+ '.json')

Total dataset length: 2580
Training subset length: 1805
Validation subset length: 387
Test subset length: 388


In [6]:
scaler_highways = joblib.load('x_scaler_dim_3.pkl')

In [7]:
highway_mapping = {
    'trunk': 0, 'trunk_link': 0, 'motorway_link': 0,
    'primary': 1, 'primary_link': 1,
    'secondary': 2, 'secondary_link': 2,
    'tertiary': 3, 'tertiary_link': 3,
    'residential': 4, 'living_street': 5,
    'pedestrian': 6, 'service': 7,
    'construction': 8, 'unclassified': 9,
    'np.nan': -1
}

higher_order_roads = ['primary', 'primary_link', 'secondary', 'secondary_link', 'tertiary', 'tertiary_link']

higher_order_values = {highway_mapping[road] for road in higher_order_roads if road in highway_mapping}

In [8]:
def get_indices_of_higher_order_roads(input_tensor, scaler, values_to_filter_for):
    values_dim3 = input_tensor[:, 3]
    original_highways_tensor = torch.tensor(scaler.inverse_transform(values_dim3.cpu().detach().numpy().reshape(-1, 1)).flatten(), dtype=torch.float)

    membership_mask = torch.tensor(np.isin(original_highways_tensor, list(values_to_filter_for)), dtype=torch.bool)
    indices = torch.nonzero(membership_mask).reshape(-1)
    return indices

In [9]:
def train(model, config=None, loss_fct=None, optimizer=None, train_dl=None, valid_dl=None, device=None, early_stopping=None):
    for epoch in range(config.epochs):
        model.train()
        actual_vals = []
        predictions = []
        for idx, data in tqdm(enumerate(train_dl)):
            input_node_features, targets = data.x.to(device), data.y.to(device)
            optimizer.zero_grad()
            
            # Extract values along the 3rd dimension
            indices = get_indices_of_higher_order_roads(input_tensor=input_node_features, scaler=scaler_highways, values_to_filter_for = higher_order_values)

            # Forward pass
            predicted = model(data)
            
            actual_vals.extend(targets.detach().numpy())
            predictions.extend(predicted.detach().numpy())
            
            filtered_predicted = torch.index_select(predicted, 0, indices)
            filtered_actual = torch.index_select(targets, 0, indices)
            
            # Backward pass
            train_loss = loss_fct(filtered_predicted, filtered_actual)
            
            train_loss.backward()
            optimizer.step()
            
            wandb.log({"train_loss": train_loss.item(), "epoch": epoch, "step": idx})
            # print(f"epoch: {epoch}, step: {idx}, loss: {train_loss.item()}")
        
        actual_vals = np.array(actual_vals)
        predictions = np.array(predictions)
        
        # Calculate R^2
        sst = ((actual_vals - actual_vals.mean()) ** 2).sum()
        ssr = ((actual_vals - predictions) ** 2).sum()
        r2 = 1 - ssr / sst

        val_loss = garch.validate_model_pos_features(model, valid_dl, loss_fct, device)
        print(f"epoch: {epoch}, validation loss: {val_loss}, R^2: {r2}")
        wandb.log({"loss": val_loss, "epoch": epoch, "r2": r2})
            
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered. Stopping training.")
            break
    
    print("Best validation loss: ", val_loss)
    wandb.summary["val_loss"] = val_loss
    wandb.finish()
    return val_loss, epoch

In [10]:
# Currently best architecture ! 
output_layer_parameter = 'gat'
hidden_size_parameter = 16
gat_layer_parameter = 0
gcn_layer_parameter = 0
lr = 0.001
in_channels = len(indices_of_datasets_to_use) + 2 # dimensions of the x vector + 2 (pos)

wandb.login()

wandb.init(
    project=project_name,
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "lr": lr,
        "early_stopping_patience": 10,
        "hidden_layer_size": hidden_size_parameter,
        "gat_layers": gat_layer_parameter,
        "gcn_layers": gcn_layer_parameter,
        "output_layer": output_layer_parameter,
        # "dropout": 0.15,
    }
)
config = wandb.config

print("output_layer: ", output_layer_parameter)
print("hidden_size: ", hidden_size_parameter)
print("gat_layers: ", gat_layer_parameter)
print("gcn_layers: ", gcn_layer_parameter)

gnn_instance = garch.MyGnn(in_channels=in_channels, out_channels=1, hidden_size=hidden_size_parameter, gat_layers=gat_layer_parameter, gcn_layers=gcn_layer_parameter, output_layer=output_layer_parameter)
model = gnn_instance.to(device)

best_val_loss, best_epoch = train(model, config=config, 
                                loss_fct=torch.nn.MSELoss(), 
                                optimizer=torch.optim.Adam(model.parameters(), lr=lr),
                                train_dl=train_dl, valid_dl=valid_dl,
                                device=device, early_stopping=early_stopping)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


output_layer:  gat
hidden_size:  16
gat_layers:  0
gcn_layers:  0
Model initialized
MyGnn(
  (pointLayer): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=6, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
  ), global_nn=Sequential(
    (0): Linear(in_features=16, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
  ))
  (output_layer): GATConv(16, 1, heads=1)
)


113it [02:23,  1.27s/it]


epoch: 0, validation loss: 0.003281985428184271, R^2: -8.60151481628418


113it [01:53,  1.00s/it]


epoch: 1, validation loss: 0.0032810638658702374, R^2: -0.0076198577880859375


113it [02:36,  1.39s/it]


epoch: 2, validation loss: 0.0032808151841163635, R^2: -0.0070950984954833984


113it [02:41,  1.43s/it]


epoch: 3, validation loss: 0.003294027671217918, R^2: -0.009050846099853516
EarlyStopping counter: 1 out of 30


113it [02:24,  1.28s/it]


epoch: 4, validation loss: 0.0032963852863758802, R^2: -0.011550545692443848
EarlyStopping counter: 2 out of 30


113it [02:13,  1.18s/it]


epoch: 5, validation loss: 0.003303424697369337, R^2: -0.012486696243286133
EarlyStopping counter: 3 out of 30


113it [02:20,  1.24s/it]


epoch: 6, validation loss: 0.0033206885773688554, R^2: -0.016606450080871582
EarlyStopping counter: 4 out of 30


113it [02:49,  1.50s/it]


epoch: 7, validation loss: 0.0033386405650526283, R^2: -0.022196292877197266
EarlyStopping counter: 5 out of 30


113it [02:16,  1.21s/it]


epoch: 8, validation loss: 0.003354023890569806, R^2: -0.027212977409362793
EarlyStopping counter: 6 out of 30


113it [02:04,  1.10s/it]


epoch: 9, validation loss: 0.0032959196250885725, R^2: -0.029925823211669922
EarlyStopping counter: 7 out of 30


113it [02:29,  1.32s/it]


epoch: 10, validation loss: 0.0033279075287282467, R^2: -0.01275944709777832
EarlyStopping counter: 8 out of 30


113it [02:18,  1.23s/it]


epoch: 11, validation loss: 0.003290270669385791, R^2: -0.014904022216796875
EarlyStopping counter: 9 out of 30


113it [02:39,  1.42s/it]


epoch: 12, validation loss: 0.0032363417651504276, R^2: -0.0031855106353759766


113it [02:13,  1.18s/it]


epoch: 13, validation loss: 0.0032415923196822405, R^2: 0.007543802261352539
EarlyStopping counter: 1 out of 30


113it [02:32,  1.35s/it]


epoch: 14, validation loss: 0.0032036192901432514, R^2: 0.011121034622192383


113it [01:57,  1.04s/it]


epoch: 15, validation loss: 0.0032418053597211838, R^2: 0.01250988245010376
EarlyStopping counter: 1 out of 30


113it [02:23,  1.27s/it]


epoch: 16, validation loss: 0.0032417068909853697, R^2: 0.010861873626708984
EarlyStopping counter: 2 out of 30


113it [02:50,  1.50s/it]


epoch: 17, validation loss: 0.0032354299910366535, R^2: 0.008889853954315186
EarlyStopping counter: 3 out of 30


113it [02:52,  1.53s/it]


epoch: 18, validation loss: 0.003253371436148882, R^2: 0.005366683006286621
EarlyStopping counter: 4 out of 30


113it [02:56,  1.56s/it]


epoch: 19, validation loss: 0.0032509691175073383, R^2: 0.004323720932006836
EarlyStopping counter: 5 out of 30


113it [02:35,  1.37s/it]


epoch: 20, validation loss: 0.0032521323300898075, R^2: -0.0007826089859008789
EarlyStopping counter: 6 out of 30


113it [02:05,  1.11s/it]


epoch: 21, validation loss: 0.003300392087548971, R^2: -0.004260897636413574
EarlyStopping counter: 7 out of 30


113it [01:59,  1.06s/it]


epoch: 22, validation loss: 0.0032572818361222746, R^2: -0.004711151123046875
EarlyStopping counter: 8 out of 30


113it [01:59,  1.05s/it]


epoch: 23, validation loss: 0.0032649263739585876, R^2: -0.0036742687225341797
EarlyStopping counter: 9 out of 30


113it [02:47,  1.48s/it]


epoch: 24, validation loss: 0.0032894092332571745, R^2: -0.004610657691955566
EarlyStopping counter: 10 out of 30


113it [02:21,  1.25s/it]


epoch: 25, validation loss: 0.00324463402852416, R^2: -0.004969358444213867
EarlyStopping counter: 11 out of 30


113it [03:08,  1.67s/it]


epoch: 26, validation loss: 0.0032453157380223276, R^2: -0.002275705337524414
EarlyStopping counter: 12 out of 30


113it [03:01,  1.61s/it]


epoch: 27, validation loss: 0.0032323764357715845, R^2: 0.002006232738494873
EarlyStopping counter: 13 out of 30


113it [02:31,  1.34s/it]


epoch: 28, validation loss: 0.003202410191297531, R^2: 0.010120987892150879


113it [02:23,  1.27s/it]


epoch: 29, validation loss: 0.0031648948788642883, R^2: 0.01764535903930664


113it [02:23,  1.27s/it]


epoch: 30, validation loss: 0.003126250570639968, R^2: 0.026110827922821045


113it [02:39,  1.41s/it]


epoch: 31, validation loss: 0.0031390692852437496, R^2: 0.03420901298522949
EarlyStopping counter: 1 out of 30


113it [02:37,  1.39s/it]


epoch: 32, validation loss: 0.003112494284287095, R^2: 0.04070371389389038


113it [02:28,  1.31s/it]


epoch: 33, validation loss: 0.0031253353226929903, R^2: 0.04372209310531616
EarlyStopping counter: 1 out of 30


113it [02:17,  1.22s/it]


epoch: 34, validation loss: 0.003151489654555917, R^2: 0.04562664031982422
EarlyStopping counter: 2 out of 30


113it [02:33,  1.36s/it]


epoch: 35, validation loss: 0.003090897109359503, R^2: 0.04713559150695801


113it [02:22,  1.26s/it]


epoch: 36, validation loss: 0.0030989712104201317, R^2: 0.0480959415435791
EarlyStopping counter: 1 out of 30


113it [02:34,  1.36s/it]


epoch: 37, validation loss: 0.0031098737567663193, R^2: 0.0488470196723938
EarlyStopping counter: 2 out of 30


113it [02:27,  1.31s/it]


epoch: 38, validation loss: 0.0030900144577026365, R^2: 0.05023854970932007


113it [02:07,  1.13s/it]


epoch: 39, validation loss: 0.0030855271127074955, R^2: 0.05151355266571045


113it [02:24,  1.28s/it]


epoch: 40, validation loss: 0.0031031116377562284, R^2: 0.05010414123535156
EarlyStopping counter: 1 out of 30


113it [02:20,  1.25s/it]


epoch: 41, validation loss: 0.0031201464403420687, R^2: 0.04955023527145386
EarlyStopping counter: 2 out of 30


113it [02:38,  1.40s/it]


epoch: 42, validation loss: 0.0031391219049692154, R^2: 0.048820436000823975
EarlyStopping counter: 3 out of 30


113it [02:01,  1.08s/it]


epoch: 43, validation loss: 0.003089974634349346, R^2: 0.04834699630737305
EarlyStopping counter: 4 out of 30


113it [02:16,  1.21s/it]


epoch: 44, validation loss: 0.003080454869195819, R^2: 0.04624456167221069


113it [02:10,  1.16s/it]


epoch: 45, validation loss: 0.0031060317996889353, R^2: 0.04412466287612915
EarlyStopping counter: 1 out of 30


113it [02:01,  1.08s/it]


epoch: 46, validation loss: 0.003117981133982539, R^2: 0.042471885681152344
EarlyStopping counter: 2 out of 30


113it [01:54,  1.01s/it]


epoch: 47, validation loss: 0.0031443524733185766, R^2: 0.04017436504364014
EarlyStopping counter: 3 out of 30


113it [01:53,  1.00s/it]


epoch: 48, validation loss: 0.0031462144199758766, R^2: 0.03293102979660034
EarlyStopping counter: 4 out of 30


113it [02:43,  1.45s/it]


epoch: 49, validation loss: 0.0032022811844944956, R^2: 0.026658356189727783
EarlyStopping counter: 5 out of 30


113it [02:21,  1.25s/it]


epoch: 50, validation loss: 0.003216008907184005, R^2: 0.02265620231628418
EarlyStopping counter: 6 out of 30


113it [02:24,  1.28s/it]


epoch: 51, validation loss: 0.003159933714196086, R^2: 0.02014625072479248
EarlyStopping counter: 7 out of 30


113it [02:10,  1.16s/it]


epoch: 52, validation loss: 0.003233073726296425, R^2: 0.016425669193267822
EarlyStopping counter: 8 out of 30


113it [01:51,  1.01it/s]


epoch: 53, validation loss: 0.003244029339402914, R^2: 0.01116710901260376
EarlyStopping counter: 9 out of 30


55it [00:52,  1.13it/s]wandb: Network error (ReadTimeout), entering retry loop.
113it [01:52,  1.00it/s]


epoch: 54, validation loss: 0.0032402009144425394, R^2: 0.006402134895324707
EarlyStopping counter: 10 out of 30


113it [01:49,  1.03it/s]


epoch: 55, validation loss: 0.0032286748662590983, R^2: 0.0025429725646972656
EarlyStopping counter: 11 out of 30


113it [01:52,  1.00it/s]


epoch: 56, validation loss: 0.0033197645097970963, R^2: -0.0015673637390136719
EarlyStopping counter: 12 out of 30


113it [01:48,  1.04it/s]


epoch: 57, validation loss: 0.003286795262247324, R^2: -0.007788181304931641
EarlyStopping counter: 13 out of 30


113it [01:50,  1.02it/s]


epoch: 58, validation loss: 0.0033140592370182274, R^2: -0.013729453086853027
EarlyStopping counter: 14 out of 30


113it [02:08,  1.14s/it]


epoch: 59, validation loss: 0.003321602726355195, R^2: -0.018412470817565918
EarlyStopping counter: 15 out of 30


34it [00:43,  1.28s/it]


KeyboardInterrupt: 

In [None]:
model_path = '../../data/trained_models/model_' + unique_model_description + '.pth'

# Save the model state dictionary and configuration
torch.save({
    'state_dict': model.state_dict(),
    'config': {
        'in_channels': model.in_channels,
        'out_channels': 1,
        'hidden_size': model.hidden_size,
        'gat_layers': model.gat_layers,
        'gcn_layers': model.gcn_layers,
        'output_layer': model.output_layer
    }
}, model_path)