#### Analyzing HGTconv layer  
- source code: [here](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/hgt_conv.html#HGTConv)  
- docs: [here](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.HGTConv)  

**arguments**  
- in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.
- out_channels (int) – Size of each output sample.
- metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph
- heads (int, optional) – Number of multi-head-attentions. 
- group (string, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. (sum, mean, min, max)
- **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

In [1]:
from typing import Tuple, Union, Dict, Optional, List
from torch_geometric.typing import NodeType, EdgeType, Metadata

import math

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from torch_sparse import SparseTensor
from torch_geometric.nn.dense import Linear
from torch_geometric.utils import softmax
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, ones, reset

In [None]:
class HGTConv(MessagePassing):
    def __init__(
        self,
        in_channels: Union[int, Dict[str, int]],
        out_channels: int,
        metadata: Metadata,
        heads: int = 1,
        group: str = "sum",
        **kwargs,
    ):

        super().__init__(aggr='add', node_dim=0, **kwargs)

default value of `metadata` argument is torch_geometric.typing.  

```python
Metadata = Tuple[List[NodeType], List[EdgeType]]

# example: (['author', 'paper', 'term', 'conference'], [('author', 'to', 'paper'), (), ()...])
```

so the first element of `metadata` is the list of node types

init method continues as follows...

```python
    # metadata[0] represents the list of node types
    # ex) ['author', 'paper', 'term', 'conference']
    if not isinstance(in_channels, dict):
        in_channels = {node_type: in_channels for node_type in metadata[0]}

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.heads = heads
    self.group = group
````

`in_channels` are same regardless of node type because **linear dict** will be applied before using `HGTConv` layer.  
`in_channels` will be used to all linear layers (k_in, q_lin, v_lin, a_lin).

```python
    self.k_lin = torch.nn.ModuleDict()
    self.q_lin = torch.nn.ModuleDict()
    self.v_lin = torch.nn.ModuleDict()
    self.a_lin = torch.nn.ModuleDict()
    self.skip = torch.nn.ParameterDict()

    for node_type, in_channels in self.in_channels.items():
        self.k_lin[node_type] = Linear(in_channels, out_channels)
        self.q_lin[node_type] = Linear(in_channels, out_channels)
        self.v_lin[node_type] = Linear(in_channels, out_channels)
        self.a_lin[node_type] = Linear(out_channels, out_channels)
        self.skip[node_type] = Parameter(torch.Tensor(1))

```

`a_rel` means relation parameter matrix for attention: $W^{ATT}$,  
`m_rel` means relation parameter matrix for message: $W^{MSG}$,  
`p_rel` means position encoding parameter matrix  

```python
    self.a_rel = torch.nn.ParameterDict()
    self.m_rel = torch.nn.ParameterDict()
    self.p_rel = torch.nn.ParameterDict()

    dim = out_channels // heads
    for edge_type in metadata[1]:
        edge_type = '__'.join(edge_type)
        self.a_rel[edge_type] = Parameter(torch.Tensor(heads, dim, dim))
        self.m_rel[edge_type] = Parameter(torch.Tensor(heads, dim, dim))
        self.p_rel[edge_type] = Parameter(torch.Tensor(heads))
```

so the **init method** ends here  

Now let's see **forward** method.  

In [None]:
def forward(
    self,
    x_dict: Dict[NodeType, Tensor],
    edge_index_dict: Union[Dict[EdgeType, Tensor],
                            Dict[EdgeType, SparseTensor]]  # Support both.
) -> Dict[NodeType, Optional[Tensor]]:

given `x_dict` and `edge_index_dict`,  

```
x_dict
 A dictionary holding input node features for each individual node type.

edge_index_dict
 A dictionary holding graph connectivity information for each
 individual edge type, either as a `torch.LongTensor` of
 shape [2, num_edges] or a `torch_sparse.SparseTensor`.
```

Here `x_dict` represents $H^{l-1}$ which means node feature matrix of layer $l-1$  

```python
    H, D = self.heads, self.out_channels // self.heads

    k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {}

    # 1) pass through k/q/v linear layer
    for node_type, x in x_dict.items():
        k_dict[node_type] = self.k_lin[node_type](x).view(-1, H, D)
        q_dict[node_type] = self.q_lin[node_type](x).view(-1, H, D)
        v_dict[node_type] = self.v_lin[node_type](x).view(-1, H, D)
        out_dict[node_type] = []
```


The next thing we should do is that 

```python
    # 2) Iterate over edge-types
    # Example
    # data.edge_index_dict[('author', 'to', 'paper')].shape -> (2, 19645)
    for edge_type, edge_index in edge_index_dict.items():
        src_type, _, dst_type = edge_type
        edge_type = '__'.join(edge_type)

        # ready for heterogenous mutual attention
        a_rel = self.a_rel[edge_type]
        k = (k_dict[src_type].transpose(0, 1) @ a_rel).transpose(1, 0)

        # ready for heterogenous message passing
        m_rel = self.m_rel[edge_type]
        v = (v_dict[src_type].transpose(0, 1) @ m_rel).transpose(1, 0)

        # propagate_type: (k: Tensor, q: Tensor, v: Tensor, rel: Tensor)
        out = self.propagate(
            edge_index, k=k, q=q_dict[dst_type], v=v,
            rel=self.p_rel[edge_type], size=None)
        out_dict[dst_type].append(out)
```

Here we must check the **message** method first.

```python
def message(
    self,
    k_j: Tensor, q_i: Tensor, v_j: Tensor, rel: Tensor,
    index: Tensor,
    ptr: Optional[Tensor],
    size_i: Optional[int]
    ) -> Tensor:

    alpha = (q_i * k_j).sum(dim=-1) * rel
    alpha = alpha / math.sqrt(q_i.size(-1))
    alpha = softmax(alpha, index, ptr, size_i)
    out = v_j * alpha.view(-1, self.heads, 1)
    return out.view(-1, self.out_channels)
```

```python
    # 3) Iterate over node-types
    for node_type, outs in out_dict.items():
        out = group(outs, self.group)

        if out is None:
            out_dict[node_type] = None
            continue

        out = self.a_lin[node_type](F.gelu(out))
        if out.size(-1) == x_dict[node_type].size(-1):
            alpha = self.skip[node_type].sigmoid()
            out = alpha * out + (1 - alpha) * x_dict[node_type]
        out_dict[node_type] = out

    return out_dict
```

**group** function is like this:

```python
def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
    if len(xs) == 0:
        return None
    elif aggr is None:
        return torch.stack(xs, dim=1)
    elif len(xs) == 1:
        return xs[0]
    else:
        out = torch.stack(xs, dim=0)
        out = getattr(torch, aggr)(out, dim=0)
        out = out[0] if isinstance(out, tuple) else out
        return out
```