In [16]:
import json

import pandas as pd
import numpy as np
import networkx as nx

import torch
from torch_geometric.nn import GATv2Conv, global_max_pool
from torch_geometric.data import Data
from torch_geometric.loader import GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler
from matplotlib import pyplot as plt

from train_utils import *
from product_graph import *
from tqdm.notebook import tqdm
from torch_geometric.utils import to_dense_adj
from torch.nn import MSELoss
from product_graph import generate_parametric_product_graph
import networkx as nx
from torch_geometric.utils import from_networkx
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import from_scipy_sparse_matrix

In [39]:
class GATv3Conv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, concat = True, heads=1) -> None:
        super().__init__()
        self.beta = torch.nn.Parameter(torch.tensor(0.5))
        self.conv = GATv2Conv(in_channels, out_channels, heads, concat, add_self_loops=False)

    def forward(self, x, edge_index, edge_weights):
        H, C = self.conv.heads, self.conv.out_channels

        if isinstance(x, torch.Tensor):
            assert x.dim() == 2
            x_l = self.conv.lin_l(x).view(-1, H, C)
            if self.conv.share_weights:
                x_r = x_l
            else:
                x_r = self.conv.lin_r(x).view(-1, H, C)
        else:
            raise TypeError("x must be a Tensor")

        assert x_l is not None
        assert x_r is not None

        # edge_updater_type: (x: PairTensor, edge_attr: OptTensor)
        alpha = self.conv.edge_updater(edge_index, x=(x_l, x_r), edge_attr=None)
        
        alpha = (1-self.beta) * alpha + self.beta * edge_weights.view(edge_weights.shape[0],1)
        # propagate_type: (x: PairTensor, alpha: Tensor)
        out = self.conv.propagate(edge_index, x=(x_l, x_r), alpha=alpha)

        if self.conv.concat:
            out = out.view(-1, self.conv.heads * self.conv.out_channels)
        else:
            out = out.mean(dim=1)

        if self.conv.bias is not None:
            out = out + self.conv.bias

        return out
    

class GATNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_size, out_dim, in_head=8, out_head=1, p=0.25) -> None:
        super().__init__()
        self.hid = hidden_size
        self.in_head = in_head
        self.out_head = out_head
        self.p = p
        
        self.conv1 = GATv3Conv(in_channels=in_dim, 
                               out_channels=self.hid, 
                               heads=self.in_head)
        
        self.conv2 = GATv3Conv(in_channels=self.hid*self.in_head, 
                               out_channels=self.hid, 
                               heads=self.out_head, 
                               concat=False)
        
        self.lin = nn.Linear(self.hid, out_dim)

    def forward(self, x, edge_index, edge_weight):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=self.p, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        
        batch = torch.zeros(x.size(0), dtype=torch.long)
        x = global_max_pool(x, batch, size=x.shape[0]//4)

        x = self.lin(x)
        return x

In [28]:
dynamic_data = torch.tensor(np.load("data/preprocessed/dynamic_data.npy", allow_pickle=True))
S = torch.tensor(np.load("data/adjacency/coords_features.npy", allow_pickle=False))

calendar_df = pd.read_csv('data/preprocessed/calendar.csv')
data = Data(x = dynamic_data)
edge_index = torch.nonzero(torch.tensor(S), as_tuple=False).t().contiguous()
edge_weight = S[edge_index[0], edge_index[1]]
data.edge_index = edge_index
data.edge_weight = edge_weight

  edge_index = torch.nonzero(torch.tensor(S), as_tuple=False).t().contiguous()


In [42]:
criterion = MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GATNN(in_dim=1, hidden_size=32, out_dim=1).cpu()

loader = GraphSAINTNodeSampler (
    data,
    batch_size=200,
    num_steps=6,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)

temporal_adj = np.array([[0, 0, 0, 0],
                [1, 0, 0, 0],
                [0, 1, 0, 0],
                [0, 0, 1, 0]])

In [33]:
totalCount = 0

for param in model.parameters():
    totalCount += param.nelement()

totalCount

1451

In [43]:
train_loss = []
for epoch in range(20):
    model.train()
    total_loss = 0
    for i, batch in enumerate(loader):
        batch_loss = 0
        dataset = create_forecasting_dataset(batch.x.T,
                                        splits = None,
                                        pred_horizen= 1,
                                        obs_window= 4,
                                        verbose = 0)
        batch_sample = Data(x = torch.tensor(dataset['trn']['data']), y = torch.tensor(dataset['trn']['labels']), 
                            edge_index= batch.edge_index, edge_weight = batch.edge_weight)

        batch_adj = to_dense_adj(batch_sample.edge_index, edge_attr=batch_sample.edge_weight).squeeze(dim = 0).numpy()
        
        product_graph = generate_parametric_product_graph(s00 = 0, s01 = 1, s10 = 1, s11 = 1, A_T = temporal_adj, A_N = batch_adj, spatial_graph = None)
        product_edge_index, product_edge_weight = from_scipy_sparse_matrix(product_graph)
        
        batch_x = batch_sample.x.reshape(batch_sample.x.shape[0], batch_sample.x.shape[1]* batch_sample.x.shape[2], 1)
        batch_y = batch_sample.y.reshape(batch_sample.y.shape[0], batch_sample.y.shape[1])
        
        for i in tqdm(range(batch_sample.x.shape[0])):
            optimizer.zero_grad()
            out = model(batch_x[i].float(), product_edge_index,product_edge_weight.float())
            # print(f'out shape: {out.shape} \n y shape {batch_y[i].shape}')
            # print(out)
            loss = criterion(out, batch_y[i].float())
            batch_loss += loss
        
            loss.backward()
            optimizer.step()
            
        total_loss += batch_loss
        print(f'Epoch: {epoch} Batch Loss: {batch_loss}')
    train_loss.append(total_loss)

  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0 Batch Loss: 33260036.0


  0%|          | 0/275 [00:00<?, ?it/s]

Epoch: 0 Batch Loss: 16121438.0


  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0 Batch Loss: 16728816.0


  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0 Batch Loss: 22029212.0


  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0 Batch Loss: 25327138.0


  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0 Batch Loss: 18342316.0


  0%|          | 0/275 [00:00<?, ?it/s]

Epoch: 1 Batch Loss: 16614172.0


  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 1 Batch Loss: 16264415.0


  0%|          | 0/275 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 1 Batch Loss: 17461024.0


  0%|          | 0/275 [00:00<?, ?it/s]

Epoch: 1 Batch Loss: 18469132.0


  0%|          | 0/275 [00:00<?, ?it/s]

Epoch: 1 Batch Loss: 17200608.0


  0%|          | 0/275 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [114]:
print(out[:10])
print(batch_y[0])

tensor([[ 2.2100e-01],
        [ 1.3513e+03],
        [ 5.9633e+02],
        [ 7.5958e+02],
        [ 4.6309e+02],
        [ 1.1660e+03],
        [ 1.0380e+03],
        [ 7.7342e+02],
        [-7.0580e+01],
        [ 2.1796e+02]], grad_fn=<SliceBackward0>)
tensor([125, 300, 129,  80, 295,  60, 199,  81, 225,  85, 299,  80,  36,  99,
         50, 167, 289, 469,  80, 200, 354, 295, 228, 299, 140, 150, 125, 125,
        139, 225, 199, 249, 145, 100, 450, 300, 170, 137, 116, 150, 600, 149,
        197, 260, 200, 155, 155,  99, 139, 228, 199, 129, 155, 220, 110, 251,
        199, 160, 150, 195, 228, 231, 150, 279, 135, 189, 225, 200, 300, 265,
        299, 207, 305, 189, 186, 150, 199, 255, 130, 169, 139, 266, 225,  90,
        224, 194, 300,  60,  82,  78, 100, 200, 249, 149, 200, 457, 571, 177,
        150, 199, 275, 129, 179, 175, 242, 157, 160, 159, 299, 165, 115, 139,
        230, 499, 179, 690, 537, 150, 259, 375,  91, 259, 199, 139, 250, 185,
        239, 329, 231, 231, 125, 179, 109

In [76]:
print(out.shape[0]/ 4)

193.0


In [None]:
model.eval()
_, pred = model(x_test).max(dim=1)
correct = float(pred.eq(y_test).sum().item())
acc = correct / x_test.sum().item()
print('Accuracy: {:.4f}'.format(acc))

In [201]:
x = torch.tensor([[100], 
                  [1],
                  [2],
                  [3]]).float()

edge_index = torch.tensor([[0, 1, 1, 2, 3, 2], 
                           [1, 0, 2, 1, 2, 3]])

edge_weights = torch.tensor([[0.3],
                             [0.3],
                             [0.9],
                             [0.9],
                             [0.1],
                             [0.1]])

gat_net(x, edge_index, edge_weights)

tensor([[-5.4292],
        [-0.9369],
        [-4.6689],
        [-0.5501]], grad_fn=<AddBackward0>)