In [None]:
import json

import pandas as pd
import numpy as np
import networkx as nx

import torch
from torch_geometric.nn import GATv2Conv, global_max_pool
from torch_geometric.data import Data
from torch_geometric.loader import GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler
from matplotlib import pyplot as plt

from train_utils import *
from product_graph import *
from tqdm.notebook import tqdm
from torch_geometric.utils import to_dense_adj
from torch.nn import MSELoss
from product_graph import generate_parametric_product_graph
import networkx as nx
from torch_geometric.utils import from_networkx
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import from_scipy_sparse_matrix
from sklearn.preprocessing import StandardScaler
import os

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
class GATv3Conv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, concat = True, heads=1) -> None:
        super().__init__()
        self.beta = torch.nn.Parameter(torch.tensor(0.5))
        self.conv = GATv2Conv(in_channels, out_channels, heads, concat, add_self_loops=False)

    def forward(self, x, edge_index, edge_weights):
        H, C = self.conv.heads, self.conv.out_channels

        if isinstance(x, torch.Tensor):
            assert x.dim() == 2
            x_l = self.conv.lin_l(x).view(-1, H, C)
            if self.conv.share_weights:
                x_r = x_l
            else:
                x_r = self.conv.lin_r(x).view(-1, H, C)
        else:
            raise TypeError("x must be a Tensor")

        assert x_l is not None
        assert x_r is not None

        # edge_updater_type: (x: PairTensor, edge_attr: OptTensor)
        alpha = self.conv.edge_updater(edge_index, x=(x_l, x_r), edge_attr=None)
        
        alpha = (1-self.beta) * alpha + self.beta * edge_weights.view(edge_weights.shape[0],1)
        # propagate_type: (x: PairTensor, alpha: Tensor)
        out = self.conv.propagate(edge_index, x=(x_l, x_r), alpha=alpha)

        if self.conv.concat:
            out = out.view(-1, self.conv.heads * self.conv.out_channels)
        else:
            out = out.mean(dim=1)

        if self.conv.bias is not None:
            out = out + self.conv.bias

        return out
    

class GATNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_size, out_dim, in_head=8, out_head=1, p=0.25) -> None:
        super().__init__()
        self.hid = hidden_size
        self.in_head = in_head
        self.out_head = out_head
        self.p = p
        
        self.conv1 = GATv3Conv(in_channels=in_dim, 
                               out_channels=self.hid, 
                               heads=self.in_head)
        
        self.conv2 = GATv3Conv(in_channels=self.hid*self.in_head, 
                               out_channels=self.hid, 
                               heads=self.out_head, 
                               concat=False)
        
        self.lin = nn.Linear(self.hid, out_dim)

    def forward(self, x, edge_index, edge_weight):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=self.p, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        
        batch = torch.zeros(x.size(0), dtype=torch.long)
        x = global_max_pool(x, batch, size=x.shape[0]//4)

        x = self.lin(x)
        return x

In [None]:
dynamic_data = torch.tensor(np.load("data/preprocessed/dynamic_data.npy", allow_pickle=True))
S = torch.tensor(np.load("data/adjacency/coords_features.npy", allow_pickle=False))
scaler = StandardScaler()
data_normalized = scaler.fit_transform(dynamic_data)

data = create_forecasting_dataset(data_normalized.T,
                                      splits = [0.8, 0.1, 0.1],
                                      pred_horizen= 1,
                                      obs_window= 4,
                                      verbose = 0)
edge_index = torch.nonzero(torch.tensor(S), as_tuple=False).t().contiguous()
edge_weight = S[edge_index[0], edge_index[1]]

criterion = MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GATNN(in_dim=1, hidden_size=32, out_dim=1)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)

temporal_adj = np.array([[0, 0, 0, 0],
                [1, 0, 0, 0],
                [0, 1, 0, 0],
                [0, 0, 1, 0]])

In [None]:

# loop over the number of samples
train_samples = data['trn']['data'].shape[0]
val_samples = data['val']['data'].shape[0]
train_losses = []
val_losses = []

patience = 5
best_val_loss = float('inf')
counter = 0
flag = False

