## Temporal Graph Networks


Till now, we have been focused on the *static* graphs where the graph structures and the node features are *fixed* over time. However, there are domains where the graph changes over time. 

Temporal graphs can be divided in two categories:
- **Static graphs with temporal signals:** The underlying graph structure does not change over time, but features and labels evolve over time.
        <center><img src="images/static_structure_dynamic_features.png" width="500"></center>
    
A main example is for traffic forcasting where graphs are based on traffic sensor data (e.g. the PeMS dataset) where each sensor is a node, and the edges are road connections. The geographical distribution of sensors in PeMS is shown below:
        
<center><img src="images/pems.ppm" width="400"></center>

- **Dynamic graphs with temporal signals:** The topology of the graph (the presence of nodes and edges), features, and labels evolve over time.
<center><img src="images/dynamic_structure_dynamic_features.png" width="500"></center>
A main example is in a social network where new edges are added when people make new friends, existing edges are removed when people stop being friends, and node features change as people change their attributes, e.g., when they change their career assuming that career is one of the node features.
<center><img src="images/Dynamic_Graphs.png" width="500"></center>

> Note:\
>Dynamic graphs can be divided into *discrete-time* and *continuous-time* categories as well. 

- A discrete-time dynamic graph (DTDG) is a sequence $[G^{(1)}, G^{(2)},...,G^{(\tau)}]$ of graph snapshots where each $G^{(t)} = \left(V^{(t)},A^{(t)},X^{(t)}\right)$ has vertices $V^{(t)}$, adjacency matrix $A^{(t)}$ and feature matrix $X^{(t)}$. DTDGs mainly appear in applications where data is captured at reguarly-spaced intervals.

<center><img src="images/DTDG.png" width="700"></center>
<center><small>Image from https://graph-neural-networks.github.io/static/file/chapter15.pdf</small></center> 

- A continuous-time dynamic graph (CTDG) is a pair $\left(G^{(t_0)},O\right)$ where ${G^{(t_0)}=\left(V^{(t_0)},A^{(t_0)},X^{(t_0)}\right)}$ is a static initial graph at initial state time $t_0$ and $O$ is a sequence of temporal observations/events. Each observation is a tuple of the form *(event, event type,timestamp)* where *event type* can be a node or edge addition, node or edge deletion, node feature update, etc. *event* represents the actual event that happened, and *timestamp* is the time at which the event occured:

<center><img src="images/CTDG.png" width="400"></center>
<center><small>Image from https://arxiv.org/pdf/2404.18211v1</small></center>

We focus on DTDG in this tutorial.

## Combining GNNs with sequence models

DTDGs are made up of several snapshots arranged in order over time, which can be treated as sequential data. Temporal patterns in DTDGs are identified by looking at the relationships between these snapshots. Recurrent Neural Networks (RNNs) are often combined with GNNs to create dynamic models for DTDGs. These combinations are generally grouped into two types: stacked architectures and integrated architectures.

- **Stacked dynamic GNNs:** The most straightforward way to model a discrete dynamic graph is to have a separate GNN handle each snapshot of the graph and feed the output of each GNN to a time series component, such as an RNN. This is illustrated in the following Figure: 

<center><img src="images/stacked_DTDG.png" width="400"></center>
<center><small>Image from https://arxiv.org/pdf/2404.18211v1</small></center>

One of most well-known approaches in this cateogry is Waterfall Dynamic-GCN. In this architectures, a GCN is stacked with an LSTM per node. More specifically, at first separate GCNs (with same parameters) handle each snapshot of the graph and next the output of each GNN is sequentially given to a LSTM. In fact, a separate LSTM is used per node (although the weights across the LSTMs are shared). The architecture is illustaretd in the following Figure:
<center><img src="images/waterfall.png" width="700"></center>
<center><small>Image from https://arxiv.org/pdf/2005.07496</small></center>

The figure shows a network working on sequences of four snapshots of a graphs composed
of five vertices. The first GCN layer acts as four copies of a regular GCN layer, each one working on a snapshot of the sequence of the graphs. The output of this first layer is processed by the LSTM layer that acts as five copies of a LSTM, each one working on a nodes of the graphs.
The final fully-coonected (FC) layer produces the $C$-class probability vector for each nodes of every snapshot of the sequence. This layer, which produces the $C$-class probability vector for each node and for each instant of the sequence, can be seen as 5 x 4 copies of a FC layer.

