In [1]:
import math
import numpy as np
import wandb

import torch
import torch_geometric
from torch_geometric.data import Data

import sys
import os
from tqdm import tqdm

# Add the 'scripts' directory to the Python path
scripts_path = os.path.abspath(os.path.join('..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)
    
import joblib

# Now you can import the gnn_io module
import gnn_io as gio

import gnn_architectures as garch

## 1. Define model and parameters

In [2]:
# Define parameters 
num_epochs = 1000
unique_model_description = "mse_loss_batch_size_32"
project_name = "test_different_parameters"
path_to_save_dataloader = "../../data/data_created_during_training_needed_for_testing/"
indices_of_datasets_to_use = [0, 1, 2, 3]

loss_fct = torch.nn.MSELoss()
batch_size = 32
output_layer_parameter = 'gat'
hidden_size_parameter = 16
gat_layer_parameter = 0
gcn_layer_parameter = 0
lr = 0.001
in_channels = len(indices_of_datasets_to_use) + 2 # dimensions of the x vector + 2 (pos)
out_channels = 1 # we are predicting one value
early_stopping_patience = 10

data_dict_list = torch.load('../../data/train_data/dataset_1pm_0-3100.pt')

## 2. Load data

In [3]:
# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]
# dataset_only_relevant_dimensions = gio.cut_dimensions(dataset=datalist, indices_of_dimensions_to_keep=indices_of_datasets_to_use)
dataset_normalized = gio.normalize_dataset(datalist, y_scalar=None, x_scalar_list=None, pos_scalar=None, directory_path=path_to_save_dataloader)

In [4]:
baseline_error = gio.compute_baseline_of_no_policies(dataset=dataset_normalized, loss_fct=loss_fct)
print(f'Baseline error no policies: {baseline_error}')

baseline_error = gio.compute_baseline_of_mean_target(dataset=dataset_normalized, loss_fct=loss_fct)
print(f'Baseline error mean: {baseline_error}')

Baseline error no policies: 0.3216274082660675
Baseline error mean: 0.0032576550729572773


## 4. Train the model

We first find a good model for one batch. 

In [5]:
train_dl, valid_dl, test_dl = gio.create_dataloaders(batch_size = batch_size, dataset=dataset_normalized, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)
gio.save_dataloader(test_dl, path_to_save_dataloader + 'test_dl_' + unique_model_description + '.pt')
gio.save_dataloader_params(test_dl, path_to_save_dataloader + 'test_loader_params_' + unique_model_description+ '.json')

Total dataset length: 3079
Training subset length: 2155
Validation subset length: 461
Test subset length: 463


In [None]:
print(f"Running with {torch.cuda.device_count()} GPUS")
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Name is ", torch.cuda.get_device_name())

In [6]:
wandb.login()
wandb.init(
    project=project_name,
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "lr": lr,
        "early_stopping_patience": 10,
        "hidden_layer_size": hidden_size_parameter,
        "gat_layers": gat_layer_parameter,
        "gcn_layers": gcn_layer_parameter,
        "output_layer": output_layer_parameter,
        # "dropout": 0.15,
    }
)
config = wandb.config

print("output_layer: ", output_layer_parameter)
print("hidden_size: ", hidden_size_parameter)
print("gat_layers: ", gat_layer_parameter)
print("gcn_layers: ", gcn_layer_parameter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
early_stopping = gio.EarlyStopping(patience=early_stopping_patience, verbose=True)
# torch.set_printoptions(precision=4, sci_mode=False)

gnn_instance = garch.MyGnn(in_channels=in_channels, out_channels=out_channels, hidden_size=hidden_size_parameter, gat_layers=gat_layer_parameter, gcn_layers=gcn_layer_parameter, output_layer=output_layer_parameter)
model = gnn_instance.to(device)

best_val_loss, best_epoch = garch.train(model, config=config, 
                                loss_fct=loss_fct, 
                                optimizer=torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0),
                                train_dl=train_dl, valid_dl=valid_dl,
                                device=device, early_stopping=early_stopping)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


output_layer:  gat
hidden_size:  16
gat_layers:  0
gcn_layers:  0
Model initialized
MyGnn(
  (pointLayer): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=6, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
  ), global_nn=Sequential(
    (0): Linear(in_features=16, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
  ))
  (output_layer): GATConv(16, 1, heads=1)
)


