In [1]:
import math
import random
import pickle

import numpy as np
import pandas as pd
import geopandas as gpd
import tqdm
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import LineGraph
from shapely.geometry import LineString
import gnn_io as gio

from gnn_architectures import GnnMultipleInputFeatures
from gnn_architectures import GnnBasic
from gnn_architectures import GnnWithPos

import gnn_architectures as garch

  Referenced from: <5AA8DD3D-A2CC-31CA-8060-88B4E9C18B09> /Users/elenanatterer/anaconda3/envs/ml_env/lib/python3.10/site-packages/torchvision/image.so
  warn(


In [2]:
# Define parameters
num_epochs = 40
batch_size = 20
lr = 0.001
project_name = 'multiple_features'
train_ratio = 0.8
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

## 1. Load data and create the dataset

In [3]:
# Load the list of dictionaries
data_dict_list = torch.load('../data/dataset_1pm_0-1382_with_more_infos.pt')

# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]

# # Apply normalization to your dataset
dataset_normalized = gio.normalize_dataset(datalist)

### Approximate MSE - baseline error

In [4]:
y_values_normalized = np.concatenate([data.normalized_y for data in dataset_normalized])

# Compute the mean and standard deviation
mean_y_normalized = np.mean(y_values_normalized)
std_y_normalized = np.std(y_values_normalized)

print(f"Mean of y: {mean_y_normalized}")
print(f"Standard Deviation of y: {std_y_normalized}")

# Convert numpy arrays to torch tensors
y_values_normalized_tensor = torch.tensor(y_values_normalized, dtype=torch.float32)
mean_y_normalized_tensor = torch.tensor(mean_y_normalized, dtype=torch.float32)

# Create the target tensor with the same shape as y_values_normalized_tensor
target_tensor = mean_y_normalized_tensor * torch.ones_like(y_values_normalized_tensor)

# Instantiate the MSELoss function
mse_loss = torch.nn.MSELoss()

# Compute the MSE 
mse = mse_loss(y_values_normalized_tensor, target_tensor)

# Print the MSE value
print("Baseline error is: " + str(mse.item()))

Mean of y: 0.030937498435378075
Standard Deviation of y: 0.074600949883461
Baseline error is: 0.005565311759710312


## 2. Load model and loss function

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.init(
        project=project_name,
        config={
            "epochs": num_epochs,
            "batch_size": batch_size,
            "lr": lr,
            'early_stopping_patience': 10,
            # "dropout": 0.15,
            })
config = wandb.config

# gnn_instance = GnnMultipleInputFeatures()
gnn_instance = GnnBasic()
gnn_instance = GnnWithPos(3, 1)
model = gnn_instance.to(device)

# Define loss and optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fct = torch.nn.MSELoss()

In [6]:
model

GnnBasic(
  (conv1): GCNConv(1, 16)
  (conv3): GCNConv(16, 1)
)

## 3. Split into train and test set

In [7]:
train_dl = gio.create_dataloader(dataset=dataset_normalized, is_train=True, batch_size=config.batch_size, train_ratio=train_ratio)
valid_dl = gio.create_dataloader(dataset=dataset_normalized, is_train=False, batch_size=config.batch_size, train_ratio=train_ratio)
n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
print(n_steps_per_epoch)

Total dataset length: 1382
Training subset length: 1100
Total dataset length: 1382
Validation subset length: 260
55


## 4. Train the model

We first find a good model for one batch. 

In [8]:
early_stopping = gio.EarlyStopping(patience=5, verbose=True)

for epoch in range(config.epochs):
    model.train()
    data = next(iter(train_dl))
    # for idx in range(len(train_dl)):
    # for idx, data in enumerate(train_dl):
    input_node_features, targets = data.normalized_x.to(device), data.normalized_y.to(device)
    predicted = model(data)
    train_loss = loss_fct(predicted, targets)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    # wandb.log({"train_loss": train_loss.item(), "epoch": epoch, "step": idx})
        # print(f"epoch: {epoch}, step: {idx}, loss: {train_loss.item()}")
        
    val_loss = garch.validate_model_basic(model, valid_dl, loss_fct, device)
    wandb.log({"val_loss": val_loss})
    print(f"epoch: {epoch}, val_loss: {val_loss}")
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break
    
wandb.summary["val_loss"] = val_loss
wandb.finish()

epoch: 0, val_loss: 0.006355380711074059
epoch: 1, val_loss: 0.006198488534069979
epoch: 2, val_loss: 0.006064380411631786
epoch: 3, val_loss: 0.005949619082877269
epoch: 4, val_loss: 0.005852257438863699
epoch: 5, val_loss: 0.005770632388213506
epoch: 6, val_loss: 0.005703596075853476
epoch: 7, val_loss: 0.005649861163244798
epoch: 8, val_loss: 0.005608262959867716
epoch: 9, val_loss: 0.0055776710550372414
epoch: 10, val_loss: 0.005556791590956541
epoch: 11, val_loss: 0.005544170844726837
epoch: 12, val_loss: 0.005538306019913692
epoch: 13, val_loss: 0.005537699693097518
epoch: 14, val_loss: 0.005540814608908617
EarlyStopping counter: 1 out of 5
epoch: 15, val_loss: 0.005546100616741639
EarlyStopping counter: 2 out of 5
epoch: 16, val_loss: 0.005552262246895295
EarlyStopping counter: 3 out of 5
epoch: 17, val_loss: 0.005558190437463613
EarlyStopping counter: 4 out of 5
epoch: 18, val_loss: 0.0055630795585994534
EarlyStopping counter: 5 out of 5
Early stopping triggered. Stopping train

0,1
val_loss,█▇▆▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
val_loss,0.00556
