# Semi-supervised Community Detection using Graph Neural Networks

Almost every computer 101 class starts with a "Hello World" example. Like MNIST for deep learning, in graph domain we have the Zachary's Karate Club problem. The karate club is a social network that includes 34 members and documents pairwise links between members who interact outside the club. The club later divides into two communities led by the instructor (node 0) and the club president (node 33). The network is visualized as follows with the color indicating the community.

<img src='../asset/karat_club.png' align='center' width="400px" height="300px" />

In this tutorial, you will learn:

* Formulate the community detection problem as a semi-supervised node classification task.
* Build a GraphSAGE model, a popular Graph Neural Network architecture proposed by [Hamilton et al.](https://arxiv.org/abs/1706.02216)
* Train the model and understand the result.

## Community detection as node classification

The study of community structure in graphs has a long history. Many proposed methods are *unsupervised* (or *self-supervised* by recent definition), where the model predicts the community labels only by connectivity. Recently, [Kipf et al.,](https://arxiv.org/abs/1609.02907) proposed to formulate the community detection problem as a semi-supervised node classification task. With the help of only a small portion of labeled nodes, a GNN can accurately predict the community labels of the others.

In this tutorial, we apply Kipf's setting to the Zachery's Karate Club network to predict the community membership, where only the labels of a few nodes are used.

In [None]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

## Message passing and GNNs

DGL follows the *message passing paradigm* inspired by the Message Passing Neural Network proposed by [Gilmer et al.](https://arxiv.org/abs/1704.01212) Essentially, they found many GNN models can fit into the following framework:

$$
m_{u\sim v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\sim v}^{(l-1)}\right)
$$

$$
m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\sim v}^{(l)}
$$

$$
h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
$$

where DGL calls $M^{(l)}$ the *message function* and $\sum$ the *reduce function*.  Note that $\sum$ here can represent any function and is not necessarily a summation.

This message passing paradigm only allows a node to access the data of its direct neighbors. In order to access data multiple hops away, we can perform multi-layer message passing.

<img src='../asset/multi_layer.png' align='center' width="800px" height="600px" />


## Define a GraphSAGE model

In this tutorial, we are going to use the GraphSage model proposed by [Hamilton et al.](https://arxiv.org/abs/1706.02216).

The equations of a GraphSage layer are:

$$
h_{\mathcal{N}(v)}^l\leftarrow \text{AGGREGATE}_l\{h_u^{l-1},\forall u\in\mathcal{N}(v)\}
$$

$$
h_v^l\leftarrow \sigma\left(W^l\cdot \text{CONCAT}(h_v^{l-1}, h_{\mathcal{N}(v)}^l) \right)
$$


You can see that message passing is directional: the message sent from one node $u$ to other node $v$ is not necessarily the same as the other message sent from node $v$ to node $u$ in the opposite direction.

In [None]:
import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            # equation 1 and 2
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_neigh'))
            # equation 3
            h_neigh = g.ndata['h_neigh']
            h_total = torch.cat([h, h_neigh], dim=1)
            return self.linear(h_total)

Implement a multi-layer GraphSage model. Our model consists of two layers, each computes new node representations by aggregating neighbor information. 

In [None]:
# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

## Load data

In [None]:
from tutorial_utils import load_zachery

# ----------- 0. load graph -------------- #
g = load_zachery()
print(g)

## Add embeddings as node data

In the original Zachery's Karate Club graph, nodes are feature-less. (The `'Age'` attribute is an artificial one mainly for tutorial purposes). For feature-less graph, a common practice is to use an embedding weight that is updated during training for every node.

We can use PyTorch's `Embedding` module to achieve this.

In [None]:
# ----------- 1. node features -------------- #
node_embed = nn.Embedding(g.number_of_nodes(), 5)  # Every node has an embedding of size 5.
inputs = node_embed.weight                         # Use the embedding weight as the node features.
nn.init.xavier_uniform_(inputs)
print(inputs)

## Split the dataset for training

The community label is stored in the `'club'` node data (0 for instructor, 1 for club president). We pick a few nodes in the graph as training nodes and use the remaining nodes as test nodes.

In [None]:
import random
import numpy as np
random.seed(0)
labels = g.ndata['club']
print('#nodes:', len(labels))
train_nodes = np.unique([0, 33] + random.sample(range(len(labels)), 3))
test_nodes = np.delete(np.arange(len(labels)), train_nodes)
print('#labeled nodes:', len(train_nodes))
print('Labels', labels[train_nodes])

## Train the GraphSage model

In [None]:
# Create the model with given dimensions 
# input layer dimension: 5, node embeddings
# hidden layer dimension: 16
# output layer dimension: 2, the two classes, 0 and 1
net = GraphSAGE(5, 16, 2)

# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
optimizer = torch.optim.Adam(itertools.chain(net.parameters(), node_embed.parameters()), lr=0.01)

# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100):
    # forward
    logits = net(g, inputs)
    
    # compute loss
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_nodes], labels[train_nodes])
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    all_logits.append(logits.detach())
    
    if e % 5 == 0:
        print('In epoch {}, loss: {}'.format(e, loss))

In [None]:
# ----------- 5. check results ------------------------ #
pred = torch.argmax(logits, axis=1)
print('Accuracy', (pred == labels)[test_nodes].sum().item() / len(pred))

## Implement GraphSage with nn modules

DGL provides implementation of many popular neighbor aggregation modules. . They all can be invoked easily with one line of code. See the full list of supported [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv).

In [None]:
from dgl.nn import SAGEConv

# ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, num_classes, 'mean')
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    
# Create the model with given dimensions 
# input layer dimension: 5, node embeddings
# hidden layer dimension: 16
# output layer dimension: 2, the two classes, 0 and 1
net = GraphSAGE(5, 16, 2)

## Exercise

Play with the GNN models by using other [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv). For example, how about graph attention networks?