In [None]:
# Install PyG if you haven't
%pip install torch_geometric
%pip install torch_sparse
%pip install torch_scatter

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting aiohttp (from torch_geometric)
  Downloading aiohttp-3.11.14-cp311-cp311-win_amd64.whl.metadata (8.0 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->torch_geometric)
  Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->torch_geometric)
  Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting attrs>=17.3.0 (from aiohttp->torch_geometric)
  Downloading attrs-25.3.0-py3-none-any.whl.metadata (10 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->torch_geometric)
  Downloading frozenlist-1.5.0-cp311-cp311-win_amd64.whl.metadata (14 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->torch_geometric)
  Downloading multidict-6.2.0-cp311-cp311-win_amd64.whl.metadata (5.1 kB)
Collecting propcache>=0.2.0 (from aiohttp->torch_geometric)
  Downloading propcache-0.3.1-cp311-cp311-win_amd64.whl.metadata (11 

### Import PyG and PyTorch

In this tutorial, we are going to introduce how to implement GraphSAGE to perform semi-supervised learning on a node classification task.

- We will first demonstrate how to implement and train a GraphSAGE model for node classification without neighbor sampling.
- Then, we will show how to use PyTorch Geometric's NeighborLoader to enable neighbor sampling for training and testing the GraphSAGE model.

We will demonstrate this using the PyTorch Geometric package. However, feel free to explore other packages such as DGL.

First, load PyTorch, PyTorch Geometric, and other necessary packages (we will also use NumPy).

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import SAGEConv

In [2]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')

### Prepare the PubMed dataset
We use a citation network called pubmed for demonstration. A node in the citation network is a paper and an edge represents the citation between two papers. 

This dataset has 19,717 papers and 88,648 citations. Each paper has a sparse bag-of-words feature vector and a class label.

In [3]:
# Load and preprocess the PubMed dataset
dataset = Planetoid(root='/tmp/Pubmed', name='Pubmed')
data = dataset[0]

# Sparse bag-of-words features of papers
features = data.x
# Class labels of papers
labels = data.y
# Number of unique classes on the nodes
n_classes = dataset.num_classes
# input feature size
in_feats = features.shape[1]
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Number of features: {features.shape[1]}')
print(f'Number of classes: {n_classes}')
print(f'Input feature size: {in_feats}')

Number of nodes: 19717
Number of edges: 88648
Number of features: 500
Number of classes: 3
Input feature size: 500


Here we remove all self-loops in the graph.

In [6]:
data

Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])

### Implement the GNN model

Essentially, given a graph structure, GNNs (GCN, GraphSAGE, GAT, etc.) are used to learn meaningful node representations (in this case, the embeddings, or vectors).
Once these embeddings are properly learnt, we may perform downstream tasks such as node classification, graph classification, and link prediction.

PyG provides two ways of implementing a GNN model:

- using the nn module, which contains many commonly used GNN modules.
- using the message passing interface to implement a GNN model from scratch.

For simplicity, we implement the GraphSAGE model in the tutorial with the nn module.

If you are interested in using the message passing interface to implement a GNN model, check this link https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html out.

