# Graph Neural Networks (GNN)

This week, you learned about Graph Neural Networks. GNNs are simply neural networks for graphs, so if you are familiar with graphs, the rest is straightforward.


A graph is a set of nodes (vertices) that are connected to each other via edges. You should have gotten familiar with some examples in the lectures about what kinds of data can be represented using graphs.

In order to analyze a graph using a neural network, we need to represent each node as a usual input to a neural network, like a feature vector. One can also represent edges as feature vectors, but that may introduce more complexity to the model. In the simplest setting, the edges indicate which nodes can pass a messsage to the node. Some edges have directions and information only  flows in the specified direction.

Most supervised learning tasks on graphs are about perdicting a target for either the whole graph or for each node. Here, we go over a node classification task and a graph regression task.

In [1]:
import os
from tqdm import tqdm
from typing import Sequence

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# torch_geometric for Graph Neural Networks:
try:
    import torch_geometric as gtorch
except ImportError:
    os.system('pip install torch_geometric -qq')
    os.system('pip install torch-scatter -qq')
    import torch_geometric as gtorch
import torch_geometric.nn as gnn
import torch_geometric.data as gdata
import torch_geometric.datasets as gdatasets
from torch_geometric.loader import DataLoader as gDataLoader
import torch_geometric.transforms as gtransforms

# rdkit for cheminformatics:
try:
    import rdkit
except ImportError:
    os.system('pip install rdkit-pypi -qq')


if torch.cuda.is_available():
    Device = 'cuda'
elif torch.backends.mps.is_available():
    Device = 'mps'
else:
    Device = 'cpu'

print(f'Using {Device} device')


