# Selective Long-Range Connections in Message Passing Neural Networks
In this notebook, we will show that the message passing neural network (MPNN) can be improved upon for problems with a high problem-radius by using a last layer where some select nodes are connected over long distances to nodes otherwise unreachable.

## Imports

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(0)

from torch_geometric import nn as gnn
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader

from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
# Load graph data
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')

# Split data into train and test sets
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

## Model

In [3]:
# Write a model class using torch geometric
class SimpleGNN(nn.Module):
    
    def __init__(self, num_node_features, hidden_channels, dense_input, num_classes):
        super(SimpleGNN, self).__init__()
        
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_node_features, hidden_channels[0]))
        for i in range(1, len(hidden_channels)):
            self.convs.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))
        self.convs.append(GCNConv(hidden_channels[-1], dense_input))

        self.dense = nn.Linear(dense_input, num_classes)


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)

        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = gnn.global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.dense(x)
        return x

In [12]:
# Create the model
model = SimpleGNN(dataset.num_node_features, [64, 64], 64, dataset.num_classes)

# Dataloader for random sampling and batching
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Create the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

# Create the loss function
criterion = nn.CrossEntropyLoss()

# Create the training loop
def train(epoch):
    model.train()

    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()

# Create the testing loop
def test():
    model.eval()

    correct = 0
    loss = 0
    for batch in test_loader:
        out = model(batch)
        loss += criterion(out, batch.y).item()
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())

    return correct / len(test_loader.dataset), loss / len(test_loader.dataset)
                                              
# Run the training loop
for epoch in range(1, 201):
    train(epoch)
    test_acc, test_loss = test()
    print(f'Epoch: {epoch:03d}, Test Acc: {test_acc:.4f}, Test Loss: {test_loss:.4f}')

# Save the model
torch.save(model.state_dict(), 'model.pt')



Epoch: 001, Test Acc: 0.3158, Test Loss: 0.0372
Epoch: 002, Test Acc: 0.6842, Test Loss: 0.0363
Epoch: 003, Test Acc: 0.6842, Test Loss: 0.0352
Epoch: 004, Test Acc: 0.6842, Test Loss: 0.0341
Epoch: 005, Test Acc: 0.6842, Test Loss: 0.0331
Epoch: 006, Test Acc: 0.6842, Test Loss: 0.0325
Epoch: 007, Test Acc: 0.6842, Test Loss: 0.0324
Epoch: 008, Test Acc: 0.6842, Test Loss: 0.0324
Epoch: 009, Test Acc: 0.6842, Test Loss: 0.0323
Epoch: 010, Test Acc: 0.6842, Test Loss: 0.0323
Epoch: 011, Test Acc: 0.6842, Test Loss: 0.0323
Epoch: 012, Test Acc: 0.6842, Test Loss: 0.0323
Epoch: 013, Test Acc: 0.6842, Test Loss: 0.0322
Epoch: 014, Test Acc: 0.6842, Test Loss: 0.0321
Epoch: 015, Test Acc: 0.6842, Test Loss: 0.0320
Epoch: 016, Test Acc: 0.6842, Test Loss: 0.0319
Epoch: 017, Test Acc: 0.6842, Test Loss: 0.0318
Epoch: 018, Test Acc: 0.6842, Test Loss: 0.0316
Epoch: 019, Test Acc: 0.6842, Test Loss: 0.0315
Epoch: 020, Test Acc: 0.6842, Test Loss: 0.0313
Epoch: 021, Test Acc: 0.6842, Test Loss: