## Graph Convolution Network (GCN)

Paper: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/pdf/1609.02907.pdf) (ICLR 2017)


**Propagation Rule**

$$H^{(l+1)} = \sigma \left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)} \right) ,$$

with $\tilde{A} = A + I$, where $A$ is the adjacency matrix, $I$ is the identity matrix, and $\tilde{D}$ is the diagonal node degree matrix of $\tilde{A}$. $W^{(l)}$ is a weight matrix for the $l$-th neural network layer and $\sigma(\cdot)$ is a non-linear activation function like the $\text{ReLU}$.

Example of a two-layer GCN: 

$$Z = f(X, A) = \text{softmax}\left(\hat{A} \text{ ReLU}\left(\hat{A}XW^{(0)} \right) W^{(1)} \right) ,$$

where $\hat{A} = \tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}}$, $W^{(0)}$ and $W^{(1)}$ are weight matrices.

**Message Passing Perspective**

|Notion | Meaning | 
|---|---|
|$\mathcal{G}$ = $(V, E)$ | Input graph |
|$x_v$ | Node features for node $v\in V$|
|$h_v$ | Node embedding for node $v\in V$ |
|$\mathcal{N}(v)$ | Neighbours of node $v\in V$|

Initial:
$$h^{(0)}_v = x_v , \forall v \in V .$$

Aggregate:
$$\hat{h}_v \leftarrow\sum_{u\in \{\mathcal{N}(v) \cup \{v\}\} }\frac{h^{(l-1)}_u}{\sqrt{|\mathcal{N}(u)| |\mathcal{N}(v)}|} , \forall v \in V .$$

Update: 
$$h^{(l)}_v \leftarrow \sigma \left(W^{(l)}\cdot \hat{h}_v \right), \forall v \in V.$$

## Reproduce Results

|Dataset | Citeseer | Cora | Pubmed | 
|---| --- | --- | ---|
|Original Paper | 70.3 | 81.5 | 79.0 | 
|Ours | 70.4 | 81.6 | 78.9 |

Refererence of implementation: https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/gcn_mp.py

In [1]:
import time

import torch 
import torch.nn as nn
import torch.nn.functional as F
import dgl 
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset

Using backend: pytorch


In [2]:
def message_func(edges):
    return {'m': edges.src['h'] * edges.src['norm']}

In [3]:
def reduce_func(nodes):
    return {'h': torch.sum(nodes.mailbox['m'], 1) * nodes.data['norm']}

In [4]:
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, dropout=0.5):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats, bias=False)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = self.dropout(h)
            g.update_all(message_func=message_func, reduce_func=reduce_func)
            return self.linear(g.ndata['h'])

In [5]:
# A two-layer GCN as described in the paper
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, dropout=0.5):
        super(GCN, self).__init__()
        self.conv1 = GCNLayer(in_feats, h_feats, dropout=dropout)
        self.conv2 = GCNLayer(h_feats, num_classes, dropout=dropout)

    def forward(self, g, h):
        h = self.conv1(g, h)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [6]:
# we use the same configurations as the paper's
dropout = 0.5
wd = 5e-4
hidden_size = 16
lr = 0.01
epochs = 200

In [7]:
def evaluate(model, g, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [8]:
def main(dataset):
    g = dataset[0]
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    in_feats = features.shape[1]
    n_classes = dataset.num_classes
    n_edges = g.number_of_edges()
    
    # add self loop
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
    n_edges = g.number_of_edges()

    # normalization
    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0
    g.ndata['norm'] = norm.unsqueeze(1)
    
    model = GCN(in_feats, hidden_size, n_classes, dropout)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = nn.CrossEntropyLoss()
    
    early_stopping_cnt = 0
    best_val = 100 # large enough
    for epoch in range(epochs):
        start = time.time()
        model.train()
        # forward
        logits = model(g, features)
        loss = loss_fn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        end = time.time()
        logits = model(g, features)
        val_loss = loss_fn(logits[val_mask], labels[val_mask])
        
        
        if val_loss < best_val:
            best_val = val_loss
            early_stopping_cnt = 0
        else:
            early_stopping_cnt += 1
            if early_stopping_cnt == 10:
                print("Early stopping (val loss does not decrease for 10 consecutive epochs).")
                break
        
        print("Epoch {:03d} | Time(s) {:.4f} | Train Loss {:.4f} | Val Loss {:.4f} | ".format(epoch, end - start, loss.item(), val_loss.item()))
    
    acc = evaluate(model, g, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))

