# Selective Long-Range Connections in Message Passing Neural Networks
In this notebook, we will show that the message passing neural network (MPNN) can be improved upon for problems with a high problem-radius by using a last layer where some select nodes are connected over long distances to nodes otherwise unreachable.

## Imports

In [1]:
#!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
#!pip install torchmetrics
#!wget https://raw.githubusercontent.com/max-seeli/selective-long-range-connection-gnn/main/slrc.py

In [2]:
#!rm -rf /tmp/neighbors_match

In [3]:
from itertools import combinations
from tqdm import tqdm

import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

import torch_geometric
from torch_geometric import nn as gnn
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset

from torchmetrics import Accuracy

from sklearn.model_selection import KFold

import slrc
import neighbors_match

import pickle

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


## Data

In [4]:
def data_preprocessing(data):
    data.x = data.x.float()

    # Note: it is important for the graph mini-batch, that 'index' is contained in the name.
    data.k_hop_edge_index = slrc.create_k_hop_graph(data, k=3).edge_index

    data.to(device)
    return data

class NeighborsMatch(InMemoryDataset):
    def __init__(self, root, d, n, transform=None, pre_transform=None, pre_filter=None):
        self.d = d
        self.n = min(n, neighbors_match.num_graphs(d))

        super().__init__(root, transform, pre_transform, pre_filter)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['neighbors_match.pkl']
    
    @property
    def processed_file_names(self):
        return ['neighbors_match.pt']
    
    def download(self):
        neighbors_match_graphs = neighbors_match.create_all_tree_neighbors_match_graph(self.d)
        selected_graphs = []

        for graph in neighbors_match_graphs:
            selected_graphs.append(graph)
            if len(selected_graphs) >= self.n:
                break

        with open(self.raw_paths[0], 'wb') as f:
            pickle.dump(selected_graphs, f)


    def process(self):
        
        with open(self.raw_paths[0], 'rb') as f:
            nx_graphs = pickle.load(f)
        
        data_list = []
        for G in nx_graphs:
            data = torch_geometric.utils.from_networkx(G, group_node_attrs=['x'])
            data_list.append(data)
        
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        self.save(data_list, self.processed_paths[0])

dataset = NeighborsMatch(root='/tmp/neighbors_match', d=3, n=20000, pre_transform=data_preprocessing)    

Generating Trees:   6%|▌         | 19999/322560 [00:41<10:30, 479.66it/s] 
Processing...
  data[key] = torch.tensor(value)
Done!


## Model

In [5]:
class SimpleGNN(nn.Module):

    def __init__(self, num_node_features, hidden_channels, transfer_size, dense_layers, num_classes, dropout):
        super(SimpleGNN, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_node_features, hidden_channels[0]))
        for i in range(1, len(hidden_channels)):
            self.convs.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))
        self.convs.append(GCNConv(hidden_channels[-1], transfer_size))

        self.dense = nn.ModuleList()
        self.dense.append(nn.Linear(transfer_size, dense_layers[0]))
        for i in range(1, len(dense_layers)):
            self.dense.append(nn.Linear(dense_layers[i-1], dense_layers[i]))
        self.dense.append(nn.Linear(dense_layers[-1], num_classes))

        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, self.get_last_layer_edge_index(data))
        x = F.relu(x)
        x = gnn.global_mean_pool(x, batch)

        for dense in self.dense[:-1]:
            x = dense(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.dense[-1](x)

        return x.squeeze(-1)

    def get_last_layer_edge_index(self, data):
        return data.edge_index

In [6]:
class SlrcGNN(SimpleGNN):

    def get_last_layer_edge_index(self, data):
        return data.k_hop_edge_index

In [7]:
def train(model, train_loader, optimizer, criterion, epoch):
    model.train()

    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()

def test(model, test_loader):
    model.eval()

    acc = Accuracy(task='multiclass', num_classes=8).to(device)
    for batch in test_loader:
        out = model(batch)
        
        true = torch.argmax(batch.y, dim=1)
        pred = torch.argmax(out, dim=1)
        acc(pred, true)
    return acc.compute()


def eval(dataset, isSelective, params, k=5):

    kfold = KFold(n_splits=k, shuffle=True, random_state=42)

    fold_acc = []
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'@ Fold {fold}')
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

        fold_model = SlrcGNN(**params) if isSelective else SimpleGNN(**params)
        fold_model.to(device)

        optimizer = torch.optim.Adam(fold_model.parameters(), lr=0.01, weight_decay=5e-4)
        criterion = nn.MSELoss()


        best_acc = float('-inf')
        for epoch in range(1, 201):
            train(fold_model, train_loader, optimizer, criterion, epoch)

            with torch.no_grad():
                test_acc = test(fold_model, val_loader)
                best_acc = max(best_acc, test_acc)
                print(f'Epoch: {epoch:03d}, Test ACC: {test_acc:.4f}')

            torch.cuda.empty_cache()

        fold_acc.append(best_acc)
        print()

    return min(fold_acc), sum(fold_acc) / len(fold_acc), fold_acc

In [8]:
params = {
    'num_node_features': dataset.num_node_features,
    'hidden_channels': [64] * 2,
    'transfer_size': 64,
    'dense_layers': [64],
    'num_classes': dataset.num_classes,
    'dropout': 0.0
}

In [9]:
best_slrc, avg_slrc, all_slrc = eval(dataset, True, params)
print('-' * 64)
print(f'Best: {best_slrc}, Avg: {avg_slrc}, Per fold: {all_slrc}')

@ Fold 0


Epoch: 001, Test ACC: 0.1280
Epoch: 002, Test ACC: 0.1280
Epoch: 003, Test ACC: 0.1233
Epoch: 004, Test ACC: 0.1233
Epoch: 005, Test ACC: 0.1233
Epoch: 006, Test ACC: 0.1255
Epoch: 007, Test ACC: 0.1318
Epoch: 008, Test ACC: 0.1280
Epoch: 009, Test ACC: 0.1287


In [None]:
best_simple, avg_simple, all_simple = eval(dataset, False, params)
print('-' * 64)
print(f'Best: {best_simple}, Avg: {avg_simple}, Per fold: {all_simple}')

@ Fold 0
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size([32, 8])
torch.Size([32, 8]) torch.Size(

KeyboardInterrupt: 