In [69]:
import math
import random
import pickle

import numpy as np
import pandas as pd
import geopandas as gpd
import tqdm
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, Subset
import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import LineGraph
from shapely.geometry import LineString
import gnn_io as gio

# Abstract

Here we investigate the data.

In [70]:
# Define parameters
num_epochs = 40
batch_size = 20
lr = 0.001
project_name = 'with_pos_features'
train_ratio = 0.8
wandb.login()

True

## 1. Load data and create the dataset

In [71]:
# from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius

# class SAModule(torch.nn.Module):
#     def __init__(self, ratio, r, nn):
#         super().__init__()
#         self.ratio = ratio
#         self.r = r
#         self.conv = PointNetConv(nn, add_self_loops=False)

#     def forward(self, x, pos, batch):
#         idx = fps(pos, batch, ratio=self.ratio)
#         row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64)
#         edge_index = torch.stack([col, row], dim=0)
#         x_dst = None if x is None else x[idx]
#         x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
#         pos, batch = pos[idx], batch[idx]
#         return x, pos, batch

# class GlobalSAModule(torch.nn.Module):
#     def __init__(self, nn):
#         super().__init__()
#         self.nn = nn

#     def forward(self, x, pos, batch):
#         x = self.nn(torch.cat([x, pos], dim=1))
#         x = global_max_pool(x, batch)
#         return x

# class Net(torch.nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.sa_module = SAModule(0.5, 0.2, MLP([3, 64, 128]))
#         self.global_sa_module = GlobalSAModule(MLP([128 + 3, 256, 512]))
#         self.mlp = MLP([512, 256, 1], dropout=0.5, norm=None)

#     def forward(self, data):
#         x, pos, batch = data.x, data.pos, data.batch
#         x, pos, batch = self.sa_module(x, pos, batch)
#         x = self.global_sa_module(x, pos, batch)
#         return self.mlp(x)

In [72]:
import torch
import torch.nn.functional as F
from torch_cluster import knn_graph
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import PointNetConv