In [9]:
dataset = CiteseerGraphDataset()
main(dataset)

  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 000 | Time(s) 0.6588 | Train Loss 1.7917 | Val Loss 1.7910 | 
Epoch 001 | Time(s) 0.5917 | Train Loss 1.7896 | Val Loss 1.7896 | 
Epoch 002 | Time(s) 0.7696 | Train Loss 1.7863 | Val Loss 1.7879 | 
Epoch 003 | Time(s) 0.7181 | Train Loss 1.7829 | Val Loss 1.7864 | 
Epoch 004 | Time(s) 0.6971 | Train Loss 1.7784 | Val Loss 1.7850 | 
Epoch 005 | Time(s) 0.6932 | Train Loss 1.7750 | Val Loss 1.7831 | 
Epoch 006 | Time(s) 0.7345 | Train Loss 1.7689 | Val Loss 1.7817 | 
Epoch 007 | Time(s) 0.6775 | Train Loss 1.7664 | Val Loss 1.7792 | 
Epoch 008 | Time(s) 0.7664 | Train Loss 1.7596 | Val Loss 1.7774 | 
Epoch 009 | Time(s) 0.7022 | Train Loss 1.7533 | Val Loss 1.7762 | 
Epoch 010 | Time(s) 0.7091 | Train Loss 1.7469 | Val Loss 1.7738 | 
Epoch 011 | Time(s) 0.7565 | Train Loss 1.7414 | Val Loss 1.7717 | 
Epo

In [10]:
dataset = CoraGraphDataset()
main(dataset)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 000 | Time(s) 0.4625 | Train Loss 1.9459 | Val Loss 1.9447 | 
Epoch 001 | Time(s) 0.4397 | Train Loss 1.9437 | Val Loss 1.9427 | 
Epoch 002 | Time(s) 0.4622 | Train Loss 1.9404 | Val Loss 1.9404 | 
Epoch 003 | Time(s) 0.4949 | Train Loss 1.9359 | Val Loss 1.9381 | 
Epoch 004 | Time(s) 0.5053 | Train Loss 1.9322 | Val Loss 1.9364 | 
Epoch 005 | Time(s) 0.4676 | Train Loss 1.9275 | Val Loss 1.9333 | 
Epoch 006 | Time(s) 0.4554 | Train Loss 1.9223 | Val Loss 1.9309 | 
Epoch 007 | Time(s) 0.4697 | Train Loss 1.9190 | Val Loss 1.9269 | 
Epoch 008 | Time(s) 0.4046 | Train Loss 1.9141 | Val Loss 1.9250 | 
Epoch 009 | Time(s) 0.3998 | Train Loss 1.9053 | Val Loss 1.9217 | 
Epoch 010 | Time(s) 0.4422 | Train Loss 1.8983 | Val Loss 1.9191 | 
Epoch 011 | Time(s) 0.4379 | Train Loss 1.8909 | Val Loss 1.9151 | 
Ep

In [11]:
dataset = PubmedGraphDataset()
main(dataset)

  NumNodes: 19717
  NumEdges: 88651
  NumFeats: 500
  NumClasses: 3
  NumTrainingSamples: 60
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 000 | Time(s) 2.0052 | Train Loss 1.0988 | Val Loss 1.0975 | 
Epoch 001 | Time(s) 1.8083 | Train Loss 1.0966 | Val Loss 1.0962 | 
Epoch 002 | Time(s) 1.7889 | Train Loss 1.0941 | Val Loss 1.0945 | 
Epoch 003 | Time(s) 1.7796 | Train Loss 1.0912 | Val Loss 1.0927 | 
Epoch 004 | Time(s) 1.7522 | Train Loss 1.0876 | Val Loss 1.0907 | 
Epoch 005 | Time(s) 1.7654 | Train Loss 1.0836 | Val Loss 1.0878 | 
Epoch 006 | Time(s) 1.7825 | Train Loss 1.0802 | Val Loss 1.0852 | 
Epoch 007 | Time(s) 1.9341 | Train Loss 1.0741 | Val Loss 1.0829 | 
Epoch 008 | Time(s) 1.9418 | Train Loss 1.0699 | Val Loss 1.0784 | 
Epoch 009 | Time(s) 1.8770 | Train Loss 1.0641 | Val Loss 1.0763 | 
Epoch 010 | Time(s) 1.7873 | Train Loss 1.0570 | Val Loss 1.0731 | 
Epoch 011 | Time(s) 1.8458 | Train Loss 1.0499 | Val Loss 1.0684 | 
Epo