In [11]:
import math
import numpy as np
import wandb

import torch
import torch_geometric
from torch_geometric.data import Data

from gnn_architectures import GnnMultipleInputFeatures
from gnn_architectures import GnnBasic
from gnn_architectures import GnnWithPos

import gnn_io as gio
import gnn_architectures as garch

import pprint

## 1. Define model and parameters

In [12]:
wandb.login()

# Define parameters 

num_epochs = 40
batch_size = 20
lr = 0.001
project_name = 'multiple_features'
train_ratio = 0.8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.init(
        project=project_name,
        config={
            "epochs": num_epochs,
            "batch_size": batch_size,
            "lr": lr,
            'early_stopping_patience': 10,
            # "dropout": 0.15,
            })
config = wandb.config

# Pick a model
gnn_instance = GnnWithPos(in_channels=3, out_channels=1, hidden_size=32, gat_layers=2, heads=1, gcn_layers=1, output_layer='gcn', graph_layers_before=False)

model = gnn_instance.to(device)

# Define loss and optimizer

# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fct = torch.nn.MSELoss()

early_stopping = gio.EarlyStopping(patience=10, verbose=True)



0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██████████████
loss,▁
step,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██▁▁▂▂▂▂▃▃▃▄▄▄▄▅
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,1.0
loss,0.00707
step,28.0
train_loss,0.0061


<class 'int'>
<class 'float'>
Model initialized
GnnWithPos(
  (pointLayer): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=3, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
  ), global_nn=Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
  ))
  (graph_layers): Sequential(
    (0) - GATConv(32, 32, heads=1): x, edge_index -> x
    (1) - ReLU(inplace=True): x -> x
    (2) - GATConv(32, 32, heads=1): x, edge_index -> x
    (3) - ReLU(inplace=True): x -> x
    (4) - GCNConv(32, 32): x, edge_index -> x
    (5) - ReLU(inplace=True): x -> x
  )
  (output_layer): GCNConv(32, 1)
)
Sequential(
  (0) - GATConv(32, 32, heads=1): x, edge_index -> x
  (1) - ReLU(inplace=True): x -> x
  (2) - GATConv(32, 32, heads=1): x, edge_index -> x
  (3) - ReLU(inplace=T

## 2. Load data

In [13]:
# Load the list of dictionaries
data_dict_list = torch.load('../data/dataset_1pm_0-1382.pt')
# torch.load('../data/dataset_1pm_0-1382_with_more_infos.pt') if model_is_basic else 

# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]

# # Apply normalization to your dataset
dataset_normalized = gio.normalize_dataset(datalist)

baseline_error = gio.compute_baseline_error(datalist)
print(f'Baseline error: {baseline_error}')

Baseline error: 0.005565311759710312


## 3. Split into train and test set

In [14]:
train_dl = gio.create_dataloader(dataset=dataset_normalized, is_train=True, batch_size=config.batch_size, train_ratio=train_ratio)
valid_dl = gio.create_dataloader(dataset=dataset_normalized, is_train=False, batch_size=config.batch_size, train_ratio=train_ratio)
n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
print(n_steps_per_epoch)

Total dataset length: 1382
Training subset length: 1100
Total dataset length: 1382
Validation subset length: 260
55


## 4. Train the model

We first find a good model for one batch. 

In [15]:
def train(model, config=None):
    for epoch in range(config.epochs):
        model.train()
        # data = next(iter(train_dl))
        for idx, data in enumerate(train_dl):
            input_node_features, targets = data.normalized_x.to(device), data.normalized_y.to(device)
            optimizer.zero_grad()

            # Forward pass
            predicted = model(data)
            train_loss = loss_fct(predicted, targets)
            
            # Backward pass
            train_loss.backward()
            optimizer.step()
            
            wandb.log({"train_loss": train_loss.item(), "epoch": epoch, "step": idx})
            print(f"epoch: {epoch}, step: {idx}, loss: {train_loss.item()}")
        
        avg_loss = garch.validate_model_pos_features(model, valid_dl, loss_fct, device)
        print(f"epoch: {epoch}, validation loss: {avg_loss.item()}")
        wandb.log({"loss": avg_loss, "epoch": epoch})
            
        early_stopping(avg_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered. Stopping training.")
            break

    wandb.summary["val_loss"] = avg_loss
    wandb.finish()

In [16]:
# epoch: 0, step: 52, loss: 0.07759632170200348
# epoch: 0, step: 53, loss: 0.06423051655292511
# epoch: 0, step: 54, loss: 0.04370930418372154

In [17]:
train(model, config=config)

epoch: 0, step: 0, loss: 22.222427368164062
epoch: 0, step: 1, loss: 26.298938751220703
epoch: 0, step: 2, loss: 21.839475631713867
epoch: 0, step: 3, loss: 10.804563522338867
epoch: 0, step: 4, loss: 1.8818180561065674
epoch: 0, step: 5, loss: 0.9747294783592224
epoch: 0, step: 6, loss: 3.6631290912628174
epoch: 0, step: 7, loss: 3.910353660583496
epoch: 0, step: 8, loss: 1.997797966003418
