#### Analyzing Pytorch Geometric Metapath2vec class
- [source code](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/metapath2vec.html)  
- [docs](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.MetaPath2Vec)  

In [1]:
from typing import Dict, List, Optional, Tuple
import torch
from torch_geometric.typing import NodeType, EdgeType, OptTensor
from torch import Tensor
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor

EPS = 1e-15

In [22]:
import os
from torch_geometric.datasets import AMiner

path = os.path.join(os.getcwd(), 'data/AMiner')
dataset = AMiner(path)
data = dataset[0]

# keys = metapath
for k, v in data.edge_index_dict.items():
    print(k, list(v.shape))

metapath = [
    ('author', 'writes', 'paper'),
    ('paper', 'published_in', 'venue'),
    ('venue', 'publishes', 'paper'),
    ('paper', 'written_by', 'author'),
]

('paper', 'written_by', 'author') [2, 9323605]
('author', 'writes', 'paper') [2, 9323605]
('paper', 'published_in', 'venue') [2, 3194405]
('venue', 'publishes', 'paper') [2, 3194405]


    edge_index_dict (Dict[Tuple[str, str, str], Tensor]): Dictionary
        holding edge indices for each
        :obj:`(src_node_type, rel_type, dst_node_type)` present in the heterogeneous graph.
    embedding_dim (int): The size of each embedding vector.

    metapath (List[Tuple[str, str, str]]): The metapath described as a list
        of :obj:`(src_node_type, rel_type, dst_node_type)` tuples.
    
    walk_length (int):
        The random walk length. Ex) 100

    context_size (int):
        = window size
        The actual context size which is considered for
        positive samples. This parameter increases the effective sampling
        rate by reusing samples across different source nodes.

    walks_per_node (int, optional): The number of walks to sample for each
        node. (default: :obj:`1`)

    num_negative_samples (int, optional): The number of negative samples to
        use for each positive sample. (default: :obj:`1`)

    num_nodes_dict (Dict[str, int], optional):
        Dictionary holding the number of nodes for each node type.
        (default: :obj:`None`)
        Ex) {'paper': 3194405, 'author': 1693531, 'venue': 3883}

    sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
        weight matrix will be sparse. (default: :obj:`False`)

In [33]:
# set arguments for example
edge_index_dict = data.edge_index_dict

embedding_dim = 128
walk_length = 50
context_size = 7
walks_per_node = 5
num_negative_samples = 5
sparse = True

In [9]:
# create num_nodes_dict if None
num_nodes_dict = {}
for keys, edge_index in edge_index_dict.items():
    key = keys[0]
    N = int(edge_index[0].max() + 1)
    num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

    key = keys[-1]
    N = int(edge_index[1].max() + 1)
    num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

In [10]:
print(num_nodes_dict)

{'paper': 3194405, 'author': 1693531, 'venue': 3883}


In [12]:
# create adj_dict based on each metapath
adj_dict = {}
for keys, edge_index in edge_index_dict.items():
    sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=sizes)
    adj = adj.to('cpu')
    adj_dict[keys] = adj

In [13]:
print(adj_dict)
# P A -> (num_papers, num_authors) = (3194405, 1693531)
# P V -> (num_papers, num_venues) = (3194405, 3883)

