<a href="https://colab.research.google.com/github/bearbearyu1223/Graph-Neural-Network-Study-Notes/blob/main/graph_embedding_and_graph_isomorphism_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Isomorphism Network (GIN)

In [1]:
# Import Libraries 

import torch
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

import networkx as nx
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
plt.rcParams.update({'font.size': 10})

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


## 1. Proteins Dataset 

[PROTEINS](https://chrsmrrs.github.io/datasets/docs/datasets/) is a popular 
dataset in bioinformatics. It is a collection of 1113 graphs representing proteins, where nodes are amino acids. Two nodes are connected by an edge when they are close enough (< 0.6 nanometers). The goal is to classify each protein as an **enzyme** or not. 

Enzymes are a particular type of proteins that act as catalysts to speed up chemical reactions in the cell. They are essential for digestion (e.g., lipases), respiration (e.g., oxidases), and other crucial functions of the human body. They are also used in commercial applications, like the production of antibiotics.

In [2]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root=".", name="PROTEINS").shuffle()

print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {dataset[0].x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Dataset: PROTEINS(1113)
-------------------
Number of graphs: 1113
Number of nodes: 20
Number of features: 3
Number of classes: 2


## 2. Create min-batches of the Dataset

In [3]:
from torch_geometric.loader import DataLoader

train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset   = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset  = dataset[int(len(dataset)*0.9):]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print('\nTrain loader:')
for i, subgraph in enumerate(train_loader):
    print(f'  Subgraph {i}: {subgraph}')

print('\nValidation loader:')
for i, subgraph in enumerate(val_loader):
    print(f'  Subgraph {i}: {subgraph}')

print('\nTest loader:')
for i, subgraph in enumerate(test_loader):
    print(f'  Subgraph {i}: {subgraph}')

Training set   = 890 graphs
Validation set = 111 graphs
Test set       = 112 graphs

Train loader:
  Subgraph 0: DataBatch(edge_index=[2, 6524], x=[1772, 3], y=[64], batch=[1772], ptr=[65])
  Subgraph 1: DataBatch(edge_index=[2, 13370], x=[3614, 3], y=[64], batch=[3614], ptr=[65])
  Subgraph 2: DataBatch(edge_index=[2, 9164], x=[2517, 3], y=[64], batch=[2517], ptr=[65])
  Subgraph 3: DataBatch(edge_index=[2, 8978], x=[2381, 3], y=[64], batch=[2381], ptr=[65])
  Subgraph 4: DataBatch(edge_index=[2, 9876], x=[2589, 3], y=[64], batch=[2589], ptr=[65])
  Subgraph 5: DataBatch(edge_index=[2, 10626], x=[2682, 3], y=[64], batch=[2682], ptr=[65])
  Subgraph 6: DataBatch(edge_index=[2, 6858], x=[1871, 3], y=[64], batch=[1871], ptr=[65])
  Subgraph 7: DataBatch(edge_index=[2, 9620], x=[2585, 3], y=[64], batch=[2585], ptr=[65])
  Subgraph 8: DataBatch(edge_index=[2, 8710], x=[2330, 3], y=[64], batch=[2330], ptr=[65])
  Subgraph 9: DataBatch(edge_index=[2, 8630], x=[2356, 3], y=[64], batch=[2356],

## 3. Representational Power of Graph Isomorphism Network (GIN)
[GIN](https://arxiv.org/abs/1810.00826v3) was designed to maximize the **representational power** of a GNN. 

### 3.1 Weisfeiler-Lehman Test
A way to characterize the "representational power" of a GNN is to use the  [Weisfeiler-Lehman graph isomorphism test](https://davidbieber.com/post/2019-05-10-weisfeiler-lehman-isomorphism-test/). [Isomorphism Graphs](https://en.wikipedia.org/wiki/Graph_isomorphism) mean the graphs have the same structure: identical connections but a permutation of nodes

One should note that the Weisfeiler-Lehman(WL) test is able to tell if two graphs **are non-isomorphic**, but it cannot guarantee that they are isomorphic.

In the WL test:


1. Every node starts with **the same label**.
2. Labels from neighboring nodes are **aggregated and** **hashed** to produce a new label.
3. The previous step is **repeated** until the labels stop changing.

### 3.2 One Aggregator to Rule Them All
To be as good as the WL test, the aggregator which is designed to aggregate the feature vectors from the neighboring nodes must produce different node embeddings when dealing with non-isomorphic graphs. How do we design this aggregator? **We just learn them with Multi-Layer Perceptron(MLP)**.  

$\displaystyle h_i = MLP\big((1+ϵ)⋅x_i+\sum_{j\in N_i}x_j\big)$

Where $ϵ$ is a learnable parameter represents the importance of the target node as compared to its neighbors. 

### 3.3 Global Pooling 
Global Pooling or Graph-Level Readout consists of producing a graph embedding using the node embeddings calculated by the GNN. 

A simple way to obtain a graph embedding $h_G$ is to use the $mean$, $sum$, or $max$ of every node embedding $h_i$. For **Global Pooling**, embeddings of nodes from each layer are summed and the result is concantenated. 

$\displaystyle h_G = \sum_{i=0}^{N}h^0_i ||…|| \sum_{i=0}^{N}h^k_i$



## 4. Benchmark the performance of GCN vs GIN

In [4]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
from torch_geometric.nn import GINConv, GCNConv
from torch_geometric.nn import global_mean_pool, global_add_pool 

In [5]:
def train(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                      lr=0.01,
                                      weight_decay=0.01)
    epochs = 100

    model.train()
    for epoch in range(epochs+1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0

        # Train on batches
        for data in loader:
          optimizer.zero_grad()
          _, out = model(data.x, data.edge_index, data.batch)
          loss = criterion(out, data.y)
          total_loss += loss / len(loader)
          acc += accuracy(out.argmax(dim=1), data.y) / len(loader)
          loss.backward()
          optimizer.step()

          # Validation
          val_loss, val_acc = test(model, val_loader)
        if(epoch % 10 == 0):
          print(f'Epoch {epoch:>3} '
                f'| Train Loss: {total_loss:.2f} '
                f'| Train Acc: {acc*100:>5.2f}% '
                f'| Val Loss: {val_loss:.2f} '
                f'| Val Acc: {val_acc*100:.2f}%')
          
    test_loss, test_acc = test(model, test_loader)
    print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')
    
    return model

@torch.no_grad()
def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
        _, out = model(data.x, data.edge_index, data.batch)
        loss += criterion(out, data.y) / len(loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

### 4.1 Classify Graphs via GCN 

In [6]:
class GCN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.conv1 = GCNConv(dim_in, dim_h)
        self.conv2 = GCNConv(dim_h, dim_h)
        self.conv3 = GCNConv(dim_h, dim_h)
        self.lin = Linear(dim_h, dim_out)

    def forward(self, x, edge_index, batch):
        # Node Embedding 
        h = self.conv1(x, edge_index)
        h = h.relu()
        h = self.conv2(h, edge_index)
        h = h.relu()
        h = self.conv3(h, edge_index)
        
        # Graph Embedding 
        hG = global_mean_pool(h, batch)

        # Classifier
        hG = F.dropout(hG, p=0.5, training=self.training)
        hG = self.lin(hG)
        
        return hG, F.log_softmax(hG, dim=1)

In [7]:
%%time

gcn = GCN(dim_in=dataset.num_features, dim_h=32, dim_out=dataset.num_classes)
gcn = train(gcn, train_loader)

Epoch   0 | Train Loss: 0.68 | Train Acc: 59.62% | Val Loss: 0.68 | Val Acc: 58.69%
Epoch  10 | Train Loss: 0.68 | Train Acc: 59.36% | Val Loss: 0.68 | Val Acc: 59.26%
Epoch  20 | Train Loss: 0.68 | Train Acc: 59.27% | Val Loss: 0.67 | Val Acc: 59.82%
Epoch  30 | Train Loss: 0.68 | Train Acc: 59.33% | Val Loss: 0.68 | Val Acc: 58.13%
Epoch  40 | Train Loss: 0.68 | Train Acc: 59.37% | Val Loss: 0.67 | Val Acc: 60.11%
Epoch  50 | Train Loss: 0.68 | Train Acc: 59.33% | Val Loss: 0.67 | Val Acc: 60.11%
Epoch  60 | Train Loss: 0.68 | Train Acc: 59.37% | Val Loss: 0.68 | Val Acc: 58.98%
Epoch  70 | Train Loss: 0.68 | Train Acc: 59.39% | Val Loss: 0.68 | Val Acc: 58.69%
Epoch  80 | Train Loss: 0.68 | Train Acc: 59.33% | Val Loss: 0.68 | Val Acc: 58.98%
Epoch  90 | Train Loss: 0.68 | Train Acc: 59.39% | Val Loss: 0.68 | Val Acc: 57.28%
Epoch 100 | Train Loss: 0.68 | Train Acc: 59.30% | Val Loss: 0.68 | Val Acc: 57.56%
Test Loss: 0.66 | Test Acc: 62.50%
CPU times: user 1min 33s, sys: 775 ms, to

### 4.2 Classify Graphs via GIN  

In [8]:
class GIN(torch.nn.Module): 
  def __init__(self, dim_in, dim_h, dim_out):
    super().__init__()
    self.conv1 = GINConv(
        Sequential(
            Linear(dim_in, dim_h),
            BatchNorm1d(dim_h), 
            ReLU(), 
            Linear(dim_h, dim_h), 
            ReLU()
            ))
    self.conv2 = GINConv(
        Sequential(
            Linear(dim_h, dim_h),
            BatchNorm1d(dim_h), 
            ReLU(), 
            Linear(dim_h, dim_h), 
            ReLU()
            ))
    self.conv3 = GINConv(
        Sequential(
            Linear(dim_h, dim_h),
            BatchNorm1d(dim_h), 
            ReLU(), 
            Linear(dim_h, dim_h), 
            ReLU()
            ))
    self.lin1 = Linear(dim_h * 3, dim_h *3)
    self.lin2 = Linear(dim_h * 3, dim_out)
  
  def forward(self, x, edge_index, batch):
    h1 = self.conv1(x, edge_index)
    h2 = self.conv2(h1, edge_index)
    h3 = self.conv3(h2, edge_index)

    h1 = global_add_pool(h1, batch)
    h2 = global_add_pool(h2, batch)
    h3 = global_add_pool(h3, batch)

    h = torch.cat((h1, h2, h3), dim=1)

    h = self.lin1(h)
    h = F.relu(h)
    h = F.dropout(h, p=0.5, training=self.training)
    h = self.lin2(h)

    return h, F.log_softmax(h, dim=1)    

In [9]:
%%time

gin = GIN(dim_in=dataset.num_features, dim_h=32, dim_out=dataset.num_classes)
gin = train(gin, train_loader)

Epoch   0 | Train Loss: 1.55 | Train Acc: 61.43% | Val Loss: 0.60 | Val Acc: 64.23%
Epoch  10 | Train Loss: 0.56 | Train Acc: 75.38% | Val Loss: 0.57 | Val Acc: 70.20%
Epoch  20 | Train Loss: 0.56 | Train Acc: 73.33% | Val Loss: 0.56 | Val Acc: 73.74%
Epoch  30 | Train Loss: 0.54 | Train Acc: 74.02% | Val Loss: 0.56 | Val Acc: 70.89%
Epoch  40 | Train Loss: 0.55 | Train Acc: 74.09% | Val Loss: 0.55 | Val Acc: 71.26%
Epoch  50 | Train Loss: 0.53 | Train Acc: 74.37% | Val Loss: 0.55 | Val Acc: 72.96%
Epoch  60 | Train Loss: 0.53 | Train Acc: 75.44% | Val Loss: 0.55 | Val Acc: 72.39%
Epoch  70 | Train Loss: 0.52 | Train Acc: 76.05% | Val Loss: 0.56 | Val Acc: 71.54%
Epoch  80 | Train Loss: 0.53 | Train Acc: 74.88% | Val Loss: 0.57 | Val Acc: 71.54%
Epoch  90 | Train Loss: 0.52 | Train Acc: 75.02% | Val Loss: 0.56 | Val Acc: 71.26%
Epoch 100 | Train Loss: 0.53 | Train Acc: 74.33% | Val Loss: 0.55 | Val Acc: 74.02%
Test Loss: 0.40 | Test Acc: 82.81%
CPU times: user 1min 1s, sys: 265 ms, tot

## 4.3 Ensemble learning of GIN and GCN
We can achieve better prediction performance by combing the predictions from multiple models. The simplest approach is to take the mean of the normalized output vectors. 

In [10]:
gcn.eval()
gin.eval()
acc_gcn = 0
acc_gin = 0
acc = 0

for data in test_loader:
    _, out_gcn = gcn(data.x, data.edge_index, data.batch)
    _, out_gin = gin(data.x, data.edge_index, data.batch)
    out = (out_gcn + out_gin)/2

    acc_gcn += accuracy(out_gcn.argmax(dim=1), data.y) / len(test_loader)
    acc_gin += accuracy(out_gin.argmax(dim=1), data.y) / len(test_loader)
    acc += accuracy(out.argmax(dim=1), data.y) / len(test_loader)

# Print results
print(f'GCN accuracy:     {acc_gcn*100:.2f}%')
print(f'GIN accuracy:     {acc_gin*100:.2f}%')
print(f'GCN+GIN accuracy: {acc*100:.2f}%')

GCN accuracy:     62.50%
GIN accuracy:     82.81%
GCN+GIN accuracy: 80.99%