class GnnWithPos(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        torch.manual_seed(12345)
        local_MLP_1 = nn.Sequential(
            nn.Linear(in_channels, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
        )
        
        global_MLP_1 = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.Linear(128, out_channels)
        )
        
        self.conv1 = PointNetConv(local_nn = local_MLP_1, global_nn = global_MLP_1)
        
        # local_MLP_2 = nn.Sequential(
        #     nn.Linear(128, 32),
        # )
        
        # global_MLP_2 = nn.Sequential(
        #     nn.Linear(32, 64),
        #     nn.ReLU(),
        #     nn.Linear(64, out_channels)
        # )
        # self.conv2 = PointNetConv(local_nn = local_MLP_2, global_nn = global_MLP_2)

    def forward(self, x, pos, edge_index):
        x = self.conv1(x=x, pos=pos, edge_index=edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        # x = self.conv2(x=x, pos=pos, edge_index=edge_index)
        # x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        return x

In [73]:
# Load the list of dictionaries
data_dict_list = torch.load('../results/dataset_1pm_0-1382.pt')

# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]

# Recreate the dataset
dataset = gio.MyGeometricDataset(datalist)

# Apply normalization to your dataset
dataset_normalized = gio.normalize_dataset(dataset)

In [74]:
dataset_normalized[0]

Data(x=[31216, 1], edge_index=[2, 59135], y=[31216, 1], pos=[31216, 2], normalized_x=[31216, 1], normalized_pos=[31216, 2], normalized_y=[31216, 1])

## 2. Load model and loss function

In [75]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.init(
        project=project_name,
        config={
            "epochs": num_epochs,
            "batch_size": batch_size,
            "lr": lr,
            # 'early_stopping_patience': 10,
            # "dropout": 0.15,
            })
config = wandb.config

model = GnnWithPos(3, 1).to(device)

# Define loss and optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fct = torch.nn.MSELoss()

## 3. Split into train and test set

In [76]:
train_dl = gio.create_dataloader(dataset=dataset_normalized, is_train=True, batch_size=config.batch_size, train_ratio=train_ratio)
valid_dl = gio.create_dataloader(dataset=dataset_normalized, is_train=False, batch_size=config.batch_size, train_ratio=train_ratio)
n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
print(n_steps_per_epoch)

Total dataset length: 1382
Training subset length: 1100
Total dataset length: 1382
Validation subset length: 260
55


In [77]:
train_dl.dataset[0]

Data(x=[31216, 1], edge_index=[2, 59135], y=[31216, 1], pos=[31216, 2], normalized_x=[31216, 1], normalized_pos=[31216, 2], normalized_y=[31216, 1])

## 4. Train the model

We first find a good model for one batch. 

In [78]:
def validate_model(model, valid_dl, loss_func, device):
    model.eval()
    val_loss = 0
    num_batches = 0
    with torch.inference_mode():
        for idx, data in enumerate(valid_dl):
            input_node_features, targets = data.normalized_x.to(device), data.normalized_y.to(device)
            predicted = model(data.normalized_x, data.normalized_pos, data.edge_index)
            val_loss += loss_func(predicted, targets).item()
            num_batches += 1
    return val_loss / num_batches if num_batches > 0 else 0


early_stopping = gio.EarlyStopping(patience=10, verbose=True)

for epoch in range(config.epochs):
    model.train()
    data = next(iter(train_dl))
    # for idx in range(len(train_dl)):
        
    # for idx, data in enumerate(train_dl):
    input_node_features, targets = data.normalized_x.to(device), data.normalized_y.to(device)
    predicted = model(data.normalized_x, data.normalized_pos, data.edge_index)
    train_loss = loss_fct(predicted, targets)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    # wandb.log({"train_loss": train_loss.item(), "epoch": epoch, "step": idx})
    print(f"epoch: {epoch}, train loss: {train_loss.item()}")
        
    val_loss = validate_model(model, valid_dl, loss_fct, device)
    wandb.log({"val_loss": val_loss})
    print(f"epoch: {epoch}, val_loss: {val_loss}")
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break
    
wandb.summary["val_loss"] = val_loss
wandb.finish()

epoch: 0, train loss: 0.006164898630231619
epoch: 0, val_loss: 0.005560705008415075
epoch: 1, train loss: 0.005566890817135572
epoch: 1, val_loss: 0.005820847044770534
EarlyStopping counter: 1 out of 10
epoch: 2, train loss: 0.0058210864663124084
epoch: 2, val_loss: 0.005818090903071256
EarlyStopping counter: 2 out of 10
epoch: 3, train loss: 0.005862414371222258
epoch: 3, val_loss: 0.00564519025815221
EarlyStopping counter: 3 out of 10
epoch: 4, train loss: 0.0057152011431753635
epoch: 4, val_loss: 0.005556888555964598
epoch: 5, train loss: 0.005566077772527933
epoch: 5, val_loss: 0.005594104850808015
EarlyStopping counter: 1 out of 10
epoch: 6, train loss: 0.005610513035207987
epoch: 6, val_loss: 0.0056577338837087154
EarlyStopping counter: 2 out of 10
epoch: 7, train loss: 0.005657061468809843
epoch: 7, val_loss: 0.0056619365842869645
EarlyStopping counter: 3 out of 10
epoch: 8, train loss: 0.005642305128276348
epoch: 8, val_loss: 0.005612973840190814
EarlyStopping counter: 4 out of

VBox(children=(Label(value='0.019 MB of 0.019 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
val_loss,▁██▃▁▂▄▄▂▁▁▂▂▂▁

0,1
val_loss,0.00557


In [79]:
targets.shape

torch.Size([624320, 1])

In [80]:
model

GnnWithPos(
  (conv1): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ), global_nn=Sequential(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=1, bias=True)
  ))
)