{('paper',
  'written_by',
  'author'): SparseTensor(row=tensor([      0,       1,       2,  ..., 3194404, 3194404, 3194404]),
              col=tensor([     0,      1,      2,  ...,   4393,  21681, 317436]),
              size=(3194405, 1693531), nnz=9323605, density=0.00%),
 ('author',
  'writes',
  'paper'): SparseTensor(row=tensor([      0,       0,       0,  ..., 1693528, 1693529, 1693530]),
              col=tensor([      0,   45988,  124807,  ..., 3194371, 3194387, 3194389]),
              size=(1693531, 3194405), nnz=9323605, density=0.00%),
 ('paper',
  'published_in',
  'venue'): SparseTensor(row=tensor([      0,       1,       2,  ..., 3194402, 3194403, 3194404]),
              col=tensor([2190, 2190, 2190,  ..., 3148, 3148, 3148]),
              size=(3194405, 3883), nnz=3194405, density=0.03%),
 ('venue',
  'publishes',
  'paper'): SparseTensor(row=tensor([   0,    0,    0,  ..., 3882, 3882, 3882]),
              col=tensor([2203069, 2203070, 2203071,  ...,  952391,  95239

In [16]:
types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
types = sorted(list(types))
print(types)

['author', 'paper', 'venue']


In [17]:
# incorporate different types into one long line
# author: 0 ~ 1693531
# paper : 1693531 ~ 4887936
# venue: 4887936 ~ 4891819
count = 0
start, end = {}, {}
for key in types:
    start[key] = count
    count += num_nodes_dict[key]
    end[key] = count

In [151]:
print(metapath)

[('author', 'writes', 'paper'), ('paper', 'published_in', 'venue'), ('venue', 'publishes', 'paper'), ('paper', 'written_by', 'author')]


In [163]:
# set offset
# start point of author
offset = [start[metapath[0][0]]]    # [0]

# add start points of paper, venue, paper, author
# [0, 1693531, 4887936, 1693531, 0] = offset
offset += [start[keys[-1]] for keys in metapath]

# repeat offset 15times -> length: 65
offset = offset * int((walk_length / len(metapath)) + 1)

offset = offset[:walk_length + 1]    # only use up to walk_length+1
assert len(offset) == walk_length + 1    # length: 51
offset = torch.tensor(offset)

In [None]:
# store every embedding of nodes
# + 1 denotes a dummy node used to link to for isolated nodes.
embedding = Embedding(count + 1, embedding_dim, sparse=sparse)
dummy_idx = count

In [95]:
# forward method: return batch embeddings
# Returns the embeddings for the nodes in batch of type node_type
# Ex) start['paper'], end['paper'] = (1693531, 4887936)

# node_type: str, batch: OptTensor
# index_select method ref: https://pytorch.org/docs/stable/generated/torch.index_select.html

emb = embedding.weight[start[node_type]:end[node_type]]
batch = torch.LongTensor([0,1,2,3,4,5,6,7])
output = emb if batch is None else emb.index_select(dim=0, index=batch)
print(list(output.shape))

[8, 128]


---

##### Order: sample function -> _pos_sample, _neg_sample -> _sample -> loader

In [112]:
def sample(src: SparseTensor, subset: Tensor, num_neighbors: int, dummy_idx: int) -> Tensor:
    # sparse tensor는 row <-> col의 상호관계를 명시한다.
    # src에서 batch에 해당하는 row가 상호작용한 col 중에서 num_neighbors 만큼 각각 샘플링한다.
    # 만약 상호작용한 결과가 존재하지 않는다면 dummy index를 추출한다.
    # subset: an index that determines which part of src to take as a sample
    # col: result of extracting the interacting elements for each batch element

    mask = subset < dummy_idx
    rowcount = torch.zeros_like(subset)

    # if size of SparseTensor: (r, c)
    # then SparseTensor.storage.rowcount() shape: (r, )
    # meaning rowcount represents the number of edges b/w each row and every column
    # sample of rowcount
    rowcount[mask] = src.storage.rowcount()[subset[mask]]
    mask = mask & (rowcount > 0)

    # sample of rowptr
    offset = torch.zeros_like(subset)
    offset[mask] = src.storage.rowptr()[subset[mask]]

    # Ex)
    # rowcount: ([ 32, 101,  64, 120,  25,  23,  90,  63])
    # -> 0 interacts with 32 neighbors
    # offset: ([  0,  32, 133, 197, 317, 342, 365, 455])
    # -> 0: 0~31, 1: 32~132, ... in src.storage.row() = rowptr
    rand = torch.rand((rowcount.size(0), num_neighbors), device=subset.device)    # Ex: (8, 1)
    rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
    rand = rand.to(torch.long)    # Ex: [18, 92, 16, 102, 14, 12, 28, 15]
    rand.add_(offset.view(-1, 1))

    col = src.storage.col()[rand]
    col[~mask] = dummy_idx    # Ex: [739183, 2803719, 1982397, 633236, 1222151, 1562057, 323228, 2651545]
    return col

In [80]:
# sample example
adj = adj_dict[('author', 'writes', 'paper')]
b = sample(src=adj, subset=batch, num_neighbors=1, dummy_idx=dummy_idx)

- pos_sample method

In [181]:
batch = torch.LongTensor([0,1,2,3,4,5,6,7])
batch = batch.repeat(walks_per_node)    # batch X walks_per_node

# random walks
rws = [batch]
for i in range(walk_length):
    keys = metapath[i % len(metapath)]
    adj = adj_dict[keys]
    # prev batch becomes input of sample function inside the loop -> random walk
    batch = sample(adj, batch, num_neighbors=1, dummy_idx=dummy_idx).view(-1)
    rws.append(batch)

# stack
rw = torch.stack(rws, dim=-1)    # (batch_size*walks_per_node, walk_length)

# follow the pre-defined metapath by adding offset: ([0, 1693531, 4887936, 1693531, 0])
rw.add_(offset.view(1, -1))
rw[rw > dummy_idx] = dummy_idx    # if index is greater than "count", change it to "count"

# chunk by context_size
walks = []
num_walks_per_rw = 1 + walk_length + 1 - context_size

for j in range(num_walks_per_rw):
    walks.append(rw[:, j:j+context_size])
output = torch.cat(walks, dim=0)    # (batch_size * walks_per_node * num_walks_per_rw, context_size)

- neg_sample method

In [None]:
# the same way as pos_sample method
batch = torch.LongTensor([0,1,2,3,4,5,6,7])
batch = batch.repeat(walks_per_node * num_negative_samples)

rws = [batch]
for i in range(walk_length):
    keys = metapath[i % len(metapath)]
    batch = torch.randint(low=0, high=num_nodes_dict[keys[-1]], size=(batch.size(0), ), dtype=torch.long)
    rws.append(batch)

rw = torch.stack(rws, dim=-1)
rw.add_(offset.view(1, -1))

walks = []
num_walks_per_rw = 1 + walk_length + 1 - context_size
for j in range(num_walks_per_rw):
    walks.append(rw[:, j:j + context_size])
output = torch.cat(walks, dim=0)

- _sample method

In [None]:
def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]:
    if not isinstance(batch, Tensor):
        batch = torch.tensor(batch, dtype=torch.long)
    return _pos_sample(batch), _neg_sample(batch)

- loader method

In [None]:
# collate_fn
# merges a list of samples to form a mini-batch of Tensor(s). 
# Used when using batched loading from a map-style dataset.

In [None]:
def loader(self, **kwargs):
    """
    Returns the data loader
    that creates both positive and negative random walks on the heterogeneous graph.

    **kwargs (optional):
        Arguments of torch.utils.data.DataLoader
        such as batch_size, shuffle, drop_last or num_workers
    """
    # starts with the beginning of metapath
    return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]), collate_fn=_sample, **kwargs)