for epoch in range(20):
    total_loss = 0
    for i in range(train_samples):
        # Create a torch geometric data over each graph 
        outer_batch = Data(x = torch.tensor(data['trn']['data'][i]), y = torch.tensor(data['trn']['labels'][i].squeeze()),
                        edge_index=edge_index, edge_weight = edge_weight) 
        
        train_loader = GraphSAINTNodeSampler(outer_batch, batch_size=100, num_steps=6)
        for inner_batch in train_loader:
            model.train()
            batch_loss = 0
            batch_adj = to_dense_adj(inner_batch.edge_index, edge_attr=inner_batch.edge_weight).squeeze(dim = 0)
            
            batch_adj = batch_adj.numpy()
            
            product_graph = generate_parametric_product_graph(s00 = 0, s01 = 1, s10 = 1, s11 = 1, A_T = temporal_adj, A_N = batch_adj, spatial_graph = None)
            product_edge_index, product_edge_weight = from_scipy_sparse_matrix(product_graph)
            
            
            batch_x =inner_batch.x.reshape(inner_batch.x.shape[0]* inner_batch.x.shape[1], 1)
            
            batch_y = inner_batch.y.unsqueeze(dim = 1)
            optimizer.zero_grad()
            out = model(batch_x.float(), product_edge_index,product_edge_weight.float())
            loss = criterion(out, batch_y.float())
            batch_loss += loss
        
            loss.backward()
            optimizer.step()
            total_loss += batch_loss
        
    train_losses.append((total_loss/ (len(train_loader) * train_samples)).detach().numpy())
    val_loss = 0
    for i in range(val_samples):
        model.eval()
        # Create a torch geometric data over each graph 
        outer_batch = Data(x = torch.tensor(data['val']['data'][i]), y = torch.tensor(data['val']['labels'][i].squeeze()),
                        edge_index=edge_index, edge_weight = edge_weight) 
        
        val_loader = GraphSAINTNodeSampler(outer_batch, batch_size=200, num_steps=6)
        for val_batch in val_loader:
            with torch.no_grad():
                batch_loss = 0
                batch_adj = to_dense_adj(val_batch.edge_index, edge_attr=val_batch.edge_weight).squeeze(dim = 0)
                
                batch_adj = batch_adj.numpy()
                
                product_graph = generate_parametric_product_graph(s00 = 0, s01 = 1, s10 = 1, s11 = 1, A_T = temporal_adj, A_N = batch_adj, spatial_graph = None)
                product_edge_index, product_edge_weight = from_scipy_sparse_matrix(product_graph)
                
                
                batch_x =val_batch.x.reshape(val_batch.x.shape[0]* val_batch.x.shape[1], 1)
                
                batch_y = val_batch.y.unsqueeze(dim = 1)
                out = model(batch_x.float(), product_edge_index,product_edge_weight.float())

                batch_val_loss = criterion(out, batch_y.float())

                val_loss += batch_val_loss
                
    val_losses.append(val_loss/(len(val_loader) * val_samples))
    print(f'Epoch: {epoch+1}, Training Loss: {total_loss/ (len(train_loader) * train_samples)}, Validation Loss: {val_loss/(len(val_loader) * val_samples)}')
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
    else:
        counter += 1
    if counter >= patience:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'best_val_loss': best_val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
        }, f'model_epoch_{epoch}.pt')
        flag = True
        break
    
if flag == False:
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'best_val_loss': best_val_loss,
        'train_losses': train_losses,
        'val_losses': val_losses,
    }, f'model_trained.pt')
    
        
        


In [None]:
train_losses1 = []
for loss in train_losses:
    train_losses1.append(loss.detach().numpy())

In [None]:
plt.figure(figsize = (5,5))
plt.xlabel("Epochs")
plt.ylabel("loss function")
plt.plot(train_losses1, label = "train loss")
plt.plot(val_losses, label = "validation loss")
plt.legend()
plt.show()

#### Inference code

In [None]:
test_loss = 0
test_samples = data['tst']['data'].shape[0]
for i in range(test_samples):
    model.eval()
    # Create a torch geometric data over each graph 
    outer_batch = Data(x = torch.tensor(data['val']['data'][i]), y = torch.tensor(data['val']['labels'][i].squeeze()),
                    edge_index=edge_index, edge_weight = edge_weight) 
    
    test_loader = GraphSAINTNodeSampler(outer_batch, batch_size=200, num_steps=6)
    for test_batch in test_loader:
        with torch.no_grad():
            batch_loss = 0
            batch_adj = to_dense_adj(test_batch.edge_index, edge_attr=test_batch.edge_weight).squeeze(dim = 0)
            
            batch_adj = batch_adj.numpy()
            
            product_graph = generate_parametric_product_graph(s00 = 0, s01 = 1, s10 = 1, s11 = 1, A_T = temporal_adj, A_N = batch_adj, spatial_graph = None)
            product_edge_index, product_edge_weight = from_scipy_sparse_matrix(product_graph)
            
            
            batch_x =test_batch.x.reshape(test_batch.x.shape[0]* test_batch.x.shape[1], 1)
            
            batch_y = test_batch.y.unsqueeze(dim = 1)
            out = model(batch_x.float(), product_edge_index,product_edge_weight.float())

            batch_test_loss = criterion(out, batch_y.float())

            test_loss += batch_test_loss
                
print(f'Total test loss: {test_loss}')
print(f'Test Loss per Batch: {test_loss/len(test_loader)}')