def print_tensor_info(
        name: str, 
        tensor, # torch.Tensor
        ):
    print(f'{name}')
    print(20*'-')
    if not isinstance(tensor, torch.Tensor):
        print(f'It is {type(tensor).__name__}!')
        print(20*'='+'\n')
        return
    # print name, shhape, dtype, device, require_grad
    print(f'shape: {tensor.shape}')
    print(f'dtype: {tensor.dtype}')
    print(f'device: {tensor.device}')
    print(f'requires_grad: {tensor.requires_grad}')
    print(20*'='+'\n')

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[86 lines of output][0m
  [31m   [0m running bdist_wheel
  [31m   [0m running build
  [31m   [0m running build_py
  [31m   [0m creating build
  [31m   [0m creating build/lib.linux-x86_64-cpython-311
  [31m   [0m creating build/lib.linux-x86_64-cpython-311/torch_scatter
  [31m   [0m copying torch_scatter/placeholder.py -> build/lib.linux-x86_64-cpython-311/torch_scatter
  [31m   [0m copying torch_scatter/testing.py -> build/lib.linux-x86_64-cpython-311/torch_scatter
  [31m   [0m copying torch_scatter/__init__.py -> build/lib.linux-x86_64-cpython-311/torch_scatter
  [31m   [0m copying torch_scatter/scatter.py -> build/lib.linux-x86_64-cpython-311/torch_scatter
  [31m   [0m copying torch_scatter/segment_coo.py -> build/lib.linux-x86_64-cpython-311/torch_scatter
  [31m   [0m 

Using cpu device


## Graph Data Structure

We use [`torch_geometric`](https://pytorch-geometric.readthedocs.io/en/stable/), which we import as `gtorch`, to represent data as graphs and define graph neural networks that can analyze them. A graph is formally defined with a set of vertices ($V$) (also called nodes) and edges ($E$) connecting those vertices. The number of vertices and edges are denoted as $|V|$ and $|E|$ respectively. In `gtorch`, a graph is defined by these matrices:

- $X \in \mathbb{R}^{|V|\times d_v}$ containing vertex features. Each row represents a vertex (node) as a feature vector of size $d_v$.

- $I \in \{0, 1, ..., |V|-1\}^{2\times |E|}$ containing the index of the nodes at the two ends of each edge. Each column corresponds to one edge, where the first element is the index of the source node and the second element is the index of the target node. The source node is a neighbor of the target node, since it can send messages to it.

- Optional: $E \in \mathbb{R}^{|E|\times d_e}$ containing edge attributes. Each row corresponds to the feature representation of an edge as a vector of size $d_e$.

Let's take a look at a dataset where each sample is a graph. We are going to use the [Cora](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset that consists of 2708 scientific publications classified as one of seven classes. Let's take a look at the dataset:


In [None]:
cora_dataset = gdatasets.Planetoid(
    root = 'week5-data',
    name = 'Cora',
    transform = gtransforms.NormalizeFeatures(),
)

if isinstance(cora_dataset, Dataset):
    n = len(cora_dataset)
    print(f'Number of samples in the dataset: {n}')
    sample = cora_dataset[0]
    print_tensor_info('Sample', sample)

In [None]:
isinstance(cora_dataset, gdata.Dataset)

In [None]:
sample

This is a new data type used to represent homogeneous graphs in `gtorch`. You can read more about it, as well as other data classes [here](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html).

In [None]:
print_tensor_info('x', sample.x)
print_tensor_info('edge_index', sample.edge_index)
print_tensor_info('y', sample.y)
print_tensor_info('train_mask', sample.train_mask)
print_tensor_info('val_mask', sample.val_mask)
print_tensor_info('test_mask', sample.test_mask)

In [None]:
# Getting information about the dataset that we need for defining the model:

num_classes = cora_dataset.num_classes
num_features = cora_dataset.num_features

print(f'num_classes: {num_classes}')
print(f'num_features: {num_features}')

## Message Passing

In graph neural networks, the exchange of information between connected nodes (vertices) is commonly known as message passing. Let's represent the feature vector of node $i$ as $x_i$. The message passing operation updates the nodes as shown below:

$$
\mathbf{x}_i' = \gamma_\Theta \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \phi_\Theta \left( \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{j,i} \right) \right)
$$

The update consists of three steps. First, $\phi_\Theta$ calculates the message from every neighbor $j$ of node $i$ based on both their values, as well as the optional weight of the edge that connects them. Then, the information from all neighbors is pooled using a permutation-invariant operation $\bigoplus$ (min, max, mean, sum, etc). Finally, the updated value of the node for the next layer is calculated by $\gamma_\Theta$ using its current value and the aggregated message from all its neighbors. There is a base class [`MessagePassing`](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing) that you can use to define your custom message passing layer.

## Graph Convolution Operation
Any model that falls into the definition of message passing can be used as a layer of a graph neural network. A prominent layer used in graph neural networks is the graph convolution layer [GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html). You can learn more about it in the online documentation or the paper. Like `torch`, `gtorch` also has its `nn` (which we imported as `gnn`) and you can look at the modules it offers and how to use them in the [online documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html). Let's use `GCNConv` from `gnn.conv`. If you wonder why it is called colvolution, it is because of the same central property from our classic convolution: **parameter sharing**. The same operation with the same parameters is applied to every node and its neighbors.

In [None]:
gcn_layer = gnn.conv.GCNConv(
    in_channels = num_features,
    out_channels = 142,
    cached = True, # If you have only one graph, use cached=True for better performance.
).to(Device)

output = gcn_layer(sample.x.to(Device), sample.edge_index.to(Device))

print_tensor_info('output', output)

print(f'Is gcn_layer, an instance og nn.Module? {isinstance(gcn_layer, nn.Module)}')
print(f'Is gcn_layer, an instance of gnn.conv.MessagePassing? {isinstance(gcn_layer, gnn.conv.MessagePassing)}')
print(gcn_layer)

In [8]:
class GNN_node(nn.Module):

    def __init__(
            self,
            in_features: int,
            num_classes: int,
            hidden_dims: Sequence[int],
            activation: str = 'ReLU',
            dropout: float = 0.0,
            ):
        super().__init__()

        self.layers = nn.ModuleList()
        dims = [in_features] + hidden_dims
        n_graph_layers = len(hidden_dims)
        for i in range(n_graph_layers):
            self.layers.append(
                gnn.conv.GCNConv(
                    in_channels = dims[i],
                    out_channels = dims[i+1],
                    cached = True, # Use this if the whole data is one big graph
                    )
                    )
            
        self.act = F.__getattribute__(activation)
        self.dropout = dropout

        self.out_layer = nn.Linear(dims[-1], num_classes)

    def forward(
            self, 
            x: torch.FloatTensor, # (N_nodes, in_features)
            edge_index: torch.LongTensor, # (2, N_edges)
            ) -> torch.FloatTensor: # (N_nodes, num_classes)
        
        x = x.to(torch.float)
        edge_index = edge_index.to(torch.long)

        for layer in self.layers:
            x = layer(x, edge_index)
            x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training) # make sure to pass training=self.training when using F.dropout
            
        x = self.out_layer(x)
        return x

In [None]:
model = GNN_node(
    in_features = num_features,
    num_classes = num_classes,
    hidden_dims = [64, 128, 64],
    activation = 'relu',
    dropout = 0.1,
).to(Device)
model

In [10]:
@torch.enable_grad()
def train(
        model: nn.Module,
        dataset: Dataset,
        optimizer: torch.optim.Optimizer,
        loss_fn: nn.Module,
        epochs: int,
        device: str = Device,
        ):
    
    epoch_pbar = tqdm(range(epochs), leave=True)
    model.to(device)
    sample = dataset[0].to(Device)
    for epoch in epoch_pbar:
        model.train()
        optimizer.zero_grad()
        output = model(sample.x, sample.edge_index)
        target = sample.y
        train_loss = loss_fn(output[sample.train_mask], target[sample.train_mask])
        train_loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.eval()
            train_acc = (output[sample.train_mask].argmax(dim=1) == target[sample.train_mask]).float().mean()
            val_loss = loss_fn(output[sample.val_mask], target[sample.val_mask])
            val_acc = (output[sample.val_mask].argmax(dim=1) == target[sample.val_mask]).float().mean()
            test_loss = loss_fn(output[sample.test_mask], target[sample.test_mask])
            test_acc = (output[sample.test_mask].argmax(dim=1) == target[sample.test_mask]).float().mean()

        postfix_str = f"train: loss {train_loss.item():.6f}, acc {train_acc:.3f} | "
        postfix_str += f"val: loss {val_loss.item():.3f}, acc {val_acc:.3f} | "
        postfix_str += f"test: loss {test_loss.item():.3f}, acc {test_acc:.3f}"
        epoch_pbar.set_postfix_str(postfix_str)


In [None]:
model = GNN_node(
    in_features = num_features,
    num_classes = num_classes,
    hidden_dims = [64, 128, 64],
    activation = 'leaky_relu',
    dropout = 0.0,
)
model

optimizer = torch.optim.Adam(model.parameters())
loss = nn.CrossEntropyLoss()

train(model, cora_dataset, optimizer, loss, epochs=500)

The split of train/val/test here is actually a bit problematic. Can you tell why?

## Graph Regression
Now let's try a graph regression task, which is predicting a real value for the whole graph. For this, we will use [`MoleculeNet`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.MoleculeNet.html) dataset.

In [None]:
moleculenet_dataset = gdatasets.MoleculeNet(
    root = 'week5-data',
    name = 'Lipo',
)

if isinstance(moleculenet_dataset, Dataset):
    n = len(moleculenet_dataset)
    print(f'Number of samples in the dataset: {n}')
    sample = moleculenet_dataset[0]
    print_tensor_info('Sample', sample)

In [None]:
moleculenet_dataset[0]

In [None]:
moleculenet_dataset[1]

This time, we have many small graphs instead of one huge graph. How do we create mini-batches of graphs? Each has a different number of nodes and edges, so we cannot simply stack them along a new batch dimension like before!

The answer is simple. A mini-batch of several graphs, is just a big graph that consists of those graphs. There is no edge connecting the nodes from different subgraphs.

<div style="text-align: center;">
    
</div>

In [15]:
batch_size = 3
gdata_loader = gDataLoader(
    dataset = moleculenet_dataset,
    batch_size = batch_size,
    shuffle = False,
)

In [None]:
sample0, sample1, sample2 = moleculenet_dataset[:3]

batch0 = next(iter(gdata_loader))

for i in range(batch_size):
    print(f'sample {i}')
    print(moleculenet_dataset[i])
    print(20*'-')

print()
print('The first mini-batch containing the first 3 graphs')
print(batch0)
print(50*'='+'\n')
print('There are new fields in the batched graph that help you revocer the subgraphs:\n')

print_tensor_info('batch0.batch', batch0.batch)
print_tensor_info('batch0.ptr', batch0.ptr)
print(f'batch.ptr: {batch0.ptr}')

print("Take a look at the actual features of nodes:")
print_tensor_info('batch0.x', batch0.x)

## Global Pooling
The final component to complete our toolkit for graph-level prediction, is how to aggregate the feature vectors of all the nodes and get a single output from it. Such operation is called  a global pooling. You can find options for pooling layers of `gtorch` [here](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#pooling-layers). We'll use [`global_max_pooling`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.global_max_pool.html#torch_geometric.nn.pool.global_max_pool) for our example. You have to be careful to pass the batch field if you have batched graphs, to specify the subgraphs in the batched graph to be pooled.

In [None]:
pool_function = gnn.pool.global_max_pool # It is a function, not a Module


output = pool_function(
    x = batch0.x,
    batch = batch0.batch, # so the function knows which subgraph each node belongs to
    )
print_tensor_info('output after global max pooling', output)

Now let's define a graph neural network for graph regression. It is basically the same, except we have some global pooling, and some more layers after it.

In [22]:
class GNN_graph(nn.Module):

    def __init__(
            self,

            node_features: int,
            node_hidden_dims: Sequence[int],
            graph_hidden_dims: Sequence[int],
            out_features: int,
            activation: str = 'relu',
            pool: str = 'max', # max, mean, add
            ):
        super().__init__()
            
        self.act = F.__getattribute__(activation)

        self.node_layers = nn.ModuleList()
        n_node_layers = len(node_hidden_dims)
        node_dims = [node_features] + node_hidden_dims
        for i in range(n_node_layers):
            self.node_layers.append(
                gnn.conv.GCNConv(
                    in_channels = node_dims[i],
                    out_channels = node_dims[i+1],
                    )
                )
            
        self.pool = gnn.pool.__getattribute__(f'global_{pool}_pool')

        self.graph_layers = nn.ModuleList()
        n_graph_layers = len(graph_hidden_dims)
        graph_dims = [node_dims[-1]] + graph_hidden_dims
        for i in range(n_graph_layers):
            self.graph_layers.append(
                nn.Linear(
                    in_features = graph_dims[i],
                    out_features = graph_dims[i+1],
                    )
                )

        self.out_layer = nn.Linear(graph_dims[-1], out_features)

    def forward(
            self,
            x: torch.FloatTensor, # Batched input
            edge_index: torch.LongTensor,
            batch: torch.LongTensor,
            ) -> torch.FloatTensor:
        
        # Always make sure dtypes are correct
        x = x.to(torch.float)
        edge_index = edge_index.to(torch.long)
        batch = batch.to(torch.long)

        # Node level layers. x is of shape [N_nodes, node_feature] in these layers
        for layer in self.node_layers:
            x = layer(x, edge_index)
            x = self.act(x)

        # The pooling layer extracts one feature vector per graph
        x = self.pool(x, batch)
        # Now, x is of shape [N_graphs, node_hidden_dims[-1]]

        # The remaining layers are just like classic feed forward layer
        for layer in self.graph_layers:
            x = layer(x)
            x = self.act(x)
        
        x = self.out_layer(x)

        return x
        

In [23]:
@torch.enable_grad()
def train_epoch(
        model: nn.Module,
        loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        loss_fn: nn.Module,
        device = Device,
        ):
    
    model.train().to(device)
    for batch in loader:
        optimizer.zero_grad()
        batch = batch.to(device)
        output = model(batch.x, batch.edge_index, batch.batch)
        target = batch.y
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()


@torch.inference_mode()
def eval_epoch(
        model: nn.Module,
        loader: DataLoader,
        loss_fn: nn.Module,
        device = Device,
        ):
    n = len(loader.dataset)
    model.eval().to(device)
    Loss = 0.
    for batch in loader:
        batch = batch.to(device)
        output = model(batch.x, batch.edge_index, batch.batch)
        target = batch.y
        loss = loss_fn(output, target)
        if loss_fn.reduction == 'sum':
            Loss += loss.item()
        elif loss_fn.reduction == 'mean':
            Loss += loss.item()*len(batch)
    return Loss/n


def train(
        model: nn.Module,
        train_dataset: Dataset,
        val_dataset: Dataset,
        optimizer: torch.optim.Optimizer,
        loss_fn: nn.Module,
        batch_size: int,
        epochs: int,
        device = Device,
        ):
    
    train_loader = gDataLoader(
        dataset = train_dataset,
        batch_size = batch_size,
        shuffle = True,
        )
    
    val_loader = gDataLoader(
        dataset = val_dataset,
        batch_size = batch_size,
        shuffle = False,
        )
    
    epoch_pbar = tqdm(range(epochs))
    for epoch in epoch_pbar:
        train_epoch(
            model = model, 
            loader = train_loader, 
            optimizer = optimizer, 
            loss_fn = loss_fn, 
            device = device,
            )
        train_loss = eval_epoch(
            model = model, 
            loader = train_loader, 
            loss_fn = loss_fn,
            device = device,
            )
        val_loss = eval_epoch(
            model = model, 
            loader = val_loader, 
            loss_fn = loss_fn,
            device = device,
            )
        postfix_str = f"train loss {train_loss:.4f} | val loss {val_loss:.4f}"
        epoch_pbar.set_postfix_str(postfix_str)

In [None]:
N = len(moleculenet_dataset)

train_dataset = moleculenet_dataset[:-round(N/5)]
va_dataset = moleculenet_dataset[-round(N/5):]

node_features = moleculenet_dataset.num_node_features
graph_features = 1

model = GNN_graph(
    node_features = node_features,
    node_hidden_dims = [64, 128],
    graph_hidden_dims = [64, 32],
    out_features = graph_features,
    activation = 'leaky_relu',
    pool = 'max',
)

optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.MSELoss()

train(
    model = model,
    train_dataset = train_dataset,
    val_dataset = va_dataset,
    optimizer = optimizer,
    loss_fn = loss_fn,
    batch_size = 32,
    epochs = 100,
    )
