In [1]:
import dgl 
import torch 
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F 

## Message Passing and GNNs

DGL 패키지의 `Message Passing`은 [Gilmer et al.](https://arxiv.org/abs/1704.01212) 가 제안한 MPNN에서 영감을 받았으며, MPNN의 수식은 다음과 같다. 

$$ m^{(l)}_{u \rightarrow v} = M^{(l)}(h^{(l-1)}_v, h^{(l-1)}_u, e^{(l-1)}_{u \rightarrow v}) \cdot m^{(l)}_v = \sum_{u \in \mathcal{N}(v)} m^{(l)}_{u \rightarrow v} h^{(l)}_v = U^{(l)}(h^{(l-1)}_v, m^{(l)}_v) $$

$M^{(l)}$은 `Message Passing` 함수를 의미하고, $\sum$은 reduce function, 그리고 $U^{(l)}$은 update function을 의미한다. 이때 $\sum$은 합계를 의미할 수도 있고 다른 함수가 될 수도 있습니다.

`Message Passing`은 주변 노드로 부터 `Message`를 받는 다는 것을 의미합니다. 

본 튜토리얼에서는 2017년에 제안된 `GraphSAGE`를 사용하며, 수식은 아래와 같습니다.

$$ h^k_{\mathcal{N}(v)} \leftarrow \text{Average}\{ h^{k-1}_u, \forall u \in \mathcal{N}\}h^k_v \leftarrow \text{ReLU} ( W^k \cdot \text{CONCAT}(h^{k-1}_v, h^k_{\mathcal{N}(v)})) $$ 

In [7]:


import dgl.function as fn 

class SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(SAGEConv, self).__init__()
        self.linear = nn.Linear(in_feats*2, out_feats)

    def forward(self, g, h):
        '''
        fn.copy_u: Message function을 의미합니다. Neighbor에게 보내는 message로 node feature `h`를 복사합니다.
        fn.mean(`m`, `h_N`): Reduce function을 의미합니다. `m`은 받은 message의 평균을 의미하고, 결과를 새로운 node feature `h_N`에 저장합니다.

        '''
        with g.local_scope(): # local의 범위를 g로 지정함으로써 여기에서 계산된 값은 original graph에 반영되지 않기 때문에 forward 연산을 수행하기 쉬워집니다.
            g.ndata['h'] = h 
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)


class Model(nn.Module):
    def __init__(self, in_feats, h_feats, n_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, n_classes)
        self.relu = nn.ReLU()

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = self.relu(h)
        h = self.conv2(g, h)
        return h 

## Training Loop

In [8]:
import dgl.data 


dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

def criterion(pred_y, true_y, mask):
    pred_y = pred_y[mask]
    true_y = true_y[mask]
    return F.cross_entropy(pred_y, true_y)

def accuracy(pred_y, true_y, mask):
    pred_y = pred_y[mask].argmax(1)
    true_y = true_y[mask]
    return (pred_y == true_y).float().mean()

def trainer(g, model, n_epochs, device, optimizer, criterion):
    all_logits = []
    best_val_acc = 0 
    best_test_acc = 0 

    g = g.to(device)
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']

    for epoch in range(1, n_epochs+1):

        model.train()
        pred_y = model(g, features).to(device)
        loss = criterion(pred_y, labels, train_mask)
        train_acc = accuracy(pred_y, labels, train_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        val_acc = accuracy(pred_y, labels, val_mask)
        test_acc = accuracy(pred_y, labels, test_mask)


        if best_val_acc < val_acc :
            best_val_acc = val_acc 

        if epoch % 50 == 0 :
            print(f'epoch: [{epoch}/{n_epochs}]\ttrain acc: {train_acc*100:.2f}\tval acc: {val_acc*100:.2f}\ttest acc: {test_acc*100:.2f}')

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
g = g.to(device)
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
trainer(g, model, 1000, device, optimizer, criterion)

epoch: [50/1000]	train acc: 100.00	val acc: 75.20	test acc: 75.90
epoch: [100/1000]	train acc: 100.00	val acc: 75.20	test acc: 75.00
epoch: [150/1000]	train acc: 100.00	val acc: 75.60	test acc: 74.40
epoch: [200/1000]	train acc: 100.00	val acc: 76.00	test acc: 74.60
epoch: [250/1000]	train acc: 100.00	val acc: 75.80	test acc: 74.60
epoch: [300/1000]	train acc: 100.00	val acc: 76.00	test acc: 74.90
epoch: [350/1000]	train acc: 100.00	val acc: 75.80	test acc: 75.10
epoch: [400/1000]	train acc: 100.00	val acc: 75.80	test acc: 74.90
epoch: [450/1000]	train acc: 100.00	val acc: 75.80	test acc: 74.90
epoch: [500/1000]	train acc: 100.00	val acc: 75.80	test acc: 75.00
epoch: [550/1000]	train acc: 100.00	val acc: 75.40	test acc: 75.00
epoch: [600/1000]	train acc: 100.00	val acc: 75.40	test acc: 75.10
epoch: [650/1000]	train acc: 100.00	val acc: 75.20	test acc: 75.30
epoch: [700/1000]	train acc: 100.00	val acc: 75.40	test acc: 75.30
epoch: [750/1000]	train acc: 100.00	val acc: 75.80	test acc: 75