<a href="https://colab.research.google.com/github/dp457/Graph-Neural-Network/blob/main/Creating_the_Message_Passing_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The convolution operator expressed as neighbourhood aggregation or message passing scheme. With $\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F$ denoting the node features of node $i$ in layer $\left (k-1 \right )$ and $\mathbf{e}_{j,i} \in \mathbb{R}^D$ denote the optional edge features from node $j$  to node $i$, hence the message passing GNN can be described as,

$\mathbf{x}_i^{(k)}
= \gamma^{(k)} \left(
    \mathbf{x}_i^{(k-1)},
    \bigoplus_{j \in \mathcal{N}(i)}
        \phi^{(k)} \left(
            \mathbf{x}_i^{(k-1)},
            \mathbf{x}_j^{(k-1)},
            \mathbf{e}_{j,i}
        \right)
\right),$

Here $\bigoplus$ represents differentiable, permutation invariant functions e.g. sum, mean or max and $\gamma$ and $\phi$ denote the differentiable functions such as MLPs.


## Message Passing Base Class

*   Define the aggregation scheme to use ("add", "mean" or "max") and flow direction of message passing. ("source_to_target" and "target_to_source"). Node dimensions help in obtaining which dimensions to propoagate
*   **.propagate()** - Initial call to start propagating the messages.



In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


## Implementing the GCN layer

The layer is mathematically defined as:
$\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}}  \cdot \left( \mathbf{W}^\top \cdot \mathbf{x}_j^{(k-1)} \right)  + \mathbf{b}  $

The formula is divided into following steps
*   Add self loops to the adjacency matrix.
*   Linearly transform the node feature matrix.
*   Compute the normalization coefficients.
*   Normalize the node feaures $\phi$.
*   Sum up neighbouring node features.
*   Apply the final bias vector.

Steps 1-3 typically computed before message passing take place. Steps 4-5 can be implemented using **MessagePassing** base class. The implementation is given as follows:

In [2]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
  def __init__(self, in_channels, out_channels):
    super().__init__(aggr='add')
    self.lin = Linear(in_channels, out_channels, bias=False)
    self.bias = Parameter(torch.empty(out_channels))

    self.reset_parameters()

  def reset_parameters(self):
    self.lin.reset_parameters()
    self.bias.data.zero_()

  def forward(self, x, edge_index):
    # x has shape [N, in_channels]
    # edge_index has shape [2,E].

    # Step 1. Add self-loops to the adjacency matrix
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

    # Step 2. Linearly transform the feature matrix
    x = self.lin(x)

    # Step 3. Compute normalization
    row, col = edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row]*deg_inv_sqrt[col]

    # Step 4-5 : Start propagating the messages
    out = self.propagate(edge_index, x=x, norm=norm) # internally calls message(), aggregate() and update()

    # Step 6: Apply a final bias vector
    out = out + self.bias

    return out

  def message(self, x_j, norm):
    # x_j has shape [E, out_channels]

    #Step 4: Normalize the node features
    return norm.view(-1,1)*x_j


In [5]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

In [6]:
# GCN Model
import torch.nn.functional as F
from torch import nn

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


In [7]:
#Utilities
import random
import numpy as np

def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [8]:
def accuracy(logits, y, mask):
    preds = logits.argmax(dim=-1)
    correct = int((preds[mask] == y[mask]).sum())
    total = int(mask.sum())
    return correct / max(total, 1)

In [9]:
# ----------------------------
# Training / Evaluation
# ----------------------------
def train(model, data, optimizer, weight_decay=5e-4):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    # L2 on all parameters as in Kipf and Welling
    l2 = sum((p**2).sum() for p in model.parameters())
    loss = loss + weight_decay * l2
    loss.backward()
    optimizer.step()
    return float(loss.item())

@torch.no_grad()
def evaluate(model, data):
    model.eval()
    logits = model(data.x, data.edge_index)
    loss_val = F.cross_entropy(logits[data.val_mask], data.y[data.val_mask]).item()
    loss_test = F.cross_entropy(logits[data.test_mask], data.y[data.test_mask]).item()
    acc_train = accuracy(logits, data.y, data.train_mask)
    acc_val = accuracy(logits, data.y, data.val_mask)
    acc_test = accuracy(logits, data.y, data.test_mask)
    return {
        "val_loss": loss_val,
        "test_loss": loss_test,
        "train_acc": acc_train,
        "val_acc": acc_val,
        "test_acc": acc_test,
    }

