## $\text{Training GNNs on Large Graphs}$

실제로 우리가 다루는 그래프는 엄청난 양의 데이터를 가지고 있습니다.

그렇기에 그래프의 node와 edge를 다루기에는 엄청난 양의 저장 공간이 요구됩니다. 

만약 GPU를 사용해서 연산 속도를 향상시키고자 하는 경우, GPU에서 전체 그래프의 훈련이 불가능한 경우가 많습니다. (데이터가 너무 많아서.)

따라서, 본 tutorial에서는 위와 같은 문제를 해결하고자 다음과 같은 두 가지 방법을 다룹니다. 

1. Stochastic training on graphs.
2. Neighbor sampling on graphs.

In [1]:
import dgl 
import dgl.function as fn 
from dgl.nn.pytorch import conv as dgl_conv 

import torch 
import torch.nn as nn 
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


## $\text{Mini-batch Construction from a Graph} $

일반적인 딥러닝에서는 stochastic training을 위해 학습용 데이터를 mini-batch 형태로 분할하고 각 단계에 필요한 정보만 GPU에 넣어서 연산을 진행합니다. 

Node Classification을 수행하는 경우에는 다음과 같습니다. 

In [2]:
import networkx as nx 

example_graph_nx = nx.Graph(
    [(0, 2), (0, 4), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10), 
     (1, 2), (1, 3), (1, 5), (2, 3), (2, 4), (2, 6), (3, 5),
     (3, 8), (4, 7), (8, 9), (8, 11), (9, 10), (9, 11)]
)

example_graph = dgl.from_networkx(example_graph_nx)

![image1](asset/figure_1.JPG)

## $\text{Single Layer} $

GraphSAGE layer를 사용하여 node 4와 6의 output representation을 계산하고자 하는 경우, node 4와 6의 input feature와 이웃에 대한 input feature 필요합니다. 

위 예제의 경우 node 4와 6의 feature, 그리고 두 node의 이웃인 node 0, 2, 7에 대한 feature도 요구됩니다.

위 feature를 mini-batch 구조로 만들기 위해 DGL API: `dgl.sample_neighbors`를 사용합니다. 

In [3]:
sampled_node_batch = torch.LongTensor([4, 6])
sampled_graph = dgl.sampling.sample_neighbors(example_graph, sampled_node_batch, 2)
print(f'|V|={sampled_graph.number_of_nodes()} |E|={sampled_graph.number_of_edges()}')
src, dst = sampled_graph.all_edges()

for s, d in zip(src, dst):
    print(s.numpy(), d.numpy())

|V|=12 |E|=4
7 4
0 4
0 6
2 6


`DGL`은 데이터 구조를 잘 반영하기 위해 `block`을 제공합니다. sub graph는 `dgl.to_block` 함수를 사용하여 `block`으로 쉽게 변환할 수 있습니다.

In [4]:
sampled_block = dgl.to_block(sampled_graph, sampled_node_batch)

def print_block_info(sampled_block):
    print('#source:', sampled_block.number_of_src_nodes())
    sampled_input_nodes = sampled_block.srcdata[dgl.NID]
    print('Node ID of source nodes in original graph:', sampled_input_nodes)
    
    sampled_output_nodes = sampled_block.dstdata[dgl.NID]
    print('#destination:', sampled_block.number_of_dst_nodes())
    print('Node ID of destination nodes in original graph:', sampled_output_nodes)
    
    sampled_block_edges_src, sampled_block_edges_dst = sampled_block.all_edges()
    print('edges in local node ids')
    
    for s, d in zip(sampled_block_edges_src, sampled_block_edges_dst):
        print(s.numpy(), d.numpy())
    
    sampled_block_edges_src_mapped = sampled_input_nodes[sampled_block_edges_src]
    sampled_block_edges_dst_mapped = sampled_output_nodes[sampled_block_edges_dst]
    print('edges in the original node ids')
    for s, d in zip(sampled_block_edges_src_mapped, sampled_block_edges_dst_mapped):
        print(s.numpy(), d.numpy())
    
print_block_info(sampled_block)

