#### anlayzing_hgt_loader  
- docs: [here](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.HGTLoader)  
- source code: [here](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/hgt_loader.html#HGTLoader)  

In [1]:
import os

from typing import Union, Dict, List, Tuple, Callable, Optional, Any
from torch_geometric.typing import NodeType

import torch
from torch import Tensor
import torch_geometric.transforms as T

from torch_geometric.data import HeteroData
from torch_geometric.datasets import OGB_MAG

from torch_geometric.loader.base import BaseDataLoader
from torch_geometric.loader.utils import to_hetero_csc, filter_hetero_data
from torch_geometric.loader import HGTLoader

`HGTLoader` is the child class of `BaseDataLoader`  

```python
class BaseDataLoader(DataLoader):
    r"""Extends the :class:`torch.utils.data.DataLoader` by integrating a
    custom :meth:`self.transform_fn` function to allow transformation of a
    returned mini-batch directly inside the main process.
    """
    def _get_iterator(self) -> Iterator:
        iterator = super()._get_iterator()
        if hasattr(self, 'transform_fn'):
            iterator = DataLoaderIterator(iterator, self.transform_fn)
        return iterator
```

In [None]:
class HGTLoader(BaseDataLoader):
    def __init__(
        self,
        data: HeteroData,
        num_samples: Union[List[int], Dict[NodeType, List[int]]],
        input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]],
        transform: Callable = None,
        **kwargs,
    ):
        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']

        if isinstance(num_samples, (list, tuple)):
            num_samples = {key: num_samples for key in data.node_types}

        if isinstance(input_nodes, str):
            input_nodes = (input_nodes, None)
        assert isinstance(input_nodes, (list, tuple))
        assert len(input_nodes) == 2
        assert isinstance(input_nodes[0], str)
        if input_nodes[1] is None:
            index = torch.arange(data[input_nodes[0]].num_nodes)
            input_nodes = (input_nodes[0], index)
        elif input_nodes[1].dtype == torch.bool:
            index = input_nodes[1].nonzero(as_tuple=False).view(-1)
            input_nodes = (input_nodes[0], index)

        self.data = data
        self.num_samples = num_samples
        self.input_nodes = input_nodes
        self.num_hops = max([len(v) for v in num_samples.values()])
        self.transform = transform
        self.sample_fn = torch.ops.torch_sparse.hgt_sample

        # Convert the graph data into a suitable format for sampling.
        # NOTE: Since C++ cannot take dictionaries with tuples as key as
        # input, edge type triplets are converted into single strings.
        self.colptr_dict, self.row_dict, self.perm_dict = to_hetero_csc(
            data, device='cpu')

        super().__init__(input_nodes[1].tolist(), collate_fn=self.sample,
                         **kwargs)

    def sample(self, indices: List[int]) -> HeteroData:
        input_node_dict = {self.input_nodes[0]: torch.tensor(indices)}
        node_dict, row_dict, col_dict, edge_dict = self.sample_fn(
            self.colptr_dict,
            self.row_dict,
            input_node_dict,
            self.num_samples,
            self.num_hops,
        )
        return node_dict, row_dict, col_dict, edge_dict, len(indices)

    def transform_fn(self, out: Any) -> HeteroData:
        node_dict, row_dict, col_dict, edge_dict, batch_size = out
        data = filter_hetero_data(self.data, node_dict, row_dict, col_dict,
                                  edge_dict, self.perm_dict)
        data[self.input_nodes[0]].batch_size = batch_size

        return data if self.transform is None else self.transform(data)

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = os.path.join(os.getcwd(), 'data/OGB_MAG')
transform = T.ToUndirected(merge=True)
dataset = OGB_MAG(path, preprocess='metapath2vec', transform=transform)

Downloading https://data.pyg.org/datasets/mag_metapath2vec_emb.zip
Extracting c:\Users\Youyoung\Documents\hetegoenous-graph-transformer\data\OGB_MAG\mag\raw\mag_metapath2vec_emb.zip
Processing...
Done!


In [3]:
# send node features/labels in advance to GPU for faster access during sampling:
hetero_data = dataset[0].to(device, 'x', 'y')

In [5]:
hetero_data

HeteroData(
  [1mpaper[0m={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  [1mauthor[0m={ x=[1134649, 128] },
  [1minstitution[0m={ x=[8740, 128] },
  [1mfield_of_study[0m={ x=[59965, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 1043998] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 7145660] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 10792672] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 7505078] },
  [1m(institution, rev_affiliated_with, author)[0m={ edge_index=[2, 1043998] },
  [1m(paper, rev_writes, author)[0m={ edge_index=[2, 7145660] },
  [1m(field_of_study, rev_has_topic, paper)[0m={ edge_index=[2, 7505078] }
)

**num_samples**  
The number of nodes
to sample in each iteration and for each node type.  
If given as a list, will sample the same amount of nodes for each node type.

**input_nodes**  
The indices of nodes for which neighbors are sampled to create mini-batches.  
Needs to be passed as a tuple that holds the node type and corresponding node indices.  
If node indices are set to :obj: `None`, all nodes of this specific type will be considered.  

In [10]:
train_input_nodes = ('paper', hetero_data['paper'].train_mask)
kwargs = {'batch_size': 1024}

# Sample 32 nodes per type and per iteration for 4 iterations

train_loader = HGTLoader(
    hetero_data,
    num_samples=[32] * 4,
    shuffle=True,
    input_nodes=train_input_nodes,
    **kwargs)

In [15]:
# Initialize lazy parameters via forwarding a single batch to the model:
batch = next(iter(train_loader))
batch = batch.to(device, 'edge_index')

In [16]:
batch

HeteroData(
  [1mpaper[0m={
    x=[1152, 128],
    year=[1152],
    y=[1152],
    train_mask=[1152],
    val_mask=[1152],
    test_mask=[1152],
    batch_size=1024
  },
  [1mauthor[0m={ x=[128, 128] },
  [1minstitution[0m={ x=[96, 128] },
  [1mfield_of_study[0m={ x=[128, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 11] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 143] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 271] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 27] },
  [1m(institution, rev_affiliated_with, author)[0m={ edge_index=[2, 117] },
  [1m(paper, rev_writes, author)[0m={ edge_index=[2, 141] },
  [1m(field_of_study, rev_has_topic, paper)[0m={ edge_index=[2, 4174] }
)