# 3장: GNN 모듈 만들기
이 장에서는 PyTorch 백엔드를 사용한 SAGEConv 를 예제로 커스텀 DGL NN 모듈을 만드는 방법을 소개

## 3.1 DGL NN 모듈 생성 함수
생성 함수는 다음 단계들을 수행한다:
1. 옵션 설정
2. 학습할 파라메터 또는 서브모듈 등록
3. 파라메터 리셋

In [None]:
import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

In [None]:
# aggregator type: mean, pool, lstm, gcn
if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
    raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'pool':
    self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
    self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'pool', 'lstm']:
    self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()

In [None]:
def reset_parameters(self):
    """Reinitialize learnable parameters."""
    gain = nn.init.calculate_gain('relu')
    if self._aggre_type == 'pool':
        nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
    if self._aggre_type == 'lstm':
        self.lstm.reset_parameters()
    if self._aggre_type != 'gcn':
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
    nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

## 3.2 DGL NN 모듈의 Forward 함수
- NN 모듈에서 `forward()` 함수는 실제 메시지 전달과 연산을 수행
- `forward()` 함수는 3단계로 수행
    - 그래프 체크 및 그래프 타입 명세화
    - 메시지 전달
    - 피쳐 업데이트

#### (1) 그래프 체크와 그래프 타입 명세화(graph type specification)

In [None]:
def forward(self, graph, feat):
    with graph.local_scope():
        # Specify graph type then expand input feature according to graph type
        feat_src, feat_dst = expand_as_pair(feat, graph)

In [None]:
def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # Bipartite graph case
        return input_
    elif g is not None and g.is_block:
        # Subgraph block case
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # Homogeneous graph case
        return input_, input_

#### (2) 메시지 전달과 축약

In [None]:
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape

if self._aggre_type == 'mean':
    graph.srcdata['h'] = feat_src
    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
    check_eq_shape(feat)
    graph.srcdata['h'] = feat_src
    graph.dstdata['h'] = feat_dst
    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
    # divide in_degrees
    degs = graph.in_degrees().to(feat_dst)
    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
else:
    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
    rst = self.fc_neigh(h_neigh)
else:
    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

#### (3) 출력값을 위한 축약 후 피쳐 업데이트

In [None]:
# activation
if self.activation is not None:
    rst = self.activation(rst)
# normalization
if self.norm is not None:
    rst = self.norm(rst)
return rst

## 3.3 Heterogeneous GraphConv 모듈
#### (1) HeteroGraphConv 구현 로직:

In [None]:
import torch.nn as nn

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # An internal function to get common aggregation functions
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

In [None]:
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
    if mod_args is None:
        mod_args = {}
    if mod_kwargs is None:
        mod_kwargs = {}
    outputs = {nty : [] for nty in g.dsttypes}

In [None]:
if g.is_block:
    src_inputs = inputs
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
    src_inputs = dst_inputs = inputs

for stype, etype, dtype in g.canonical_etypes:
    rel_graph = g[stype, etype, dtype]
    if rel_graph.num_edges() == 0:
        continue
    if stype not in src_inputs or dtype not in dst_inputs:
        continue
    dstdata = self.mods[etype](
        rel_graph,
        (src_inputs[stype], dst_inputs[dtype]),
        *mod_args.get(etype, ()),
        **mod_kwargs.get(etype, {}))
    outputs[dtype].append(dstdata)

In [None]:
rsts = {}
for nty, alist in outputs.items():
    if len(alist) != 0:
        rsts[nty] = self.agg_fn(alist, nty)