# Selective Long-Range Connections in Message Passing Neural Networks
In this notebook, we will show that the message passing neural network (MPNN) can be improved upon for problems with a high problem-radius by using a last layer where some select nodes are connected over long distances to nodes otherwise unreachable.

## Imports

In [None]:
from itertools import combinations
from tqdm import tqdm

import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(0)

from torch_geometric import nn as gnn
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader

from torchmetrics import AveragePrecision, Accuracy, Precision, Recall, F1Score

from sklearn.model_selection import train_test_split

## Data

In [None]:
from torch_geometric.datasets import TUDataset, LRGBDataset

dataset = LRGBDataset(root='/tmp/PascalVOC-SP', name='PascalVOC-SP')
# dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
dataset_list = [data for data in dataset]

for data in dataset_list:
    data.x = data.x.float()

train_dataset, test_dataset = train_test_split(dataset_list, test_size=0.2, random_state=42)

In [None]:
# Make a quick data analysis
print(f"Dataset: {dataset}:")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")

print(f"Example nodes shape: {dataset[2].x.shape}")
print(f"Example edges shape: {dataset[2].edge_index.shape}")
print(f"Example edge features shape: {dataset[2].edge_attr.shape}")
print(f"Example target shape: {dataset[2].y.shape}")

## Model

In [None]:
class SimpleGNN(nn.Module):
    
    def __init__(self, num_node_features, hidden_channels, dense_input, num_classes, dropout):
        super(SimpleGNN, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_node_features, hidden_channels[0]))
        for i in range(1, len(hidden_channels)):
            self.convs.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))
        self.convs.append(GCNConv(hidden_channels[-1], num_classes))

        # self.dense = nn.Linear(dense_input, num_classes)

        self.dropout = dropout


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, edge_index)
        # x = F.relu(x)
        # x = gnn.global_mean_pool(x, batch)

        # x = F.dropout(x, p=self.dropout, training=self.training)
        # x = self.dense(x)
        return x

In [None]:
model = SimpleGNN(dataset.num_node_features, [10] * 2, 64, dataset.num_classes, 0.0)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

def train(epoch):
    model.train()

    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()

def test():
    model.eval()

    y_pred = []
    y_true = []
    for batch in test_loader:
        out = model(batch)
        y_pred.append(out)
        y_true.append(batch.y)

    y_pred = torch.cat(y_pred, dim=0)
    y_true = torch.cat(y_true, dim=0)
    
    average_precision = AveragePrecision(num_classes=dataset.num_classes, task='multiclass')
    accuracy = Accuracy(num_classes=dataset.num_classes, task='multiclass')
    precision = Precision(num_classes=dataset.num_classes, task='multiclass')
    recall = Recall(num_classes=dataset.num_classes, task='multiclass')
    f1 = F1Score(num_classes=dataset.num_classes, task='multiclass')

    return average_precision(y_pred, y_true), accuracy(y_pred, y_true), precision(y_pred, y_true), recall(y_pred, y_true), f1(y_pred, y_true)
    

for epoch in range(1, 201):
    train(epoch)
    ap, acc, prec, rec, f1 = test()
    print(f'Epoch: {epoch:03d}, AP: {ap:.4f}, Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}')

torch.save(model.state_dict(), 'model.pt')

In [None]:
len(test_loader)