# Selective Long-Range Connections in Message Passing Neural Networks
In this notebook, we will show that the message passing neural network (MPNN) can be improved upon for problems with a high problem-radius by using a last layer where some select nodes are connected over long distances to nodes otherwise unreachable.

## Imports

In [4]:
#!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
#!pip install torchmetrics
#!wget https://raw.githubusercontent.com/max-seeli/selective-long-range-connection-gnn/main/slrc.py

In [5]:
from itertools import combinations
from tqdm import tqdm

import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

from torch_geometric import nn as gnn
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader

from torchmetrics import MeanSquaredError, MeanAbsoluteError

from sklearn.model_selection import KFold

import slrc

Device: cpu


## Data

In [6]:
from torch_geometric.datasets import ZINC

def data_preprocessing(data):
    data.x = data.x.float()

    # Note: it is important for the graph mini-batch, that 'index' is contained in the name.
    data.k_hop_edge_index = slrc.create_k_hop_graph(data, k=3).edge_index

    data.to(device)
    return data

dataset = ZINC(root='/tmp/ZINC', subset=True, split='train', pre_transform=data_preprocessing)
val_dataset = ZINC(root='/tmp/ZINC', subset=True, split='val', pre_transform=data_preprocessing)
test_dataset = ZINC(root='/tmp/ZINC', subset=True, split='test', pre_transform=data_preprocessing)

Processing...
Processing train dataset:  75%|███████▌  | 7516/10000 [01:00<00:15, 165.49it/s]

KeyboardInterrupt: 

## Model

In [None]:
class SimpleGNN(nn.Module):

    def __init__(self, num_node_features, hidden_channels, transfer_size, dense_layers, num_classes, dropout):
        super(SimpleGNN, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_node_features, hidden_channels[0]))
        for i in range(1, len(hidden_channels)):
            self.convs.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))
        self.convs.append(GCNConv(hidden_channels[-1], transfer_size))

        self.dense = nn.ModuleList()
        self.dense.append(nn.Linear(transfer_size, dense_layers[0]))
        for i in range(1, len(dense_layers)):
            self.dense.append(nn.Linear(dense_layers[i-1], dense_layers[i]))
        self.dense.append(nn.Linear(dense_layers[-1], num_classes))

        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, self.get_last_layer_edge_index(data))
        x = F.relu(x)
        x = gnn.global_mean_pool(x, batch)

        for dense in self.dense[:-1]:
            x = dense(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.dense[-1](x)

        return x.squeeze(-1)

    def get_last_layer_edge_index(self, data):
        return data.edge_index

In [None]:
class SlrcGNN(SimpleGNN):

    def get_last_layer_edge_index(self, data):
        return data.k_hop_edge_index

In [None]:
def train(model, train_loader, optimizer, criterion, epoch):
    model.train()

    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()

def test(model, test_loader):
    model.eval()

    mse = MeanSquaredError().to(device)
    mae = MeanAbsoluteError().to(device)
    for batch in test_loader:
        out = model(batch)
        mse(out, batch.y)
        mae(out, batch.y)
    return mse.compute(), mae.compute()


def eval(dataset, isSelective, params, k=5):

    kfold = KFold(n_splits=k, shuffle=True, random_state=42)

    fold_maes = []
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'@ Fold {fold}')
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

        fold_model = SlrcGNN(**params) if isSelective else SimpleGNN(**params)
        fold_model.to(device)

        optimizer = torch.optim.Adam(fold_model.parameters(), lr=0.01, weight_decay=5e-4)
        criterion = nn.MSELoss()


        best_mae = float('inf')
        for epoch in range(1, 201):
            train(fold_model, train_loader, optimizer, criterion, epoch)

            with torch.no_grad():
                test_mse, test_mae = test(fold_model, val_loader)
                best_mae = min(best_mae, test_mae)
                print(f'Epoch: {epoch:03d}, Test MSE: {test_mse:.4f}, Test MAE: {test_mae:.4f}')

            torch.cuda.empty_cache()

        fold_maes.append(best_mae)
        print()

    return min(fold_maes), sum(fold_maes) / len(fold_maes), fold_maes

In [None]:
params = {
    'num_node_features': dataset.num_node_features,
    'hidden_channels': [64] * 2,
    'transfer_size': 64,
    'dense_layers': [64],
    'num_classes': 1,
    'dropout': 0.0
}

In [None]:
best_slrc, avg_slrc, all_slrc = eval(dataset, True, params)
print('-' * 64)
print(f'Best: {best_slrc}, Avg: {avg_slrc}, Per fold: {all_slrc}')

In [None]:
best_simple, avg_simple, all_simple = eval(dataset, False, params)
print('-' * 64)
print(f'Best: {best_simple}, Avg: {avg_simple}, Per fold: {all_simple}')