In [1]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


### Source material
https://docs.dgl.ai/tutorials/blitz/1_introduction.html#sphx-glr-tutorials-blitz-1-introduction-py

## 1. Deep Graph Library (DGL)
DGL integrates built-in functions which allow us to compute feature updates (for nodes and edges) based on the graph structure. 

To simutaneously update all node features, the following 3 functions are used:

1. **message_func(edges)**: 
Each edge has attributres edge=[src, dst, data]. This function sends info $$src \rightarrow dst$$ It stores everything needed to do node-feature update in a dict called "mailbox". This dict can be accessed via "nodes.mailbox"
2. **reduce_func(nodes)**: Update the node feature by the update equation. All info in "mailbox" (obtained from message_func) will be used for the update.
3. **g.update_all(message_func,reduce_func)**: send messages through all edges (by message_func) and update features of all nodes (by reduce_func)

Additionally if the update equations involve edge updates, use **apply_edges(func)** to update edge features.

### Cora dataset
The cora dataset is single graph of citation network. It has
1. Nodes = papers
2. Edges = connectivity (citation) between the papers

In [2]:
import dgl.data
dataset=dgl.data.CoraGraphDataset()

# extract the graph
g=dataset[0]
print(g)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})


### A simple case
Assume we want to update node features by taking the mean of its neighbors
$$h_i^{(l+1)}=\text{mean}\left(h_j^{(l)}: j\in N(i)\right)$$
Note: $N(i)=$ all nodes j that send info to $i$, i.e. $(j,i)$ is an edge.
1. message_func(self,edges): store the information $h_j^{(l)}$ of $j$ on each each $(j,i)$
2. reduce_func(self,nodes): use the stored information to compute node update $\text{mean}\left(h_j^{(l)}: j\in N(i)\right)$.
3. g.update_all(message_func,reduce_func)**: perform message_func and reduce_func simutaneously on all nodes.

In [6]:
import dgl.function as fn

class mean_layer(nn.Module):
    def __init__(self):
        super().__init__()

    def message_func(self,edges): # edge=[src,dst,data]
        # extract src node features
        m=edges.src['h']
        return {'m': m} # stored in mailbox

    def reduce_func(self,nodes):
        # extract neigbor features from mailbox
        m=nodes.mailbox['m']
        return {'h_N': torch.mean(m,dim=1)}

    def forward(self,g,h):
        # h = input node features
        with g.local_scope():
            g.ndata['h']=h
            g.update_all(self.message_func,self.reduce_func)
            h_N=g.ndata['h_N']
            return h_N

            

In [7]:
print("---- Node 0 features before update -----")
print(g.ndata['feat'][0])

print("---- Node 0 features after update -----")
layer=mean_layer()
h=g.ndata['feat']
out=layer(g,h)
print(g.ndata['feat'].shape, out.shape)
print(out[0])

---- Node 0 features before update -----
tensor([0., 0., 0.,  ..., 0., 0., 0.])
---- Node 0 features after update -----
torch.Size([2708, 1433]) torch.Size([2708, 1433])
tensor([0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0222, 0.0000])


## 2. Graph Convolutional Network (GCN) from scratch
In GCN, the update equation at each node is computed over its neighbors.
$$h_i^{(l+1)}=W^{(l)}\sum_{j\in N(i)}\dfrac{1}{c_{ji}}h_j^{(l)}+b^{(l)},$$
where the terms are 
1. $N(i)=$ all neighbors of $i$ 
2. $c_{ji}=$ normalization constant $\sqrt{\text{deg}(j)}\sqrt{\text{deg}(i)}$
3. $W$ (weight) and $b$ (bias) are parameters

This GCN layer is available in dgl and can be called via **GraphConv(in_feats, out_feats)**. However, we will implement it from scratch by using message_func, reduce_func and update_all.

#### (a) GraphConv layer and model

