In [13]:
##Import Required Library
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

In [14]:
#define a simple MPNN architecture
class TwoStageAttentionGNN(nn.Module):
    def __init__(self, num_nodes, input_dim, hidden_dim, output_dim):
        super(TwoStageAttentionGNN, self).__init__()

        # Adjust the number of heads and dimensions based on your requirements
        self.gat1 = GATConv(input_dim, hidden_dim, heads=4)
        self.gat2 = GATConv(hidden_dim * 4, output_dim, heads=1)  # Multiply by the number of heads in the previous layer

    def forward(self, x, edge_index):
        # First GAT layer with attention mechanism
        x = F.relu(self.gat1(x, edge_index))

        # Second GAT layer with attention mechanism
        x = F.relu(self.gat2(x, edge_index))

        return x

In [15]:
# Architecture Input and Output declaration
num_nodes = 10
input_dim = 64
hidden_dim = 32
output_dim = 1

# Sample input features and adjacency matrix (edge_index)
x = torch.rand((num_nodes, input_dim))
edge_index = torch.tensor([(0, 1, 1, 2, 2, 3), (1, 0, 2, 1, 3, 2)], dtype=torch.long)

In [16]:
# Instantiate the model
model = TwoStageAttentionGNN(num_nodes, input_dim, hidden_dim, output_dim)

In [17]:
# Forward pass
output = model(x, edge_index)
print(output)


tensor([[0.1633],
        [0.1302],
        [0.0711],
        [0.0422],
        [0.0993],
        [0.0000],
        [0.0158],
        [0.0000],
        [0.0000],
        [0.0000]], grad_fn=<ReluBackward0>)