68it [02:31,  2.23s/it]


epoch: 0, validation loss: 0.003323109323779742, R^2: -3.9493894577026367


68it [02:33,  2.26s/it]


epoch: 1, validation loss: 0.0032674474796901146, R^2: -0.00541377067565918


68it [03:20,  2.95s/it]


epoch: 2, validation loss: 0.0032648660552998384, R^2: -0.0026427507400512695


68it [02:13,  1.97s/it]


epoch: 3, validation loss: 0.0032609107283254465, R^2: -0.001674652099609375


68it [02:03,  1.82s/it]


epoch: 4, validation loss: 0.003254208977644642, R^2: -0.00012481212615966797


68it [02:52,  2.53s/it]


epoch: 5, validation loss: 0.0032450135486821333, R^2: 0.002370297908782959


68it [03:04,  2.72s/it]


epoch: 6, validation loss: 0.003232051121691863, R^2: 0.005751430988311768


68it [03:13,  2.84s/it]


epoch: 7, validation loss: 0.0032245160546153784, R^2: 0.009233832359313965


68it [02:57,  2.61s/it]


epoch: 8, validation loss: 0.0032216908720632395, R^2: 0.010629713535308838


68it [03:22,  2.98s/it]


epoch: 9, validation loss: 0.003219052916392684, R^2: 0.011434972286224365


68it [02:54,  2.57s/it]


epoch: 10, validation loss: 0.003216142083207766, R^2: 0.012291193008422852


68it [02:37,  2.31s/it]


epoch: 11, validation loss: 0.003211499626437823, R^2: 0.013523697853088379


68it [02:45,  2.44s/it]


epoch: 12, validation loss: 0.003208871263389786, R^2: 0.014580190181732178


68it [02:44,  2.42s/it]


epoch: 13, validation loss: 0.0032061389492203793, R^2: 0.015369296073913574


68it [02:17,  2.02s/it]


epoch: 14, validation loss: 0.0032038688970108826, R^2: 0.016164839267730713


68it [02:20,  2.07s/it]


epoch: 15, validation loss: 0.0032018546480685472, R^2: 0.016820251941680908


68it [02:44,  2.42s/it]


epoch: 16, validation loss: 0.0031996707276751597, R^2: 0.017463326454162598


68it [02:56,  2.59s/it]


epoch: 17, validation loss: 0.003197759405399362, R^2: 0.018095076084136963


68it [02:40,  2.36s/it]


epoch: 18, validation loss: 0.003195862052962184, R^2: 0.018681585788726807


68it [03:08,  2.78s/it]


epoch: 19, validation loss: 0.003194424556568265, R^2: 0.01919788122177124


68it [03:12,  2.83s/it]


epoch: 20, validation loss: 0.0031929502729326487, R^2: 0.01964282989501953


68it [02:41,  2.38s/it]


epoch: 21, validation loss: 0.00319167737228175, R^2: 0.02004474401473999


68it [02:28,  2.19s/it]


epoch: 22, validation loss: 0.003190187970176339, R^2: 0.020476460456848145


68it [03:32,  3.13s/it]


epoch: 23, validation loss: 0.003188674089809259, R^2: 0.02093404531478882


68it [03:24,  3.01s/it]


epoch: 24, validation loss: 0.0031871513773997625, R^2: 0.02140265703201294


68it [02:22,  2.09s/it]


epoch: 25, validation loss: 0.0031852211803197862, R^2: 0.021914541721343994


68it [02:08,  1.88s/it]


epoch: 26, validation loss: 0.0031834801038106283, R^2: 0.02250385284423828


68it [02:12,  1.95s/it]