In [11]:
import argparse
from torch.optim import Adam

def main():
    parser = argparse.ArgumentParser(description="GCN with custom GCNConv on Planetoid datasets")
    parser.add_argument("--dataset", type=str, default="Cora", choices=["Cora", "CiteSeer", "PubMed"])
    parser.add_argument("--root", type=str, default="./data")
    parser.add_argument("--hidden", type=int, default=64)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--epochs", type=int, default=400)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--weight_decay", type=float, default=5e-4)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--patience", type=int, default=100, help="Early stopping patience on val loss")
    args, _ = parser.parse_known_args()

    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = Planetoid(root=args.root, name=args.dataset, transform=NormalizeFeatures())
    data = dataset[0].to(device)

    model = GCN(dataset.num_features, args.hidden, dataset.num_classes, dropout=args.dropout).to(device)
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=0.0)  # manual L2 already added

    best_val = float("inf")
    best_state = None
    patience = args.patience
    epochs_no_improve = 0

    for epoch in range(1, args.epochs + 1):
        loss = train(model, data, optimizer, weight_decay=args.weight_decay)
        metrics = evaluate(model, data)

        improved = metrics["val_loss"] < best_val - 1e-6
        if improved:
            best_val = metrics["val_loss"]
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch {epoch:03d} | "
                  f"train_acc={metrics['train_acc']:.4f} "
                  f"val_acc={metrics['val_acc']:.4f} "
                  f"test_acc={metrics['test_acc']:.4f} "
                  f"val_loss={metrics['val_loss']:.4f}")

        if patience and epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}. Best val_loss={best_val:.4f}")
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    final = evaluate(model, data)
    print("Final metrics:",
          {k: round(v, 4) for k, v in final.items()})

if __name__ == "__main__":
    main()


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Epoch 001 | train_acc=0.2071 val_acc=0.1440 test_acc=0.1600 val_loss=1.9414
Epoch 010 | train_acc=0.7786 val_acc=0.5680 test_acc=0.5820 val_loss=1.8956
Epoch 020 | train_acc=0.8714 val_acc=0.6960 test_acc=0.7170 val_loss=1.7849
Epoch 030 | train_acc=0.9357 val_acc=0.7480 test_acc=0.7860 val_loss=1.6156
Epoch 040 | train_acc=0.9500 val_acc=0.7820 test_acc=0.8140 val_loss=1.4052
Epoch 050 | train_acc=0.9643 val_acc=0.7900 test_acc=0.8160 val_loss=1.2276
Epoch 060 | train_acc=0.9643 val_acc=0.7960 test_acc=0.8160 val_loss=1.0943
Epoch 070 | train_acc=0.9786 val_acc=0.7940 test_acc=0.8150 val_loss=1.0185
Epoch 080 | train_acc=0.9786 val_acc=0.7960 test_acc=0.8100 val_loss=0.9682
Epoch 090 | train_acc=0.9857 val_acc=0.7900 test_acc=0.8150 val_loss=0.9354
Epoch 100 | train_acc=0.9857 val_acc=0.7940 test_acc=0.8100 val_loss=0.9077
Epoch 110 | train_acc=0.9857 val_acc=0.7860 test_acc=0.8100 val_loss=0.8857
Epoch 120 | train_acc=0.9857 val_acc=0.7880 test_acc=0.8000 val_loss=0.8767
Epoch 130 | 

# Implementating the Edge Convolution

It processes graphs and point cloud data and mathematically defined as,

$\mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\Theta} \left ( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right ) $

Here $h_{\Theta}$ represents MLP. In analogy to the GCN we use this class to implement this class.

In [3]:
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

Edge convolution is a dynamic convolution that computes graph for each layer using nearest neighbours in the feature space. The way to implement this is given as follows

In [4]:
from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)