In [1]:
import torch
import os
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from sklearn import preprocessing

import torch
from torch.nn import ModuleList, ModuleDict
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData
from torch_geometric_temporal.nn.hetero import HeteroGCLSTM
from torch_geometric_temporal.signal import temporal_signal_split

from airpollution_trf_graph_loader import AirpollutionDatasetLoader

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
n_layers= 2
device_ =torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_

device(type='cpu')

## Load dataset

In [3]:
loader= AirpollutionDatasetLoader('madrid')
dataset=loader.get_dataset(T=6)
feature_dim= loader.get_feature_dim()
feature_dim

{'trf': 4, 'ap0': 2, 'ap1': 5, 'ap2': 2, 'ap3': 1}

## Define GNN

In [4]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, in_channels_dict, out_channels, metadata, nlayers=2):
        super(HeteroGNN, self).__init__()      
        self.linears= ModuleDict({v:torch.nn.Linear(128,i) for v,i in in_channels_dict.items()})
        self.n_conv_layers=nlayers

        self.convs=ModuleList()
        self.convs.append(HeteroGCLSTM(in_channels_dict=in_channels_dict, out_channels=128, metadata=metadata))
        
        new_in_channel_dict={v:128 for v,i in in_channels_dict.items()}
        for l in range(0,self.n_conv_layers-1):
            self.convs.append(HeteroGCLSTM(in_channels_dict=new_in_channel_dict, out_channels=128, metadata=metadata))
        

    def forward(self, x_dict, edge_index_dict, h_dict_lst, c_dict_lst):
        new_h_lst=[]
        new_c_lst=[]
        x= x_dict
        for i in range(0,self.n_conv_layers):
            h, c= self.convs[i](x, edge_index_dict)
            x = {key: val.relu() for key, val in h.items()}
            new_h_lst.append(x)
            new_c_lst.append(c)
        
        h= {v: self.linears[v](emb_) for v,emb_ in x.items()}
        new_h_lst.append(h)
        return new_h_lst, new_c_lst

embedding_dim=1
model = HeteroGNN(in_channels_dict=feature_dim, out_channels= embedding_dim, metadata=dataset[0].metadata(), nlayers=n_layers)
model = model.to(device_)
    
train_dataset, test_dataset = temporal_signal_split(dataset,  train_ratio=0.99)

## Train GNN

In [6]:
#Training parameters
n_epochs=600
batch_size= 24 * 1 #hours (3-day batch)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

def calculate_loss(y_hat_dict, y_dict):
    loss_= 0

    for p in y_hat_dict.keys():
        if p != 'trf':
            y_hat= y_hat_dict[p]
            y_hat= torch.nan_to_num(y_hat)
            particle_loss = torch.mean((y_hat-y_dict[p])**2) #MSE
            loss_ += particle_loss
    return loss_

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

early_stopper = EarlyStopper(patience=5, min_delta=0.01)

model.train()

for epoch in tqdm(range(n_epochs), desc='Training epochs...'):
    
    batch_loss = 0
    counter=1    
    
    train_epoch_cost =0
    eval_epoch_cost= 0
    h_lst=[None for i in range(0,n_layers)]
    c_lst=[None for i in range(0,n_layers)]
        
    #train
    for time, train_snapshot in tqdm(enumerate(train_dataset), desc='Train snapshots...', leave=False):

        h_lst, c_lst = model(train_snapshot.x_dict, train_snapshot.edge_index_dict, h_lst, c_lst)

        h_dict= h_lst[-1]

        snap_train_loss= calculate_loss(h_dict, train_snapshot.y_dict)
        
        train_epoch_cost = train_epoch_cost + snap_train_loss  
        batch_loss = batch_loss + snap_train_loss
        
        if counter == batch_size:
            batch_loss = batch_loss / batch_size
            batch_loss.backward(retain_graph=True)
            opt.step()
            opt.zero_grad()
            counter=1
            batch_loss= 0
            
            h_lst=[None for i in range(0,n_layers)]
            c_lst=[None for i in range(0,n_layers)]
            
        else:
            counter += 1
    
    
    train_epoch_cost = train_epoch_cost / (time+1)
    
    
    # validation 
    with torch.no_grad(): 
        eval_h_lst=[None for i in range(0,n_layers)]
        eval_c_lst=[None for i in range(0,n_layers)]
        for time, test_snapshot in tqdm(enumerate(test_dataset), desc='Eval snapshots...', leave=False):
            eval_h_lst, eval_c_lst = model(test_snapshot.x_dict, test_snapshot.edge_index_dict, eval_h_lst, eval_c_lst) 
            snap_eval_loss= calculate_loss(eval_h_lst[-1], test_snapshot, target_particles)
        
            eval_epoch_cost = eval_epoch_cost + snap_eval_loss
    eval_epoch_cost = eval_epoch_cost / (time+1)
   
    if early_stopper.early_stop(eval_epoch_cost):             
        print(f'EARLY STOP  AT epoch {epoch} - MSE (train): {train_epoch_cost}  - MSE (test): {eval_epoch_cost}')
        break
    print(f'Epoch {epoch} - MSE (train): {train_epoch_cost}  - MSE (test): {eval_epoch_cost}')

Training epochs...:   0%|          | 0/30 [00:00<?, ?it/s]

Train snapshots...: 0it [00:00, ?it/s]

KeyboardInterrupt: 