![fishy](https://raw.githubusercontent.com/dglai/WWW20-Hands-on-Tutorial/master/images/GNN.png)

The GraphSage model has multiple layers. In each layer, a vertex accesses its direct neighbors. When we stack $k$ layers in a model, a node $v$ access neighbors within $k$ hops. The output of the GraphSage model is **node embeddings** that represent the nodes and all information in the k-hop neighborhood.

If you want to learn about the details of the SageConv layer, look at its official documantation at https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html  

In [5]:
class GraphSAGEModel(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 out_dim,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGEModel, self).__init__()
        self.layers = nn.ModuleList()

        # Input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggr=aggregator_type))
        # Hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggr=aggregator_type))
        # Output layer
        self.layers.append(SAGEConv(n_hidden, out_dim, aggr=aggregator_type))

        self.activation = activation
        self.dropout = dropout

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for layer in self.layers[:-1]:
            x = layer(x, edge_index)
            if self.activation is not None:
                x = self.activation(x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        
        # Output layer
        x = self.layers[-1](x, edge_index)
        return x

### Node classification (semi-supervised)
Let us perform node classification in a semi-supervised setting. In this setting, we have the entire graph structure and all node features. We only have labels on some of the nodes. We want to predict the labels on other nodes. Even though some of the nodes do not have labels, they connect with nodes with labels. Thus, we train the model with both labeled nodes and unlabeled nodes. Semi-supervised learning can usually improve performance.

![semisupervised](https://raw.githubusercontent.com/dglai/WWW20-Hands-on-Tutorial/master/images/node_classify1.png)

This dependency graph shows a better view of how labeled and unlabled nodes are used in the training. 

![dependency](https://raw.githubusercontent.com/dglai/WWW20-Hands-on-Tutorial/master/images/node_classify2.png)

In [6]:
# Hyperparameters
n_hidden = 64
n_layers = 2
dropout = 0.5
aggregator_type = 'mean'

gconv_model = GraphSAGEModel(in_feats,
                             n_hidden,
                             n_classes,
                             n_layers,
                             F.relu,
                             dropout,
                             aggregator_type)

Now we create the node classification model based on the GraphSage model. The GraphSage model takes a data (include edge_index and node features) as input and computes node embeddings as output. With node embeddings, we use a cross entropy loss to train the node classification model.

In [7]:
class NodeClassification(nn.Module):
    def __init__(self, gconv_model):
        super(NodeClassification, self).__init__()
        self.gconv_model = gconv_model
        self.loss_fcn = torch.nn.CrossEntropyLoss()

    def forward(self, data):
        labels, train_mask = data.y, data.train_mask
        logits = self.gconv_model(data)
        # Compute the loss using the training mask
        return self.loss_fcn(logits[train_mask], labels[train_mask])

After defining a model for node classification, we define the evaluation, train and test function.

In [8]:
def NCEvaluate(model, data, mask):
    model.eval()
    with torch.no_grad():
        # Compute embeddings with GNN
        logits = model.gconv_model(data)
        logits = logits[mask]
        labels = data.y[mask]

        # Get the predicted class indices
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        acc = correct.item() * 1.0 / len(labels)
    return acc

def Train(model, data, optimizer, n_epochs):
    for epoch in range(n_epochs):
        # Set the model in training mode
        model.train()

        # Forward pass and compute loss
        loss = model(data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Evaluate on the validation set
        val_acc = NCEvaluate(model, data, data.val_mask)
        print("Epoch {:05d} | Loss {:.4f} | Validation Accuracy {:.4f}".format(epoch, loss.item(), val_acc))

def Test(model, data):
    test_acc = NCEvaluate(model, data, data.test_mask)
    print('Testing Accuracy:', test_acc)

Prepare data for semi-supervised node classification

In [9]:
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask

print("""----Data statistics------'
      #Classes {}
      #Train samples {}
      #Val samples {}
      #Test samples {}""".format(
          data.y.max().item() + 1,  # Assuming y contains class labels starting from 0
          train_mask.sum().item(),
          val_mask.sum().item(),
          test_mask.sum().item()))


----Data statistics------'
      #Classes 3
      #Train samples 60
      #Val samples 500
      #Test samples 1000


After defining the model and evaluation function, we can put everything into the training loop to train the model.

In [10]:
# Node classification task
model = NodeClassification(gconv_model)

# Training hyperparameters
weight_decay = 5e-4
n_epochs = 150
lr = 1e-3

# create the Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

Train(model, data, optimizer, n_epochs)
Test(model, data)


Epoch 00000 | Loss 1.1010 | Validation Accuracy 0.1960
Epoch 00001 | Loss 1.0986 | Validation Accuracy 0.1960
Epoch 00002 | Loss 1.0977 | Validation Accuracy 0.1960
Epoch 00003 | Loss 1.0981 | Validation Accuracy 0.1960
Epoch 00004 | Loss 1.0919 | Validation Accuracy 0.1960
Epoch 00005 | Loss 1.0963 | Validation Accuracy 0.1960
Epoch 00006 | Loss 1.0928 | Validation Accuracy 0.1960
Epoch 00007 | Loss 1.0897 | Validation Accuracy 0.1960
Epoch 00008 | Loss 1.0764 | Validation Accuracy 0.1960
Epoch 00009 | Loss 1.0846 | Validation Accuracy 0.1960
Epoch 00010 | Loss 1.0846 | Validation Accuracy 0.2280
Epoch 00011 | Loss 1.0771 | Validation Accuracy 0.3500
Epoch 00012 | Loss 1.0755 | Validation Accuracy 0.4840
Epoch 00013 | Loss 1.0795 | Validation Accuracy 0.5720
Epoch 00014 | Loss 1.0652 | Validation Accuracy 0.6260
Epoch 00015 | Loss 1.0644 | Validation Accuracy 0.6640
Epoch 00016 | Loss 1.0621 | Validation Accuracy 0.6760
Epoch 00017 | Loss 1.0636 | Validation Accuracy 0.7000
Epoch 0001

The above example runs without neighbor sampling. Now, let's look at how to implement this feature in PyTorch Geometric (PyG).

PyG provides support for neighbor sampling through the following utilities:

`torch_geometric.loader.NeighborSampler`
`torch_geometric.loader.DataLoader`
Note that the GraphSAGE structure does not change; it is only a change in the training approach:

We switch to batched training.
Each node within a batch is updated with a portion of randomly sampled neighbors instead of all its neighbors.

In [11]:
import torch_sparse
import torch_scatter
from torch_geometric.loader import NeighborLoader

# Define the batch size and fan-out
batch_size = 1024
fan_out = [10, 20, 30]  # Maximum number of neighbors in each layer

def get_dataloader_with_sampling(data, mask, batch_size=32, shuffle=False):
    # Create a NeighborSampler
    sampler = NeighborLoader(data, input_nodes = mask, num_neighbors=fan_out, batch_size=batch_size, shuffle=shuffle)

    # Return the sampler and the node IDs|
    return sampler

# Assuming your data object has the necessary attributes (edge_index, train_mask, val_mask, test_mask)
train_sampler = get_dataloader_with_sampling(data, data.train_mask, batch_size, shuffle=True)
val_sampler = get_dataloader_with_sampling(data, data.val_mask, batch_size, shuffle=False)
test_sampler = get_dataloader_with_sampling(data, data.test_mask, batch_size, shuffle=False)

The model structure remains the same. 
And the only difference is at the **forward** function, where we adapt the function to receive `blocks` data as inputs, which are the batched neighborhood-sampled graphs.

In [12]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = x.relu_()
                x = F.dropout(x, p=0.5, training=self.training)
        return x

In [13]:
class NodeClassification(nn.Module):
    def __init__(self, gconv_model):
        super(NodeClassification, self).__init__()
        self.gconv_model = gconv_model
        self.loss_fcn = torch.nn.CrossEntropyLoss()

    def forward(self, x, edge_index):
        logits = self.gconv_model(x, edge_index)
        # Compute the loss using the training mask
        return logits

For the training and evaluation, we re-organize them to receive batch input.

In [14]:
def Train(model, train_dataloader, val_dataloader, optimizer, n_epochs):
    for epoch in range(n_epochs):
        model.train()
        total_loss = total_correct = 0
        for batch in train_dataloader:
            y = batch.y[:batch.batch_size]
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
            
            total_loss += float(loss)
            total_correct += int(out.argmax(dim=-1).eq(y).sum())
        loss = total_loss / len(train_dataloader)
        approx_acc = total_correct / int(data.train_mask.sum())
        val_acc = Evaluate(model, val_dataloader)
        print(f"Epoch {epoch:05d} | Loss {loss:.4f} | Accuracy {approx_acc:.4f} | Val Accuracy {val_acc:.4f}")

def Evaluate(model, eval_dataloader):
    model.eval()
    all_labels, all_logits = [], []
    with torch.no_grad():
        for batch in eval_dataloader:
            y = batch.y[:batch.batch_size]
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            all_logits.append(out)
            all_labels.append(y)
        labels = torch.cat(all_labels)
        logits = torch.cat(all_logits)
        acc = (logits.argmax(1) == labels).float().mean().item()
    return acc

def Test(model, eval_dataloader):
    test_acc = Evaluate(model, eval_dataloader)
    print('Testing Accuracy:', test_acc)

Let's try training a model in this way...

In [15]:
# Hyperparameters
n_hidden = 64
n_layers = 2
dropout = 0.5
aggregator_type = 'mean' 
gconv_model = SAGE(in_feats,
                    n_hidden,
                    n_classes)

# Node classification task
model = NodeClassification(gconv_model)

In [16]:
# Training hyperparameters
weight_decay = 5e-4
n_epochs = 150
lr = 1e-3

# create the Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

Train(model, train_sampler, val_sampler, optimizer, n_epochs)
Test(model, test_sampler)

Epoch 00000 | Loss 1.1022 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00001 | Loss 1.0973 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00002 | Loss 1.0932 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00003 | Loss 1.0887 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00004 | Loss 1.0849 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00005 | Loss 1.0804 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00006 | Loss 1.0784 | Accuracy 0.3333 | Val Accuracy 0.4160
Epoch 00007 | Loss 1.0699 | Accuracy 0.3667 | Val Accuracy 0.4160
Epoch 00008 | Loss 1.0670 | Accuracy 0.3500 | Val Accuracy 0.4180
Epoch 00009 | Loss 1.0642 | Accuracy 0.3833 | Val Accuracy 0.4300
Epoch 00010 | Loss 1.0579 | Accuracy 0.4333 | Val Accuracy 0.4520
Epoch 00011 | Loss 1.0501 | Accuracy 0.4833 | Val Accuracy 0.4800
Epoch 00012 | Loss 1.0464 | Accuracy 0.5667 | Val Accuracy 0.5040
Epoch 00013 | Loss 1.0412 | Accuracy 0.6333 | Val Accuracy 0.5120
Epoch 00014 | Loss 1.0314 | Accuracy 0.6833 | Val Accuracy 0.5240
Epoch 0001

Generally, the results should be very similar to the previous one on this dataset.