# Selective Long-Range Connections in Message Passing Neural Networks
In this notebook, we will show that the message passing neural network (MPNN) can be improved upon for problems with a high problem-radius by using a last layer where some select nodes are connected over long distances to nodes otherwise unreachable.

## Imports

In [None]:
from itertools import combinations
from tqdm import tqdm

import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

from torch_geometric import nn as gnn
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader

from torchmetrics import MeanSquaredError

## Data

In [None]:
from torch_geometric.datasets import TUDataset, LRGBDataset, ZINC

def node_feats_to_float(data):
    data.x = data.x.float()
    data.to(device)
    return data

train_dataset = ZINC(root='/tmp/ZINC', subset=True, split='train', transform=node_feats_to_float,)
val_dataset = ZINC(root='/tmp/ZINC', subset=True, split='val', transform=node_feats_to_float)
test_dataset = ZINC(root='/tmp/ZINC', subset=True, split='test', transform=node_feats_to_float)

In [None]:
print(f'Train: {train_dataset.get_summary()}')
print(f'Val: {val_dataset.get_summary()}')
print(f'Test: {test_dataset.get_summary()}')

## Model

In [None]:
class SimpleGNN(nn.Module):
    
    def __init__(self, num_node_features, hidden_channels, dense_input, num_classes, dropout):
        super(SimpleGNN, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_node_features, hidden_channels[0]))
        for i in range(1, len(hidden_channels)):
            self.convs.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))
        self.convs.append(GCNConv(hidden_channels[-1], dense_input))

        self.dense = nn.Linear(dense_input, num_classes)

        self.dropout = dropout


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = gnn.global_mean_pool(x, batch)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.dense(x)
        return x.squeeze(-1)

In [None]:
model = SimpleGNN(train_dataset.num_node_features, [10] * 2, 12, 1, 0.2).to(device)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.MSELoss()

def train(epoch):
    model.train()

    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()

def test():
    model.eval()

    mse = MeanSquaredError()
    for batch in test_loader:
        out = model(batch)
        mse(out, batch.y)
    return mse.compute()
    
    

for epoch in range(1, 201):
    train(epoch)

    with torch.no_grad():
        test_mse = test()
        print(f'Epoch: {epoch:03d}, Test MSE: {test_mse:.4f}')
    torch.cuda.empty_cache()
    
torch.save(model.state_dict(), 'model.pt')