## Stochastic Training of GNN with Multiple GPUs

본 튜토리얼은 multi-layer GraphSAGE를 학습할 때 multiple GPU를 사용하는 방법에 대해서 다룹니다. 

이때 `Pytorch`의 `DistributedDataParallel`을 사용하는 것이 일반적인 솔루션입니다. 

In [None]:
import numpy as np 
import dgl 
import dgl.nn as dglnn

import torch 
import torch.nn as nn 
from torch.nn.parallel import DistributedDataParallel
import torch.nn.functional as F 
import torch.multiprocessing as mp
import sklearn.metrics as metrics 
import tqdm 

import utils_KDD 

## Load Dataset 

In [None]:
def load_data():
    import pickle 

    with open('../data.pkl', 'rb') as f:
        data = pickle.load(f)
    graph, node_features, node_labels, train_nids, valid_nids, test_nids = data 
    utils_KDD.prepare_mp(graph)

    num_features = node_features.shape[1]
    num_classes = (node_labels.max() + 1).item()

    return graph, node_features, node_labels, train_nids, valid_nids, test_nids, num_features, num_classes


## Customized Neigborhood Sampling

이웃을 샘플링 할 때 `MultiLayerNeighborSampler`를 사용하면 쉽게 이웃을 샘플링할 수 있습니다. 

In [None]:
class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, fanouts):
        super().__init__(len(fanouts), return_eids = False)
        self.fanouts = fanouts 

    def sample_frontier(self, layer_id, g, seed_nodes):
        fanout = self.fanouts[layer_id]
        return dgl.sampling.sample_neighbors(g, seed_nodes, fanout)

## Defining Data Loader for Distributed Data Parallel (DDP)

Pytorch DDP 각각의 worker process는 `rank`를 할당합니다. `rank`는 worker process가 처리할 데이터의 파티션을 나타냅니다. 

병렬처리를 하지 않는 경우 `MultiLayerNeighborSampler`를 통해 이웃을 샘플링하는 단계만 구성하면 되지만, 병렬처리를 진행하면 다음과 같이 `rank`를 지정하고 `rank`에 각각의 파티션을 부여해 병렬적으로 처리가 가능합니다. 

In [2]:
def create_dataloader(rank, world_size, graph, nids):
    partition_size = len(nids) // world_size 
    partition_offset = partition_size * rank 
    nids = nids[partition_offset:partition_offset + partition_size]

    sampler = MultiLayerNeighborSampler([4, 4, 4])
    dataloader = dgl.dataloading.NodeDataLoader( 
        graph, nids, sampler, 
        batch_size = 1024, 
        shuffle = True, 
        drop_last = False, 
        num_workers = 0 
    )
    return dataloader 

## Defining Model 

In [None]:
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        self.in_feats = in_feats 
        self.n_hidden = n_hidden 
        self.n_layers = nn.ModuleList() 
        self.n_layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, self.n_layers-1):
            self.n_layers.append(n_hidden, n_hidden, 'mean')
        self.n_layers.append(n_hidden, n_classes, 'mean')
    
    def forward(self, bipartites, x):
        for l, (layer, bipartite) in enumerate(zip(self.n_layers, bipartites)):
            x = layer(bipartite, x)
            if l != self.n_layers -1:
                x = F.relu(x)
        return x 

## Distributing the Model to GPUs

In [3]:
def init_model(rank, in_feats, n_hidden, n_classes, n_layers):
    model = GraphSAGE(in_feats, n_hidden, n_classes, n_layers).to(rank)
    return DistributedDataParallel(model, device_ids=[rank], output_device=rank)

In [None]:
@utils_KDD.fix_openmp
def train(rank, world_size, data):
    torch.