## **Defining Expressiveness**
- Neural networks are used to approximate functions, as justified by the universal approximation theorem, which states that a feedforward neural network with only one layer can approximate any smoooth function.
- The goal of graph neural networks is to produce the best node embeddings possible. To distinguish nodes, we compare node features and their neighbours. This problem is called graph isomorphism problem in graph theory.
- Two graphs are isomorphic if they have the same connections, and their only difference is a permutation of their nodes.
- The Weisfeiler-Lehman test (WL-test) aims to build a caconical form of a graph, and compares the canonical form of two graphs to check whether they are isomorphic.
1. At the beginning, each node in the graph receives the same colour.
2. Each node aggregates its own colour and the colours of its neighbours.
3. The result is fed to a hash function that produces a new colour.
4. Each node aggregates its new colour and the new colour of its neighbours.
5. The result is fed to a hash function that produces a new colour.
6. The steps are repeated until no more node changes colour.
- If two graphs do not share the same colours, they are not isomorphic. However, we cannot be sure if they are isomorphic if they share the same colours.
- A sum aggregator can discriminate more graph structures than a mean or max aggregator. This implies that the aggregators used so far (for GCN, GAT, etc) are suboptimal since they are less expressive than a sum.

## **Introducing GIN**
- GIN is designed to be as expressive as the WL-test. GIN consists of two functions.
1. Aggregate: The function $f$ selects the neighbouring nodes that the GNN considers.
2. The function $\phi$ combines the embeddings from the selected nodes to produce the new embeddings of the target nodes.

$$h_i' = \phi(h_i, f(\{h_j: j \in N_i\}))$$

- In the case of GCN, $f$ aggregates every neighbour of node $i$ and $\phi$ applies a mean aggregator. In the case of GraphSAGE, $f$ is a neighbourhood sampling function and $\phi$ can be a mean, max or LSTM aggregator.

- The functions for GIN are designed to be injective. If the functions were not injective, same output would be produced for different inputs and embeddings would be less valuable since they contain less information.
- Both functions can be learned using a single multi-layer perceptron, thanks to the universal approximation theorem. However, we should have more than one layer of MLP to distinguish specific graph structures.

$$h_i' = MLP((1+\epsilon)h_i + \sum_{j \in N_i} h_j)$$

## **Classifying Graphs using GIN**
- Graph classification is based on the node embeddings that a GNN produces. This operation is called global pooling. There are three ways to implement global pooling.
1. Mean global pooling
$$h_G = \frac{1}{N} \sum_{i=0}^{N}h_i$$
2. Max global pooling
$$h_G=max_{i=0}^N (h_i)$$
3. Sum global pooling
$$h_g = \sum_{i=0}^N h_i$$

- However, to consider all structural information, we need to consider embeddings produced by every layer of the GNN. Hence, we concatenate the sum of node embeddings by each layer of the GNN.

$$h_G = \sum_{i=0}^N h_i^0 || \cdots || \sum_{i=0}^N h_i^k$$

## **Implementing GIN**
- This dataset comprises 1,113 graphs representing proteins, where every node is an amino acid. An edge connects two nodes when their distance is lower than 0.6 nanometers. The goal of this dataset is to classify each protein as an enzyme.

In [1]:
import torch
!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html