epoch: 27, validation loss: 0.00318185289700826, R^2: 0.02300816774368286


68it [02:03,  1.82s/it]


epoch: 28, validation loss: 0.003180139713610212, R^2: 0.023525595664978027


68it [02:55,  2.58s/it]


epoch: 29, validation loss: 0.003178440996756156, R^2: 0.024049341678619385


68it [02:02,  1.80s/it]


epoch: 30, validation loss: 0.0031765950688471397, R^2: 0.024589896202087402


68it [02:06,  1.86s/it]


epoch: 31, validation loss: 0.003174522425979376, R^2: 0.025183677673339844


68it [02:32,  2.24s/it]


epoch: 32, validation loss: 0.003172103812297185, R^2: 0.025891423225402832


68it [02:24,  2.13s/it]


epoch: 33, validation loss: 0.0031698255334049463, R^2: 0.026587486267089844


68it [02:28,  2.18s/it]


epoch: 34, validation loss: 0.0031673623869816463, R^2: 0.02734076976776123


68it [02:13,  1.97s/it]


epoch: 35, validation loss: 0.003164690670867761, R^2: 0.02811497449874878


68it [02:01,  1.78s/it]


epoch: 36, validation loss: 0.003162029633919398, R^2: 0.02894139289855957


68it [02:20,  2.07s/it]


epoch: 37, validation loss: 0.003158428504442175, R^2: 0.029939234256744385


68it [01:59,  1.75s/it]


epoch: 38, validation loss: 0.0031553348526358604, R^2: 0.030917465686798096


68it [02:09,  1.91s/it]


epoch: 39, validation loss: 0.0031520352233201264, R^2: 0.031912803649902344


68it [02:02,  1.81s/it]


epoch: 40, validation loss: 0.0031485645876576504, R^2: 0.032928287982940674


68it [02:08,  1.88s/it]


epoch: 41, validation loss: 0.003146366433550914, R^2: 0.033883750438690186


68it [02:05,  1.85s/it]


epoch: 42, validation loss: 0.003142958258589109, R^2: 0.034695565700531006


68it [02:07,  1.88s/it]


epoch: 43, validation loss: 0.0031429524378230175, R^2: 0.03557199239730835


68it [02:02,  1.80s/it]


epoch: 44, validation loss: 0.003137087899570664, R^2: 0.03646647930145264


68it [02:11,  1.94s/it]


epoch: 45, validation loss: 0.0031340702281643946, R^2: 0.03742772340774536


68it [02:03,  1.82s/it]


epoch: 46, validation loss: 0.0031313172386338312, R^2: 0.038255274295806885


68it [02:08,  1.89s/it]


epoch: 47, validation loss: 0.0031347579633196196, R^2: 0.03864628076553345
EarlyStopping counter: 1 out of 10


68it [02:03,  1.81s/it]


epoch: 48, validation loss: 0.0031263914269705614, R^2: 0.03968387842178345


68it [02:12,  1.95s/it]


epoch: 49, validation loss: 0.00312458958166341, R^2: 0.040609776973724365


68it [02:04,  1.83s/it]


epoch: 50, validation loss: 0.0031223176202426354, R^2: 0.03958940505981445


68it [02:12,  1.94s/it]


epoch: 51, validation loss: 0.003119280794635415, R^2: 0.041932880878448486


68it [02:02,  1.81s/it]


epoch: 52, validation loss: 0.0031172239842514196, R^2: 0.04278844594955444


68it [02:06,  1.87s/it]


epoch: 53, validation loss: 0.003114877035841346, R^2: 0.043418169021606445


68it [01:57,  1.73s/it]


epoch: 54, validation loss: 0.0031138938075552383, R^2: 0.04410499334335327


68it [02:14,  1.98s/it]


epoch: 55, validation loss: 0.0031111822618792454, R^2: 0.044733524322509766


68it [02:02,  1.80s/it]


epoch: 56, validation loss: 0.0031083209595332544, R^2: 0.045394062995910645