#source: 5
Node ID of source nodes in original graph: tensor([4, 6, 7, 0, 2])
#destination: 2
Node ID of destination nodes in original graph: tensor([4, 6])
edges in local node ids
2 0
3 0
3 1
4 1
edges in the original node ids
7 4
0 4
0 6
2 6


### $\text{Multiple Layers} $

지금부터 2-layer GraphSAGE를 사용하여 node 4와 6의 output을 계산하려고 합니다. 

node 4와 6의 output을 계산하기 위해서는 전체 그래프의 데이터가 필요한 것이 아니고, node 4, 6의 feature 그리고, 4, 6의 이웃 node가 필요합니다. 

따라서, 우리는 `NeighborSampler`를 통해 이웃 node를 sampling하고자 합니다. 

In [5]:
class NeighborSampler(object):
    def __init__(self, g, num_fanouts):
        """
        num_fanouts : list of fanout on each layer.
        """
        
        self.g = g 
        self.num_fanouts = num_fanouts 
        
    def sample(self, seeds):
        seeds = torch.LongTensor(self.num_fanouts)
        blocks = []
        for fanout in reversed(self.num_fanouts):
            if fanout >= self.g.number_of_nodes():
                sampled_graph = dgl.in_subgraph(self.g, seeds)
            
            else:
                sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
            
            sampled_block = dgl.to_block(sampled_graph, seeds)
            seeds = sampled_block.srcdata[dgl.NID]
            blocks.insert(0, sampled_block)
        return blocks 

In [6]:
block_sampler = NeighborSampler(example_graph, [2, 2])
sampled_blocks = block_sampler.sample(sampled_node_batch)

print('#blocks:', len(sampled_blocks))
print('Block for first layer')
print('---------------------')
print_block_info(sampled_blocks[0])
print('\nBlock for second layer')
print('----------------------')
print_block_info(sampled_blocks[1])

#blocks: 2
Block for first layer
---------------------
#source: 8
Node ID of source nodes in original graph: tensor([2, 1, 0, 3, 4, 5, 9, 8])
#destination: 4
Node ID of destination nodes in original graph: tensor([2, 1, 0, 3])
edges in local node ids
3 0
4 0
0 1
5 1
6 2
0 2
5 3
7 3
edges in the original node ids
3 2
4 2
2 1
5 1
9 0
2 0
5 3
8 3

Block for second layer
----------------------
#source: 4
Node ID of source nodes in original graph: tensor([2, 1, 0, 3])
#destination: 1
Node ID of destination nodes in original graph: tensor([2, 2])
edges in local node ids
1 0
2 0
2 0
3 0
edges in the original node ids
1 2
0 2
0 2
3 2


### $ \text{Minibatch training for 2-layer GraphSage} $

GraphSAGE on blocks 

sampled block은 bipartite graph 입니다. 이를 구축하기 위해서는 `dgl`의 `SAGEConv`를 사용하면 쉽게 구축이 가능합니다. 

`dgl`은 bipartite graph 만 지원하는 것 뿐만 아니라 homogenouse graph도 지원하기 때문에 목적에 맞게 사용할 수 있습니다.

In [7]:
import dgl.nn as dglnn
import torch.nn.functional as F 

class SAGENet(nn.Module):
    def __init__(self, n_layers, in_feats, out_feats, hidden_feats=None):
        super().__init__()
        self.convs = nn.ModuleList()
        
        if hidden_feats is None:
            hidden_feats = out_feats 
        
        if n_layers == 1:
            self.convs.append(dglnn.SAGEConv(in_feats, out_feats, 'mean'))
        
        else:
            self.convs.append(dglnn.SAGEConv(in_feats, hidden_feats, 'mean', activation=F.relu))
            for i in range(n_layers -2):
                self.convs.append(dglnn.SAGEConv(hidden_feats, hidden_feats, 'mean', activation=F.relu))
            self.convs.append(dglnn.SAGEConv(hidden_feats, out_feats, 'mean', activation=F.relu))
        
    def forward(self, blocks, input_features):
        """
        blocks: List of blocks generated block sampler.
        input_features: Input feature of the first block.
        """
        h = input_features 
        for layer, block in zip(self.convs, blocks):
            h = self.propagate(block, h, layer)
        return h 
    
    def propagate(self, block, src_feats, layer):
        '''
        graphSAGE를 적용하기 위해서는 이웃 feature 뿐만 아니라 node에 대한 feature도 필요합니다. 
        따라서, 현재 layer에 있는 output node의 feature를 복사한 후 사용하여야 합니다. 
        block의 ouptut node는 input node에서 첫 번째로 나타나기에 아래와 같은 코드로 사용이 가능합니다. 
        '''
        dst_feats = src_feats[:block.number_of_dst_nodes()]
        return layer(block, (src_feats, dst_feats))