---

##### loss, test method

loss function is as follows:

$$ log \sigma (X_{c_t} \cdot X_v) + \Sigma_{m=1}^M [1 - log \sigma(X_{u^m} \cdot X_v)] $$  

$v$ is start node, $M$ means the number of negative nodes drawn

In [None]:
def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
    # Computes the loss given positive and negative random walks.

    # Positive loss.
    start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()

    h_start = self.embedding(start).view(
        pos_rw.size(0), 1, self.embedding_dim)
    h_rest = self.embedding(rest.view(-1)).view(
        pos_rw.size(0), -1, self.embedding_dim)

    out = (h_start * h_rest).sum(dim=-1).view(-1)
    pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()

    # Negative loss.
    start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()

    h_start = self.embedding(start).view(
        neg_rw.size(0), 1, self.embedding_dim)
    h_rest = self.embedding(rest.view(-1)).view(
        neg_rw.size(0), -1, self.embedding_dim)

    out = (h_start * h_rest).sum(dim=-1).view(-1)
    neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()

    return pos_loss + neg_loss

In [None]:
def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor,
            test_y: Tensor, solver: str = "lbfgs", multi_class: str = "auto",
            *args, **kwargs) -> float:
    # Evaluates latent space quality via a logistic regression downstream task.
    from sklearn.linear_model import LogisticRegression

    clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
                                **kwargs).fit(train_z.detach().cpu().numpy(),
                                            train_y.detach().cpu().numpy())
    return clf.score(test_z.detach().cpu().numpy(),
                        test_y.detach().cpu().numpy())