- **Integrated dynamic GNNs**: 

Integrated DGNNs are networks that combine GNNs and RNNs in one layer and thus combine modelling of the
spatial and the temporal domain in that one layer.

One major break-through approach in this category is <i>**EvolveGCN**</i>. <u>EvolveGCN applies Temporal neural networks such as RNNs to the *GCN parameters* themselves.</u> Note that GCN parameters are considered temporal and not the embeddings. In EvolveGCN, the GCN evolves over time to produce relevant temporal node embeddings. The following figure illustrates a high-level view of EvolveGCN’s architecture to produce node embeddings for a static or dynamic graph with temporal signal:

<center><img src="images/Evolvegcn.png" width="700"></center>
<center><small>image from Labonne, Maxime. "Hands-On Graph Neural Networks Using Python: Practical techniques and architectures for building powerful graph and deep learning apps with PyTorch". Packt Publishing Ltd, 2023.</small></center>

but how to use RNN-based models to update the parameters of GCN according to the timesteps? EvolveGCN proposed two similar architectures, which we introduce only one of them, manily EvolveGCN-H. The main idea is shown below:

<center><img src="images/Evolvegcn-h.png" width="700"></center>

EvolveGCN-H utilizes a Gated Recurrent Unit (GRU) in place of a standard RNN. The GRU, a simplified variant of the Long Short-Term Memory (LSTM) unit, offers similar performance with fewer parameters. In this architecture, the hidden state of the GRU corresponds to the weight matrix of the GCN. 

Let:
- $H_t^{(l)}$ denote the node embeddings produced at the $l$-th layer and at the timestep $t$; (Note that $H_t^{(0)}=X$) 
- $W_{(t-1)}^{(l)}$ be the weight matrix for the GCN at layer $l$ and previous timestep $t-1$. 

At each time step $t$, the GRU takes the node embeddings from the previous layer, $H^{(l)}_t$, as input, and uses the GCN's weight matrix, $W^{(l)}_{t-1}$, as its hidden state. It then updates the $W$ matrix for layer $l$ at time $t$ as follows:

\begin{equation*}
W_t^{(l)} = GRU(H_t^{(l)}, W_{t-1}^{(l)})
\end{equation*}

The updated weight matrix is used to calculate the node embeddings for the $l+1$ layer:

\begin{equation*}
H^{(l+1)}_t = GCN(A_t, H_t^{(l)}, W_t^{(l)})
\end{equation*}





### Example: Using Temporal Graphs for Action based recognition

The skeleton joint order in UTD-MAD dataset:
    head,
    shoulder_center,
    spine,
    hip_center,
    left_shoulder,
    left_elbow,
    left_wrist,
    left_hand,
    right_shoulder,
    right_elbow,
    right_wrist,
    right_hand,
    left_hip,
    left_knee,
    left_ankle,
    left_foot,
    right_hip,
    right_knee,
    right_ankle,
    right_foot,

Actions in Dataset
    UTD-MHAD dataset consists of 27 different actions:
    right arm swipe to the left,
    right arm swipe to the right,
    right hand wave,
    two hand front clap,
    right arm throw,
    cross arms in the chest,
    basketball shoot,
    right hand draw x,
    right hand draw circle (clockwise),
    right hand draw circle (counter clockwise),
    draw triangle,
    bowling (right hand),
    front boxing,
    baseball swing from right,
    tennis right hand forehand swing,
    arm curl (two arms),
    tennis serve,
    two hand push,
    right hand knock on door,
    right hand catch an object,
    right hand pick up and throw,
    jogging in place,
    walking in place,
    sit to stand,
    stand to sit,
    forward lunge (left foot forward),
    squat (two arms stretch out)

Each skeleton data is a 20 x 3 x num_frame matrix. Each row of a skeleton frame corresponds to three spatial coordinates of a joint.

In [1]:
import os
from scipy.io import loadmat
import torch
from torch_geometric.data import Data
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_max_pool
from random import shuffle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch_geometric.nn import GCNConv
import numpy as np


