##  Message Passing Paradigm

### Definition 

Let $x_v\in \mathbb{R}^{d_1}$ be the feature for node $v$, and $w_e\in \mathbb{R}^{d_2}$  be the feature for edge $(u, v)$. The **message passing paradigm** defines the following node-wise and edge-wise computation at step $t+1$:

$$ \text{Edge-wise: } m_e^{(t+1)} = \phi \left(x_v^{(t)}, x_u^{(t)}, w_e^{(t)} \right), (u, v, e) \in \mathcal{E} .$$

$$\text{Node-wise: } x_v^{(t+1)} = \psi\left(x_v^{(t)}, \rho\left(\{ m_e^{(t+1}): (u, v, e)\in \mathcal{E}\}\right) \right) .$$

where 
- $\phi$ is a **message function** defined on each edge to generate a message by combining the edge feature with the features of its incident nodes (i.e.,  source nodes $u$),
- $\psi$ is an **update function** deinfed on each node to update the node feature by aggregating its incoming messages using the **reduce function** $\rho$. 

### DGL's Message Passing APIs

`update_all()` is the API for message passing and aggregation. 
- `message_func`: how to collect features from source nodes. The input is a single argument `edges`. It has three members `src`, `dst` and `data` to access features of source nodes, destination nodes, and edges, respectively. 
- `reduce_func`: how to aggregate. It takes a single argument `nodes`. It has member `mailbox` to access the messages received for the nodes in the batch. Some of the most common reduce operations include `sum`, `max`, `min`, etc. 

##  GraphSAGE

$$h^k_{\mathcal{N}(v)} \leftarrow \text{Average}\{h^{k-1}_u, \forall u\in \mathcal{N}(v)\}$$

$$h^k_v \leftarrow \text{ReLU}\left(W^k \cdot \text{CONCAT}\left(h^{k-1}_v, h^k_{\mathcal{N}(v)}\right)\right)$$



In [1]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

Using backend: pytorch


In [2]:
import dgl.function as fn

In [22]:
class GraphSage(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(GraphSage, self).__init__()
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        # print(h.shape)
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            # print(h_N.shape)
            h_total = torch.cat([h, h_N], dim=1)
            # print(h_total.shape)
            return self.linear(h_total)

In [5]:
import dgl.data

In [6]:
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [7]:
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']

In [12]:
print(features.shape)
print(labels.shape)

torch.Size([2708, 1433])
torch.Size([2708])


In [25]:
model = GraphSage(features.shape[1], 16)

In [26]:
y = model(g, features)

In [27]:
y.shape

torch.Size([2708, 16])

In [30]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = GraphSage(in_feats, h_feats)
        self.conv2 = GraphSage(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [33]:
dataset.num_classes

7

In [31]:
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)

In [34]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [37]:
best_val_acc = 0
best_test_acc = 0

for e in range(200):
    logits = model(g, features)
    pred = logits.argmax(1)
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    
    # Compute accuracy on training/validation/test
    train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
    val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
    test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

    # Save the best validation accuracy and the corresponding test accuracy.
    if best_val_acc < val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if e % 5 == 0:
        print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
            e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
    

In epoch 0, loss: 1.951, val acc: 0.116 (best 0.116), test acc: 0.120 (best 0.120)
In epoch 5, loss: 1.897, val acc: 0.126 (best 0.126), test acc: 0.138 (best 0.138)
In epoch 10, loss: 1.784, val acc: 0.338 (best 0.338), test acc: 0.344 (best 0.344)
In epoch 15, loss: 1.607, val acc: 0.414 (best 0.414), test acc: 0.422 (best 0.422)
In epoch 20, loss: 1.367, val acc: 0.484 (best 0.484), test acc: 0.499 (best 0.499)
In epoch 25, loss: 1.084, val acc: 0.548 (best 0.548), test acc: 0.549 (best 0.549)
In epoch 30, loss: 0.793, val acc: 0.608 (best 0.608), test acc: 0.595 (best 0.595)
In epoch 35, loss: 0.538, val acc: 0.682 (best 0.682), test acc: 0.651 (best 0.651)
In epoch 40, loss: 0.344, val acc: 0.712 (best 0.712), test acc: 0.712 (best 0.712)
In epoch 45, loss: 0.214, val acc: 0.728 (best 0.728), test acc: 0.743 (best 0.743)
In epoch 50, loss: 0.134, val acc: 0.730 (best 0.730), test acc: 0.754 (best 0.754)
In epoch 55, loss: 0.087, val acc: 0.732 (best 0.732), test acc: 0.752 (best 0