68it [02:17,  2.02s/it]


epoch: 57, validation loss: 0.0031063090699414413, R^2: 0.04607963562011719


68it [02:02,  1.80s/it]


epoch: 58, validation loss: 0.0031033859432985384, R^2: 0.04679465293884277


68it [02:05,  1.85s/it]


epoch: 59, validation loss: 0.0031002128962427378, R^2: 0.047637879848480225


68it [02:06,  1.85s/it]


epoch: 60, validation loss: 0.003097946724543969, R^2: 0.04867422580718994


68it [02:09,  1.90s/it]


epoch: 61, validation loss: 0.0030931397496412197, R^2: 0.0497514009475708


68it [02:03,  1.81s/it]


epoch: 62, validation loss: 0.00308938369465371, R^2: 0.050909459590911865


68it [02:27,  2.17s/it]


epoch: 63, validation loss: 0.0030842740243921677, R^2: 0.05251669883728027


68it [02:54,  2.57s/it]


epoch: 64, validation loss: 0.003167106971765558, R^2: 0.029584288597106934
EarlyStopping counter: 1 out of 10


68it [02:21,  2.08s/it]


epoch: 65, validation loss: 0.00315715583662192, R^2: 0.029389917850494385
EarlyStopping counter: 2 out of 10


68it [01:57,  1.73s/it]


epoch: 66, validation loss: 0.0031493837169061104, R^2: 0.03202575445175171
EarlyStopping counter: 3 out of 10


68it [02:04,  1.83s/it]


epoch: 67, validation loss: 0.0031418907456099987, R^2: 0.034368276596069336
EarlyStopping counter: 4 out of 10


68it [02:26,  2.16s/it]


epoch: 68, validation loss: 0.003135247156023979, R^2: 0.036553382873535156
EarlyStopping counter: 5 out of 10


68it [02:17,  2.02s/it]


epoch: 69, validation loss: 0.0031298347438375156, R^2: 0.03843027353286743
EarlyStopping counter: 6 out of 10


68it [01:58,  1.74s/it]


epoch: 70, validation loss: 0.003124853844443957, R^2: 0.03998440504074097
EarlyStopping counter: 7 out of 10


68it [02:31,  2.22s/it]


epoch: 71, validation loss: 0.0031200635712593794, R^2: 0.04147493839263916
EarlyStopping counter: 8 out of 10


68it [02:11,  1.94s/it]


epoch: 72, validation loss: 0.0031150498272230227, R^2: 0.04297232627868652
EarlyStopping counter: 9 out of 10


68it [02:28,  2.18s/it]


epoch: 73, validation loss: 0.0031094498001039026, R^2: 0.04460328817367554
EarlyStopping counter: 10 out of 10
Early stopping triggered. Stopping training.
Best validation loss:  0.0031094498001039026


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
loss,█▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▃▃▂▂▂
r2,▁███████████████████████████████████████
step,█▇▆▅▅▄▃▂▁█▇▆▅▄▃▂▁█▇▆▅▅▄▃▂▁█▇▆▅▄▃▂▁█▇▆▅▅▄
train_loss,██▇▇▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▄▃▃▂▂

0,1
epoch,73.0
loss,0.00311
r2,0.0446
step,67.0
train_loss,0.00311
val_loss,0.00311


In [8]:
model

MyGnn(
  (pointLayer): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=6, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
  ), global_nn=Sequential(
    (0): Linear(in_features=16, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
  ))
  (output_layer): GATConv(16, 1, heads=1)
)

In [7]:
model_path = '../../data/trained_models/model_' + unique_model_description + '.pth'

# Save the model state dictionary and configuration
torch.save({
    'state_dict': model.state_dict(),
    'config': {
        'in_channels': model.in_channels,
        'out_channels': model.out_channels,
        'hidden_size': model.hidden_size,
        'gat_layers': model.gat_layers,
        'gcn_layers': model.gcn_layers,
        'output_layer': model.output_layer
    }
}, model_path)