class LSTMMatrixCell(nn.Module):
    def __init__(self, input_size1,input_size2, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size1 = input_size1
        self.input_size2 = input_size2

        # Input gate components
        self.W_ii = nn.Parameter(torch.Tensor(hidden_size, input_size1))
        self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_i = nn.Parameter(torch.Tensor(hidden_size, input_size2))

        # Forget gate components
        self.W_if = nn.Parameter(torch.Tensor(hidden_size, input_size1))
        self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_f = nn.Parameter(torch.Tensor(hidden_size, input_size2))

        # Cell gate components
        self.W_ig = nn.Parameter(torch.Tensor(hidden_size,input_size1))
        self.W_hg = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_g = nn.Parameter(torch.Tensor(hidden_size, input_size2))

        # Output gate components
        self.W_io = nn.Parameter(torch.Tensor(hidden_size, input_size1))
        self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_o = nn.Parameter(torch.Tensor(hidden_size, input_size2))

        self.init_weights()

    def init_weights(self):
        for param in self.parameters():
            nn.init.uniform_(param, -0.1, 0.1)

    def forward(self, x, hidden):
        h_prev, c_prev = hidden
        i_t = torch.sigmoid(self.W_ii @ x +  self.W_hi @ h_prev + self.b_i)
        f_t = torch.sigmoid(self.W_if @ x + self.W_hf @ h_prev + self.b_f)
        g_t = torch.tanh(self.W_ig @ x + self.W_hg @ h_prev + self.b_g)
        o_t = torch.sigmoid(self.W_io @ x + self.W_ho @ h_prev + self.b_o)

        c_t = f_t * c_prev + i_t * g_t
        h_t = o_t * torch.tanh(c_t)

        return h_t, c_t


class EvolveGCN_O(nn.Module):
    def __init__(self, in_channels1, in_channels2, hidden_channels, out_channels):
        super(EvolveGCN_O, self).__init__()
        self.hidden_channels, self.in_channels1, self.in_channels2 = hidden_channels, in_channels1, in_channels2
        # GCN layer (used for message passing)
        self.gcn = GCNConv(in_channels2, hidden_channels)

        # GRU to evolve the GCN weight matrix
        #self.gru = nn.GRU(in_channels, hidden_channels, batch_first=True)
        self.lstm = LSTMMatrixCell(in_channels1, in_channels2, hidden_channels)
        # Linear layer for final output
        self.fc1 = nn.Linear(hidden_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, out_channels)

        # Initialize the GCN weight matrix that will evolve over time
        
    def forward(self, snapshots):
        gcn_weights = torch.zeros(self.hidden_channels, self.in_channels2)
        c_t = torch.zeros(self.hidden_channels, self.in_channels2)
        for graph in snapshots:
            x_t = graph.x      # Node features at time t
            edge_index_t = graph.edge_index # Edge index at time t

            # Update the GCN weights using the GRU

            #for i in range(self.in_channels):
            #    self.gcn_weights[:,i] = self.gru(x_t, self.gcn_weights[:,i].unsqueeze(0))[1]
            gcn_weights, c_t = self.lstm(x_t,(gcn_weights,c_t))
            # Assign the updated weights to the GCN layer
            #with torch.no_grad():
            self.gcn.lin.weight = nn.Parameter(gcn_weights)

            # Perform the GCN operation on the snapshot
            h_t = self.gcn(x_t, edge_index_t)


        # Final output for all snapshot
        h_t= global_max_pool(h_t,torch.zeros(20, dtype=int))
        h_t = self.fc1(h_t)
        out_t = self.fc2(h_t)

        return out_t

# Example Training Loop
model = EvolveGCN_O(in_channels1=20, in_channels2=3, hidden_channels=64, out_channels=27)  # 27 action classes in UTD-MHAD

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

train = []
test = []
for file in os.listdir("Skeleton/"):
    if file.endswith("mat"):
        if "s2" in file or "s4" in file or "s6" in file or "s8" in file:
            train.append((loadmat("Skeleton/"+file)['d_skel'], file.split("_")[0]))
        else:
            test.append((loadmat("Skeleton/"+file)['d_skel'],file.split("_")[0]))

# Skeleton structure (joints connections) for UTD-MHAD
skeleton_edges = [
    (0, 1),  # Spine base to spine mid
    (1, 2),  # Spine mid to spine top
    (2, 3),  # Spine top to neck
    (3, 4),  # Neck to head

    (3, 5),  # Neck to left shoulder
    (5, 6),  # Left shoulder to left elbow
    (6, 7),  # Left elbow to left wrist
    (7, 8),  # Left wrist to left hand

    (3, 9),  # Neck to right shoulder
    (9, 10), # Right shoulder to right elbow
    (10, 11), # Right elbow to right wrist
    (11, 12), # Right wrist to right hand

    (0, 13), # Spine base to left hip
    (13, 14), # Left hip to left knee
    (14, 15), # Left knee to left ankle
    (15, 16), # Left ankle to left foot

    (0, 17), # Spine base to right hip
    (17, 18), # Right hip to right knee
    (18, 19)  # Right knee to right ankle
]


#Each skeleton data is a 20 x 3 x num_frame matrix.
def create_graphs_from_tuple(in_tuple):
    snapshots = []
    action = int(in_tuple[1].split('a')[1]) - 1
    frames = in_tuple[0]
    for frame_num in range(frames.shape[2]):
        joint_positions = frames[:,:,frame_num]
        # Convert joint positions to tensor (nodes)
        node_features = torch.tensor(joint_positions, dtype=torch.float).squeeze()

        # Convert edge list to tensor
        edge_index = torch.tensor(skeleton_edges, dtype=torch.long).t().contiguous()

        # Create PyG Data object
        snapshots.append(Data(x=node_features, edge_index=edge_index, y=action))
    return snapshots


train_graph_snapshots = [create_graphs_from_tuple(seq) for seq in train]
shuffle(train_graph_snapshots)
test_graph_snapshots = [create_graphs_from_tuple(seq) for seq in test]

train_loss = []
for epoch in range(100):
    model.train()
    losses = []
    for index,graph_snapshots in enumerate(train_graph_snapshots):
        optimizer.zero_grad()
        # Forward pass through the model
        out = model(graph_snapshots)  # Pass graph snapshots for one action sequence
        labels = torch.as_tensor([graph_snapshots[0].y])
        # Compute loss
        loss = criterion(out, labels)
        # Backpropagation
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        print(f'Epoch {epoch+1}: graph # {index}, Loss: {loss.item()}')
    print(f'Epoch {epoch+1}, Loss: {np.array(losses).mean()}')
    train_loss.append(np.array(losses).mean())
with torch.no_grad():
    model.eval()
    correct = 0
    total = 0
    for graph_snapshots in test_graph_snapshots:
        # Forward pass through the model
        out = model(graph_snapshots)  # Pass graph snapshots for one action sequence
        label = graph_snapshots[0].y
        # Compute loss
        pred = torch.argmax(out)
        if pred.item() == label:
            correct +=1
        total +=1
    print(f'Test accuracy: {correct/total}')



Epoch 1: graph # 0, Loss: 3.239790678024292
Epoch 1: graph # 1, Loss: 3.123340129852295
Epoch 1: graph # 2, Loss: 3.235542058944702
Epoch 1: graph # 3, Loss: 3.3277673721313477
Epoch 1: graph # 4, Loss: 3.53240704536438
Epoch 1: graph # 5, Loss: 3.4116873741149902
Epoch 1: graph # 6, Loss: 3.4452357292175293
Epoch 1: graph # 7, Loss: 3.036949872970581
Epoch 1: graph # 8, Loss: 3.458928108215332
Epoch 1: graph # 9, Loss: 3.210045337677002
Epoch 1: graph # 10, Loss: 3.4054267406463623
Epoch 1: graph # 11, Loss: 3.3841280937194824
Epoch 1: graph # 12, Loss: 2.8154563903808594
Epoch 1: graph # 13, Loss: 3.1482362747192383
Epoch 1: graph # 14, Loss: 3.657726287841797
Epoch 1: graph # 15, Loss: 3.6755919456481934
Epoch 1: graph # 16, Loss: 3.7113866806030273
Epoch 1: graph # 17, Loss: 3.742241144180298
Epoch 1: graph # 18, Loss: 3.7518653869628906
Epoch 1: graph # 19, Loss: 3.690558433532715
Epoch 1: graph # 20, Loss: 3.475358009338379
Epoch 1: graph # 21, Loss: 3.485562801361084
Epoch 1: gr

KeyboardInterrupt: 