In [16]:
class GCN_layer(nn.Module):
    def __init__(self,in_dim,out_dim):
        super().__init__()
        self.linear=nn.Linear(in_dim,out_dim)

        # initialize params by xavier
        self.apply(self._init_weights)

    def _init_weights(self,module):
        if isinstance(module,nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(self,g,h):
        # g = input graph, h=input node features
        with g.local_scope():
            
            degrees=g.in_degrees().float()
            norm=torch.pow(degrees,-0.5)
            norm=norm.unsqueeze(-1)

            # normalize all node features: h_i = h_i/sqrt(deg(i))
            h=h*norm
            # assign new node features
            g.ndata['h']=h

            # new feaurs for 'h' = sum of all neighbor features
            #                 h_i=sum_j h_j/sqrt(deg(j))
            g.update_all(message_func=fn.copy_u('h','m'), reduce_func=fn.sum('m','h'))

            # normalize by deg(i)
            out=g.ndata['h']*norm

            # linear layer
            out=self.linear(out)

            return out



In [19]:
# model
# input -> gcn1 -> relu -> gcn2 -> classification
class GCN_model(nn.Module):
    def __init__(self,in_dim,hidden_dim,num_classes):
        super().__init__()
        self.gcn1=GCN_layer(in_dim,hidden_dim)
        self.gcn2=GCN_layer(hidden_dim,num_classes)
    
    def forward(self,g,features):
        with g.local_scope():
            h=self.gcn1(g,features)
            h=F.relu(h)
            h=self.gcn2(g,h)
            return h


In [20]:
in_dim=g.ndata['feat'].shape[-1]
num_classes=dataset.num_classes

hidden_dim=128

net=GCN_model(in_dim, hidden_dim, num_classes)
print(f"{sum(p.numel() for p in net.parameters())/1e6} million parameters")

with torch.no_grad():
    features=g.ndata['feat']
    out=net(g,features)
    print(out.shape)


0.184455 million parameters
torch.Size([2708, 7])


#### (b) Train and Test

In [22]:
# train_mask, val_mask, test_mask
train_mask=g.ndata['train_mask']
val_mask=g.ndata['val_mask']
test_mask=g.ndata['test_mask']

print(f"train_nodes: {sum(train_mask)} | val_nodes: {sum(val_mask)} | test_nodes: {sum(test_mask)}")

train_nodes: 140 | val_nodes: 500 | test_nodes: 1000


In [24]:
def train(model, graph,loss_fn,optimizer):
    model.train()

    features=graph.ndata['feat']
    labels=graph.ndata['label']

    # forward and backward
    optimizer.zero_grad()
    # prediction on the whole graph
    logits=model(graph,features)
    # only consider train_mask
    loss=loss_fn(logits[train_mask], labels[train_mask])
    loss.backward()
    optimizer.step()

    # compute accuracy
    preds=logits.argmax(dim=-1)
    acc= (preds[train_mask]==labels[train_mask]).float().mean()

    return loss, acc


@torch.no_grad
def evaluate(model,graph,loss_fn):
    model.eval()
    features=graph.ndata['feat']
    labels=graph.ndata['label']

    # forward
    logits=model(graph,features)
    loss=loss_fn(logits[val_mask],labels[val_mask])

    # compute acc
    preds=logits.argmax(dim=-1)
    acc= (preds[val_mask]==labels[val_mask]).float().mean()

    return loss,acc

In [25]:
torch.manual_seed(1442)

num_epochs=30

model=GCN_model(in_dim,hidden_dim,num_classes)
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")

loss_fn=F.cross_entropy
optimizer=torch.optim.AdamW(model.parameters(),lr=0.1)

# train and test
for epoch in range(num_epochs):
    train_loss, train_acc=train(model,g,loss_fn,optimizer)
    val_loss, val_acc=evaluate(model,g,loss_fn)
    print(f"Epoch: {epoch} | train_loss: {train_loss: .4f} | train_acc: {train_acc*100: .2f}% | val_loss: {val_loss: .4f} | val_acc: {val_acc*100: .2f}%")


0.184455 million parameters
Epoch: 0 | train_loss:  1.9456 | train_acc:  10.71% | val_loss:  1.8188 | val_acc:  46.20%
Epoch: 1 | train_loss:  1.7384 | train_acc:  64.29% | val_loss:  1.7972 | val_acc:  22.40%
Epoch: 2 | train_loss:  1.3862 | train_acc:  57.14% | val_loss:  1.2557 | val_acc:  73.40%
Epoch: 3 | train_loss:  0.9328 | train_acc:  86.43% | val_loss:  1.1686 | val_acc:  67.00%
Epoch: 4 | train_loss:  0.5854 | train_acc:  90.00% | val_loss:  0.9531 | val_acc:  74.20%
Epoch: 5 | train_loss:  0.2967 | train_acc:  97.14% | val_loss:  0.8331 | val_acc:  75.20%
Epoch: 6 | train_loss:  0.1710 | train_acc:  98.57% | val_loss:  0.8206 | val_acc:  72.20%
Epoch: 7 | train_loss:  0.1195 | train_acc:  98.57% | val_loss:  0.7423 | val_acc:  77.60%
Epoch: 8 | train_loss:  0.0515 | train_acc:  100.00% | val_loss:  0.7458 | val_acc:  78.60%
Epoch: 9 | train_loss:  0.0248 | train_acc:  100.00% | val_loss:  0.8103 | val_acc:  76.80%
Epoch: 10 | train_loss:  0.0150 | train_acc:  100.00% | val_

### 3. GCN with dgl.nn.pytorch.conv.GraphConv
dgl.nn.pytorch.conv.GraphConv(in_feats, out_feats, norm='both', weight=True, bias=True, activation=None, allow_zero_in_degree=False)

In [28]:
from dgl.nn import GraphConv

# model: input -> gcn1 -> relu -> gcn2 -> classification
class GCN_Net(nn.Module):
    def __init__(self,in_dim,hidden_dim,num_classes):
        super().__init__()
        self.gcn1=GraphConv(in_dim,hidden_dim)
        self.gcn2=GraphConv(hidden_dim,num_classes)

    def forward(self,g, features):
        h=self.gcn1(g,features)
        h=F.relu(h)
        h=self.gcn2(g,h)
        return h


In [29]:
# load model
model2=GCN_Net(in_dim,hidden_dim,num_classes)
print(f"{sum(p.numel() for p in model2.parameters())/1e6} million parameters")

loss_fn=F.cross_entropy
optimizer=torch.optim.AdamW(model2.parameters(), lr=0.1)

# train and test
for epoch in range(num_epochs):
    train_loss, train_acc=train(model2,g,loss_fn,optimizer)
    val_loss, val_acc=evaluate(model2,g,loss_fn)
    print(f"Epoch: {epoch} | train_loss: {train_loss: .4f} | train_acc: {train_acc*100: .2f}% | val_loss: {val_loss: .4f} | val_acc: {val_acc*100: .2f}%")


0.184455 million parameters
Epoch: 0 | train_loss:  1.9451 | train_acc:  17.14% | val_loss:  1.9049 | val_acc:  34.00%
Epoch: 1 | train_loss:  1.7513 | train_acc:  57.14% | val_loss:  1.6535 | val_acc:  45.40%
Epoch: 2 | train_loss:  1.4345 | train_acc:  61.43% | val_loss:  1.4392 | val_acc:  60.40%
Epoch: 3 | train_loss:  0.9092 | train_acc:  92.14% | val_loss:  1.2300 | val_acc:  59.80%
Epoch: 4 | train_loss:  0.5409 | train_acc:  92.86% | val_loss:  0.8813 | val_acc:  78.80%
Epoch: 5 | train_loss:  0.2485 | train_acc:  97.14% | val_loss:  0.7233 | val_acc:  77.40%
Epoch: 6 | train_loss:  0.1545 | train_acc:  97.86% | val_loss:  0.6567 | val_acc:  81.20%
Epoch: 7 | train_loss:  0.0649 | train_acc:  99.29% | val_loss:  0.7653 | val_acc:  77.40%
Epoch: 8 | train_loss:  0.0383 | train_acc:  100.00% | val_loss:  0.8597 | val_acc:  75.00%
Epoch: 9 | train_loss:  0.0273 | train_acc:  100.00% | val_loss:  0.8131 | val_acc:  79.20%
Epoch: 10 | train_loss:  0.0102 | train_acc:  100.00% | val_