### $\text{Inference with mini-batch} $

mini-batch fashion에서 inference를 하기 위해서는 먼저, 첫 번째 GraphSAGE layer에서 노드에 대한 representation을 계산하여야 합니다.

$1^{\text{th}}$-layer에서 계산한 representation을 바탕으로 $2^{\text{nd}}$-layer를 계산합니다. 이와 같은 과정을 last layer까지 반복한 후 결과를 산출합니다.

In [8]:
def inference_with_sagenet(sagenet, graph, input_features, batch_size):
    block_sampler = NeighborSampler(graph, [graph.number_of_nodes()])
    h = input_features
    
    with torch.no_grad():
        for conv in sagenet.convs:
            new_h_list = []
            node_ids = torch.arange((graph.number_of_nodes()))
            
            for batch_start in range(0, graph.number_of_nodes(), batch_size):
                block = block_sampler.sample(node_ids[batch_start:batch_start+batch_size])[0]
                input_node_ids = block.srcdata[dgl.NID]
                
                h_input = h[input_node_ids]
                
                new_h = sagenet.propagate(block, h_input, conv)
                new_h_list.append(new_h)
            
            h = torch.cat(new_h_list)
    
    return h

### $ \text{Load Dataset} $

In [9]:
import dgl.data 

dataset = dgl.data.citation_graph.load_pubmed()

graph = dataset[0]


in_feats = graph.ndata['feat'].shape[1]
num_labels = dataset.num_classes

train_nid = torch.where(graph.ndata['train_mask'])[0]
val_nid = torch.where(graph.ndata['val_mask'])[0]
test_nid = torch.where(graph.ndata['test_mask'])[0]

  NumNodes: 19717
  NumEdges: 88651
  NumFeats: 500
  NumClasses: 3
  NumTrainingSamples: 60
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [10]:
neighbor_sampler = NeighborSampler(graph, [10, 25])

In [12]:
from torch.utils.data import DataLoader 
import torch.optim as optim 
import torch.nn as nn 

BATCH_SIZE = 100

train_dataloader = DataLoader(train_nid, batch_size = BATCH_SIZE, collate_fn=neighbor_sampler.sample, shuffle=True)

HIDDEN_FEATURES = 50 
model = SAGENet(2, in_feats, num_labels, HIDDEN_FEATURES)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

### $\text{Evaluation} $

In [13]:
def calc_accuracy(pred, true):
    pred = torch.argmax(pred, dim=1)
    return (pred == true).float().mean().item()

### $ \text{Training Loop} $

In [None]:
NUM_EPOCHS = 10
EVAL_BATCH_SIZE = 10

for epoch in range(NUM_EPOCHS):
    model.train()
    for blocks in train_dataloader:
        input_nodes = blocks[0].srcdata[dgl.NID]
        output_nodes = blocks[-1].dstdata[dgl.NID]
        
        input_features = graph.ndata['feat'][input_nodes]
        output_labels = graph.ndata['label'][output_nodes]
        
        output_pred = model(blocks, input_features)
        loss = criterion(output_pred, output_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
    if (epoch + 1) % 5 == 0:
        model.eval()
        all_predictions = inference_with_sagenet(model, graph, graph.ndata['feat'], EVAL_BATCH_SIZE)
        
        val_pred = all_predictions[val_nid]
        val_labels = graph.ndata['labels'][val_nid]
        test_pred = all_predictions[test_nid]
        test_labels = graph.ndata['labels'][test_nid]
        
        print('Validation acc: ', calc_accuracy(val_pred, val_labels), 'Test acc: ', calc_accuracy(test_pred, test_labels))