In [1]:
import math
import random
import pickle

import numpy as np
import pandas as pd
import geopandas as gpd
import tqdm
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, Subset

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import LineGraph

from shapely.geometry import LineString

import gnn_io as gio

# Abstract

Here we investigate the data.

In [2]:
# Define parameters
num_epochs = 20
batch_size = 20
lr = 0.001
wandb_name = 'gnn_decrease_model_for_one_batch'
train_ratio = 0.8
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

## 1. Load data and create the dataset

In [3]:
class GnnModel(nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(12345)
        self.conv1 = torch_geometric.nn.GATConv(1, 16)
        # self.conv2 = torch_geometric.nn.GATConv(16, 16)
        self.conv3 = torch_geometric.nn.GATConv(16, 1)
        # self.conv3 = torch_geometric.nn.GCNConv(16, 1)
        # self.gat1 = torch_geometric.nn.GATConv(16, 16)
        # self.conv4 = torch_geometric.nn.GCNConv(16, 1)
                
        # self.convWithPos = torch_geometric.nn.conv.PointNetConv(1, 16, 3)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        # x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        # x = self.conv3(x, edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        # x = self.gat1(x, edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        # x = self.conv4(x, edge_index)
        return x

In [4]:
# Load the list of dictionaries
data_dict_list = torch.load('../results/dataset_1pm_0-1382.pt')

# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]

# Recreate the dataset
dataset = gio.MyGeometricDataset(datalist)

# Apply normalization to your dataset
dataset_normalized = gio.normalize_dataset(dataset)

### Approximate MSE - baseline error

In [5]:
y_values_normalized = np.concatenate([data.normalized_y for data in dataset_normalized.data_list])

# Compute the mean and standard deviation
mean_y_normalized = np.mean(y_values_normalized)
std_y_normalized = np.std(y_values_normalized)

print(f"Mean of y: {mean_y_normalized}")
print(f"Standard Deviation of y: {std_y_normalized}")

# Plot the distribution of y values
# plt.figure(figsize=(10, 6))
# n, bins, patches = plt.hist(y_values_normalized, bins=30, edgecolor='k', alpha=0.7)

# # Add bin labels
# bin_centers = 0.5 * (bins[:-1] + bins[1:])
# # for count, x in zip(n, bin_centers):
# #     plt.text(x, count, str(int(count)), ha='center', va='bottom')

# # Set the x-axis ticks and labels
# plt.xticks(bins, rotation=45)
# plt.title('Distribution of Normalized y Values')
# plt.xlabel('Normalized y')
# plt.ylabel('Frequency')
# plt.grid(True)
# plt.show()

mean_y_normalized = np.mean(y_values_normalized)

# Convert numpy arrays to torch tensors
y_values_normalized_tensor = torch.tensor(y_values_normalized, dtype=torch.float32)
mean_y_normalized_tensor = torch.tensor(mean_y_normalized, dtype=torch.float32)

# Create the target tensor with the same shape as y_values_normalized_tensor
target_tensor = mean_y_normalized_tensor * torch.ones_like(y_values_normalized_tensor)

# Instantiate the MSELoss function
mse_loss = torch.nn.MSELoss()

# Compute the MSE 
mse = mse_loss(y_values_normalized_tensor, target_tensor)

# Print the MSE value
print("Baseline error is: " + str(mse.item()))

Mean of y: 0.030937498435378075
Standard Deviation of y: 0.074600949883461
Baseline error is: 0.005565311759710312


## 2. Load model and loss function

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.init(
        project="check_errors",
        config={
            "epochs": num_epochs,
            "batch_size": batch_size,
            "lr": lr,
            'early_stopping_patience': 10,
            # "dropout": 0.15,
            })
config = wandb.config
model = GnnModel().to(device)

# Define loss and optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss_fct = torch.nn.MSELoss()

## 3. Split into train and test set

In [7]:
train_dl = create_dataloader(dataset=dataset_normalized, is_train=True, batch_size=config.batch_size, train_ratio=train_ratio)
valid_dl = create_dataloader(dataset=dataset_normalized, is_train=False, batch_size=config.batch_size, train_ratio=train_ratio)
n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
print(n_steps_per_epoch)

NameError: name 'create_dataloader' is not defined

## 4. Train the model

We first find a good model for one batch. 

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0


In [None]:
# find the average per batch

# Train the model
# mse_loss = 0
# counter = 0
early_stopping = EarlyStopping(patience=5, verbose=True)

for epoch in range(config.epochs):
    model.train()
    data = next(iter(train_dl))
    for idx in range(len(train_dl)):
        
    # for idx, data in enumerate(train_dl):
        input_node_features, targets = data.normalized_x.to(device), data.normalized_y.to(device)
        predicted = model(data)
        train_loss = loss_fct(predicted, targets)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        wandb.log({"train_loss": train_loss.item(), "epoch": epoch, "step": idx})
        # print(f"epoch: {epoch}, step: {idx}, loss: {train_loss.item()}")
        
        
    val_loss = validate_model(model, valid_dl, loss_fct, device)
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break
    
    wandb.log({"val_loss": val_loss})
    print(f"epoch: {epoch}, val_loss: {val_loss}")
        
wandb.summary["val_loss"] = val_loss
wandb.finish()

epoch: 0, val_loss: 36074.207331730766
epoch: 1, val_loss: 24183.85486778846
epoch: 2, val_loss: 78698.20673076923
epoch: 3, val_loss: 102636.74699519231
epoch: 4, val_loss: 101859.34975961539
epoch: 5, val_loss: 90070.79627403847
epoch: 6, val_loss: 76151.65204326923
epoch: 7, val_loss: 63441.45823317308
epoch: 8, val_loss: 52378.12319711538
epoch: 9, val_loss: 61294.62319711538
epoch: 10, val_loss: 45354.18028846154
epoch: 11, val_loss: 34868.485877403844
epoch: 12, val_loss: 27600.67022235577
epoch: 13, val_loss: 22172.89287860577
epoch: 14, val_loss: 18098.572415865383
epoch: 15, val_loss: 14953.845552884615
epoch: 16, val_loss: 12339.248422475961
epoch: 17, val_loss: 10288.269756610576
epoch: 18, val_loss: 8633.09990985577
epoch: 19, val_loss: 7207.927396334135
epoch: 20, val_loss: 6059.212552584135
epoch: 21, val_loss: 5042.471529447115
epoch: 22, val_loss: 4238.04052734375
epoch: 23, val_loss: 3513.580866887019
epoch: 24, val_loss: 2917.111328125
epoch: 25, val_loss: 2395.178297

[E thread_pool.cpp:130] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 