torch.manual_seed(11)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m52.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m65.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m60.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m994.8/994.8 kB[0m [31m47.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m565.0/565.0 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [2]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='.', name='PROTEINS').shuffle()

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-----------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {dataset[0].x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/PROTEINS.zip
Extracting ./PROTEINS/PROTEINS.zip
Processing...


Dataset: PROTEINS(1113)
-----------------------
Number of graphs: 1113
Number of nodes: 14
Number of features: 3
Number of classes: 2


Done!
  out = torch.load(self.processed_paths[0])


In [3]:
from torch_geometric.loader import DataLoader

# Create training, validation, and test sets
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset   = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset  = dataset[int(len(dataset)*0.9):]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

# Create mini-batches
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=True)

print('\nTrain loader:')
for i, batch in enumerate(train_loader):
    print(f' - Batch {i}: {batch}')

print('\nValidation loader:')
for i, batch in enumerate(val_loader):
    print(f' - Batch {i}: {batch}')

print('\nTest loader:')
for i, batch in enumerate(test_loader):
    print(f' - Batch {i}: {batch}')

Training set   = 890 graphs
Validation set = 111 graphs
Test set       = 112 graphs

Train loader:
 - Batch 0: DataBatch(edge_index=[2, 9274], x=[2468, 3], y=[64], batch=[2468], ptr=[65])
 - Batch 1: DataBatch(edge_index=[2, 8972], x=[2366, 3], y=[64], batch=[2366], ptr=[65])
 - Batch 2: DataBatch(edge_index=[2, 8820], x=[2350, 3], y=[64], batch=[2350], ptr=[65])
 - Batch 3: DataBatch(edge_index=[2, 9596], x=[2570, 3], y=[64], batch=[2570], ptr=[65])
 - Batch 4: DataBatch(edge_index=[2, 9108], x=[2490, 3], y=[64], batch=[2490], ptr=[65])
 - Batch 5: DataBatch(edge_index=[2, 10022], x=[2637, 3], y=[64], batch=[2637], ptr=[65])
 - Batch 6: DataBatch(edge_index=[2, 9732], x=[2726, 3], y=[64], batch=[2726], ptr=[65])
 - Batch 7: DataBatch(edge_index=[2, 9316], x=[2533, 3], y=[64], batch=[2533], ptr=[65])
 - Batch 8: DataBatch(edge_index=[2, 7994], x=[2074, 3], y=[64], batch=[2074], ptr=[65])
 - Batch 9: DataBatch(edge_index=[2, 11984], x=[3267, 3], y=[64], batch=[3267], ptr=[65])
 - Batch 

In [4]:
import torch
torch.manual_seed(0)
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool

- For the composition of the GIN layer, we need an MLP with at least two layers. We should introduce batch normaliation to standardise the inputs of each hidden layer, which stabilizes and speeds up training. In summary, our GIN layer has the following composition:

$$ Linear \to BatchNorm \to ReLU \to Linear \to ReLU$$

In [5]:
class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(dataset.num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Node embeddings
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)

        return F.log_softmax(h, dim=1)

In [7]:
def train(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    epochs = 100

    model.train()
    for epoch in range(epochs+1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0

        # Train on batches
        for data in loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            total_loss += loss / len(loader)
            acc += accuracy(out.argmax(dim=1), data.y) / len(loader)
            loss.backward()
            optimizer.step()

            # Validation
            val_loss, val_acc = test(model, val_loader)

        # Print metrics every 20 epochs
        if(epoch % 20 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}%')

    return model

@torch.no_grad()
def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        loss += criterion(out, data.y) / len(loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

gin = GIN(dim_h=32)
gin = train(gin, train_loader)
test_loss, test_acc = test(gin, test_loader)
print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')

Epoch   0 | Train Loss: 1.07 | Train Acc: 61.45% | Val Loss: 0.60 | Val Acc: 61.32%
Epoch  20 | Train Loss: 0.55 | Train Acc: 75.30% | Val Loss: 0.52 | Val Acc: 76.86%
Epoch  40 | Train Loss: 0.50 | Train Acc: 74.93% | Val Loss: 0.52 | Val Acc: 75.23%
Epoch  60 | Train Loss: 0.50 | Train Acc: 75.89% | Val Loss: 0.53 | Val Acc: 77.58%
Epoch  80 | Train Loss: 0.48 | Train Acc: 76.37% | Val Loss: 0.47 | Val Acc: 80.83%
Epoch 100 | Train Loss: 0.47 | Train Acc: 78.78% | Val Loss: 0.49 | Val Acc: 74.73%
Test Loss: 0.55 | Test Acc: 72.14%
