# Intro

The explosion of interest and progress in natural language processing (NLP) and large language models (LLMs) is well-documented, with terms like “tokenization” and “transformers” can be seen everywhere. Yet, the equally potent realm of graph neural networks (GNNs) and specifically knowledge graph embeddings remains much less popular, outside specialized circles. These technologies offer powerful solutions across various applications, from recommendation systems to link prediction and node classification.

One of the techniques in the GNN world that, in my opinion, requires much more attention is NodePiece tokenization: a novel approach that imports several concepts from NLP to enhance the functionality of graph neural networks. This technique employs a finite set of universal “tokens” to represent nodes within a graph. This approach eliminates the need for a predefined vocabulary of IDs, allowing for a more adaptable representation of nodes and improving the model's ability to generalize across different graphs.

Despite its potential, the NodePiece tokenization methodology is not as widely discussed as e.g., LLMs. This blog post seeks to demystify NodePiece tokenization, offering a clear, intuitive explanation and a practical Python implementation for hands-on learning, as the existing implementations are quite complex and built-into pre-existing libraries.

By the end of this post, you will:

1. Grasp the fundamental concepts of NodePiece tokenization.
2. Comprehend the rationale and principles underpinning this technique.
3. Gain the skills to implement a basic version of NodePiece tokenization in Python.

## Jupyter notebook disclaimer

This is the notebook accompanying the blogpost about the NodePiece graph algorithm. You can read this notebook as a standalone post, or just refer to the code fragments mentioned in the text.

For clarity, longer modules (like neural network itself, or the training loop) are stored in separate modules, in the same directory.

# Lib imports

In [1]:
import os

# Set variable for CUDA determinism

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

In [2]:
import networkx as nx
import numpy as np
import torch as th
import torch.nn as nn
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
import torch_geometric.data as pyg_data
import torch_geometric.datasets as pyg_datasets
import torch_geometric.utils as pyg_utils
import pytorch_lightning as pl
import torcheval.metrics.functional as tmf
import matplotlib.pyplot as plt
import pickle as pkl
import random
import inspect

from collections import defaultdict
from dataclasses import dataclass
from tqdm.auto import tqdm
from typing import List, Tuple
from networkx import Graph

import models as models
import tokenization as tok
from tokenizers import Tokenizer

In [3]:
device = 'cuda'

# Part 1: The magic of building blocks

## The NLP world - history

Reflecting on the evolution of NLP, early stages were dominated by methodologies like word embeddings, lemmatization, and tokenization. Pioneering algorithms such as Word2Vec **(Church, 2017)** and GloVe **(Brochier et al., 2019)** set the standard, operating under a relatively straightforward process:

1. Compile a large corpus of text.

2. Preprocess it through tokenization, lemmatization, and similar techniques.

3. Construct embedding lookups that map the processed words to unique IDs.

This approach, however, required the allocation of massive embedding matrices, which were both space-consuming and computationally expensive.

For instance, consider an embedding matrix defined as:

$$\text{Emb matrix} \in \mathbb{R}^{V \times E}$$

Where V was the size of the vocabulary, and E was the size of the embedding. A vocabulary of **100,000 words** with an embedding size of **300** would require the allocation of **30 million floats**, a significant demand on resources.

## The advent of transformer tokenizers

The introduction of transformers and their associated tokenization methods revolutionized the field. Techniques like Byte-Pair Encoding (BPE) **(Sennrich et al., 2016)** introduced the concept of sub-words, akin to building blocks or alphabets, significantly streamlining the tokenization process. These sub-words or tokens are not only more compact than whole words but also universal, enabling their application across various languages and the seamless introduction of new vocabulary.
Consider the sentence: 

> “Modern tokenizers revolutionized the way we process text these days.”


Traditional methods would tokenize this into separate words and assign each an embedding. However, a modern tokenizer might segment it into sub-words, as shown below:

In [4]:
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
for tok in tokenizer.encode("Modern tokenizers revolutionized the way, how we process text these days").tokens:
    print(tok)

  from .autonotebook import tqdm as notebook_tqdm


[CLS]
modern
token
##izer
##s
revolution
##ized
the
way
,
how
we
process
text
these
days
[SEP]


Here, the word "tokenizers" is split into "token" and "##izers", "revolutionized" into "revolution" and "##ized", and so forth. 

This removes the need for unique IDs for each new word like "tokenizer", as "token" and "##izers" are already in the vocabulary, enabling the construction of other words' representations while conserving embedding memory space.

## The world of graphs

In the realm of graph processing and knowledge graph inference, the adoption of concepts analogous to those used in modern NLP tokenization has been slow to materialize. Traditional algorithms for knowledge graph reasoning, such as TransE **(Bordes et al., 2013)** and RotatE **(Sun et al., 2019)**, have predominantly relied on mapping entities and relations to unique embeddings. This approach, while straightforward, is markedly memory-intensive, each entity and relation requiring its unique identifier within the embedding space - just like words in the word2vec or older NLP solutions!.

<div class="alert alert-block alert-info">
<b>Tip:</b> For those interested in learning more about TransE, RotatE, and similar models, I highly recommend a post on Medium.com <a href="https://medium.com/stanford-cs224w/simple-schemes-for-knowledge-graph-embedding-dd07c61f3267">Simple Schemes for Knowledge Graph Embedding by Preston Carlson</a>. This article provides a comprehensive overview of the principles behind these models.
</div>

The search for scalable and effective solutions within the graph domain has been protracted, yet fruitful. **NodePiece** emerges as a one of the pioneering algorithms in this context, **drawing significant inspiration from the advancements in NLP tokenization**. By **applying the principles of modern tokenization techniques to graph structures, NodePiece offers a novel approach to representing graph entities and relationships**, marking a significant leap towards more memory-efficient and generalizable models in the knowledge graph domain.

# Part 2: The NodePiece algorithm

## Contributors and sources

The NodePiece algorithm represents a significant advancement in the domain of knowledge graph embeddings, focusing on compositional and parameter-efficient representations of large knowledge graphs. Introduced in the paper **"NodePiece: Compositional and parameter-efficient representations of large knowledge graphs" (Galkin et al., 2021)**, this approach has been integrated into the **widely utilized Python library, PyKeen**.

<a href="https://arxiv.org/pdf/2106.12144.pdf">Galkin, M., Denis, E., Wu, J., & Hamilton, W. L. (2021). Nodepiece: Compositional and parameter-efficient representations of large knowledge graphs. arXiv preprint arXiv:2106.12144.</a>

PyKeen implementation can be found <a href="https://pykeen.readthedocs.io/en/stable/">here</a>.

Additionally, the implementation authored by the paper's contributors is accessible on GitHub for those interested in exploring the codebase further.

>  <a href="https://github.com/migalkin/NodePiece">Github: Compositional and Parameter-Efficient Representations for Large Knowledge Graphs (ICLR'22) </a>


For a comprehensive and accessible introduction to NodePiece from the theoretical point of view, I recommend reading the Medium blog post **"NodePiece: Tokenizing Knowledge Graphs" by Michael Galkin**, one of the co-authors of the original paper. This article offers an invaluable deep dive into the algorithm, providing insights directly from its developers.


<div class="alert alert-block alert-info">
<b>NodePiece Blog Post:</b> The blog post <a href="https://towardsdatascience.com/nodepiece-tokenizing-knowledge-graphs-6dd2b91847aa">NodePiece: Tokenizing Knowledge Graphs by Michael Galkin</a> stands out as a thorough explanation of the NodePiece algorithm, authored by one of its creators. This is an excellent resource for those seeking to understand the model beyond the original academic paper.
</div>



## Basic concepts

The NodePiece algorithm is founded on a set of core assumptions that **shift away from the traditional requirement of assigning a unique identifier to each entity within a knowledge graph**. Instead, entities are depicted through a combination of basic building blocks, drawing parallels to the tokenization concepts in NLP. These building blocks comprise:

1. **Positional Features**: The proximity of nodes to designated anchor nodes.
2. **Relational Features**: The types of relationships in which the nodes are involve

### Positional features

Nodes are characterized **by their distances to a set of predetermined anchor nodes**. The method for selecting these anchor nodes varies, including options like **choosing the most connected nodes, employing clustering algorithms, or even random selection.**

Regardless of the selection technique, the underlying principle remains straightforward — **each node is defined by its distance to the K nearest anchor nodes**. With a strategic choice of anchor nodes, there's a high likelihood that the proximity to the nearest anchors provides a unique “fingerprint” for each node.

Consider the following example to illustrate this concept:

```mermaid
flowchart LR
    n1((node1)) --- n2((node2))
    n1 --- n3
    n2 --- n3
    n2 --- a1(((Anchor 1)))
    n3((node3)) --- n4((node4))
    n4 --- a1
    n4 --- n5((node5))
    n5 --- a2(((Anchor 2)))
    n2 --- a2
    a1 --- a2
```


The table below summarizes distances of each node to anchors:


| Node  | Anchor 1 | Anchor 2 |
|-------|----------|----------|
| node1 | 2        | 2        |
| node2 | 1        | 1        |
| node3 | 2        | 3        |
| node4 | 1        | 2        |
| node5 | 2        | 1        |


### Relational features

In heterogeneous knowledge graphs, edges encapsulate diverse types of relationships between nodes. The NodePiece algorithm leverages these relationships as integral components of node representation, enriching the understanding of each node's context and connectivity within the graph. To illustrate how relational features are incorporated, let's revisit and expand upon our previous example, this time including relation types:


```mermaid
flowchart LR
    n1((node1)) -- r1 --- n2((node2))
    n1 -- r2 --- n3
    n2 -- r1 --- n3
    n2 -- r3 --- a1(((Anchor 1)))
    n3((node3)) -- r2 --- n4((node4))
    n4 -- r1 --- a1
    n4 -- r3 --- n5((node5))
    n5 -- r2 --- a2(((Anchor 2)))
    n2 -- r3 --- a2
    a1 -- r1 --- a2
```

We can summarize the relations of each node in the following table:

| Node  | relationships |
|-------|---------------|
| node1 | r1, r2            |
| node2 | r1, r1, r3            |
| node3 | r1, r2, r2            |
| node4 | r1, r2, r3            |
| node5 | r2, r3            |

Or with couting each relation type:

| Node  | r1 | r2 | r3 |
|-------|----|----|----|
| node1 | 1  | 1  | 0  |
| node2 | 2  | 0  | 1  |
| node3 | 1  | 2  | 0  |
| node4 | 1  | 1  | 1  |
| node5 | 0  | 1  | 1  |

Probably you see now, where it is going...

### The unique fingerprint - an intution

The essence of the NodePiece algorithm lies in its ability to synthesize two distinct features—positional and relational—into a cohesive representation for each node. This amalgamation involves concatenating the proximity to the nearest anchor nodes with the types of relations a node engages in, thereby crafting a singular vector emblematic of the node's unique “fingerprint.”
To conceptualize this process, consider representing node 2 through the following simplified schema

```mermaid
flowchart TB
    a1(((Anchor 1))) --Distance=1--> ae[Anchor dist. vector]
    a2(((Anchor 2))) --Distance=1--> ae[Anchor dist. vector]
    r1 --> re[Relation context]
    r3 --> re[Relation context]
    re --Concatenate+embed--> Enc["Fingerprint" vector]
    ae --Concatenate+embed--> Enc["Fingerprint" vector]
```

## Part 3: A formal definition

There are some important details to mention, before we start coding the NodePiece algorithm. For example:

1. **Selective Anchor Usage**: Not every anchor is pertinent to each node's representation. Only the nearest k anchors are considered for embedding any given nod
2. **Relational Sampling**: Similarly, a node's relational context is derived from sampling among its immediate outgoing relations, limited to *m* relations per node.
3. **Handling Disconnections**: In instances where a node is disconnected or lacking links to any specified anchor or relation - a special `[DISCONNECTED]` token is employed, analogous to the `<OOV>` (out-of-vocabulary) token in NLP scenarios.

Now it is time for some math.
Let's start with whatis the input to NodePiece :

In [5]:
%%latex
\begin{align*}
\text{Given: }& \\
&KG = \{N, E, R\} \text{, with } \\
    &|N| \text{ - number of nodes} \\
    &|E| \text{ - number of edges} \\
    &|R| \text{ - number of relations} \\
    &A \text{ - set of anchors selected from nodes} \\
    &V = R + A \text{- vocabulary for NodePiece: relations + anchors} \\
    &k \text{ - number of anchors to sample for each node} \\
    &m \text{ - number of immediate outgoing relations to sample from node} \\
    &d \text{ - vector size, that NodePiece will use to construct embeddings. Embedding size} \\
     \end{align*}

<IPython.core.display.Latex object>

We can see, that the NodePiece is designed to work with:
1. **Knowledge graphs** - represented as a set of nodes (N) connected via some edges (E) that belong to particular relations (R).
2. **Selected anchor nodes (A)** - they will be a part of each nodes' "fingerprint" represetnation.
3. **Relations (R) and anchors (A) form a vocabulary (V) for the model**.
4. As stated in the paper - there is no need to use **all anchors for every node**, therefore *k* anchors are sampled.
5. The same applies to relations - only *m* relations are sampled for each node.

### Set-based form

With the foundational elements at our disposal, we can move on to a formal definition for the unique “fingerprint” of each node within the graph. This fingerprint is divided into three  components:

1. **Anchor Set**: A selection of k anchors randomly selected from the set of all anchors (A).
2. **Anchor Distances**: The shortest path distances (SPDs) to these k closest anchors, as determined for the node. Should an anchor be unreachable, its distance is denoted with a predefined "magic value." (like -1 or other token).
3. **Relational Context**: A subset of m direct outgoing relations sampled for the node, encapsulating its immediate relational environment.

Expressed formally, the representation for each node u in the vertex set V can be described as follows:

In [4]:
%%latex
\begin{align*}
\text{For each node } u &\in V: \\
    & \{a_u\}^k = \{a \mid \forall a \in sample(A, k)\} \\
    & \{z_u\}^k =\{SPD(u, a) \mid a \in \{a_u\}^k \} \\
    & \{r_j\}^m = \{r \mid r \in sample(out(u), m)\} \\[1em]
    & hash(u) = \big[\{a_u\}^k,  \{z_u\}^k, \{r_j\}^m\big]
\tag{1}
\end{align*}


<IPython.core.display.Latex object>

The authors of the original paper advocate for the application of positional encoding to distances, thereby mapping each distance to a vector of dimension *d*. This approach ensures the maintenance of the embedding's intended dimensionality.

In [5]:
%%latex
\begin{align*}
f_z &= z_u \rightarrow \mathbb{R}^d \\
\end{align*}

<IPython.core.display.Latex object>

For both the anchors and their corresponding distances, an embedding lookup strategy is proposed. This implies that each anchor id (e.g., anchor=0) is associated with a specific row in the embedding matrix (e.g., row 0 for anchor 0, row 1 for anchor 1, up to row k for the last anchor), and similarly for distances (e.g., distance = 1 corresponds to embedding 1, etc.). This method facilitates an efficient and meaningful encoding of both the positional and relational aspects of each node's unique signature within the graph. 

**Bear in mind, that the number of anchors and anchor distances is much smaller than the population of unique node IDs, used by traditional methods.**

### Embedding of anchors, distances and relations

The transformation of sets delineated in equation (1) into a vectorial representation through embedding involves several key steps:

1. Converting the anchors for a node, denoted as $\{a_u\}^k$ into a vector $\mathbf{A} \in \mathbb{R}^{k\times d}$ mbedding these anchor identifiers into a  *d*-dimensional space.
2. Transforming the anchor distances, $\{z_u\}^k$ into a vector $\mathbf{Z} \in \mathbb{R}^{k \times d}$ where each distance is treated akin to an identifier within the embedding vector, effectively encoding the proximity to sampled anchors.
3. Mapping the relational context, $\{r_u\}^m$ into a vector $\mathbf{R} \in \mathbb{R}^{m \times d}$ which reflects the identifiers of sampled relations relevant to the node.

Expressed formally, the comprehensive node hash function can be represented as follows:



In [2]:
%%latex
\begin{align*}
\mathbf{hash}(u) &= \big[\mathbf{a_u}, \mathbf{z}_u,  \mathbf{r}_u \big] \tag{2} \\
\text{where: } \\
\mathbf{a_u} &\in \mathbb{R}^{k \times d} \text{ - k anchor ids sampled for node}\\
\mathbf{z_u} &\in \mathbb{R}^{k \times d} \text{ - k anchor distances }\\
\mathbf{r_u} &\in \mathbb{R}^{m \times d} \text{ - m outgoing relations} \\
\end{align*} 

<IPython.core.display.Latex object>

After that, authors sum vectors related to anchors, so that the matrix node representation is as follows:

In [7]:
%%latex
\begin{align*}
\mathbf{hash}(u) = [\mathbf{a_u} + \mathbf{z_u}, \mathbf{r_u}] = [\mathbf{\hat{a}_u}, \mathbf{r}_u] 
\in \mathbb{R}^{(k+m) \times d} \tag{3}
\end{align*}

<IPython.core.display.Latex object>

### Encoding

The matrix delineated in equation (3) facilitates the transformation into a  vector for each node *u* through the application of an appropriate encoder—either an MLP (Multilayer Perceptron) or a Transformer, as explored in the original study.

This encoder operates by converting the matrix into a "flat" vector representation for each node, thereby distilling complex relational and positional information into a condensed form.

In [6]:
%%latex

\begin{align*}
enc: \mathbb{R}^{(k+m) \times d} \rightarrow \mathbb{R}^{d} \tag{4}
\end{align*}

<IPython.core.display.Latex object>

To illustrate this encoding process with a practical example, let's consider the representation for **node 2**:

1. **Anchors are assigned identifiers** starting from zero, hence `id(Anchor 1) = 0` and `id(Anchor 2) = 1`.
2. **Relations receive identifiers similarly**, so `id(r1) = 0`, `id(r2) = 1`, and `id(r3) = 2`.
3. **Distances are indexed in the same manner**, with ``id(dist=1) = 0`, `id(dist=2) = 1`, `id(dist=3) = 2`, etc.
4. **With k=m=2**, two anchors and two relations are sampled for each node.

Given these mappings and assuming a dimensional space d = 3, the representation according to equation (2) will look like this:

```mermaid
flowchart LR
    n1((node1)) -- r1 --- n2((node2))
    n1 -- r2 --- n3
    n2 -- r1 --- n3
    n2 -- r3 --- a1(((Anchor 1)))
    n3((node3)) -- r2 --- n4((node4))
    n4 -- r1 --- a1
    n4 -- r3 --- n5((node5))
    n5 -- r2 --- a2(((Anchor 2)))
    n2 -- r3 --- a2
    a1 -- r1 --- a2
```

```mermaid
flowchart TB
    a1(((Anchor 1))) --Distance=1--> ae[Anchor dist. vector]
    a2(((Anchor 2))) --Distance=1--> ae[Anchor dist. vector]
    r1 --> re[Relation context]
    r3 --> re[Relation context]
    re --Concatenate+embed--> Encoding
    ae --Concatenate+embed--> Encoding
```

In [7]:
%%latex
\begin{align*}
    & \{a_u\}^k = \{0, 1 \} \\
    & \{z_u\}^k =\{1, 1 \} \\
    & \{r_j\}^m = \{0, 2 \} \\[1em]
    & hash(\text{node}_2) = \big[\{0,1\},  \{1, 1 \}^k, \{0, 2 \}^m\big]
\end{align*}

<IPython.core.display.Latex object>

Subsequently, d-dimensional embeddings are applied to each anchor, distance, and relation, mapping them to a vector of size *d*. For instance, if *d=3*, the resulting matrices for anchors, distances, and relations, aligned with equation (2), would be:

In [8]:
%%latex
\begin{align*}
\mathbf{a}_u &= \begin{bmatrix}
    0 & 0 & 1 \\
    0 & 1 & 0 \end{bmatrix}  &\text{ 2x3 matrix for anchors 0, 1 }\\[1em]
\mathbf{z}_u &= \begin{bmatrix}
    1 & 1 & 1 \\
    1 & 1 & 1 \end{bmatrix}  &\text{ 2x3 matrix for anchor distances 1, 1 }\\[1em]
\mathbf{r}_u &= \begin{bmatrix}
    0 & 0 & 0 \\
    2 & 2 & 2 \end{bmatrix}  &\text{ 2x3 matrix for outgoing relations 0, 2 }\\[1em]
\mathbf{hash}(\text{node}_2) &= \big[\mathbf{a_u}, \mathbf{z}_u,  \mathbf{r}_u \big] \\
\end{align*}

<IPython.core.display.Latex object>

Upon summing the vectors related to anchors and then concatenating them as prescribed in equation (3), we derive:

In [3]:
%%latex
\begin{align*}
\mathbf{\hat{a}_u} = \mathbf{a_u} + \mathbf{z_u}, \mathbf{r_u} &= \begin{bmatrix}
    1 & 1 & 2 \\
    1 & 2 & 1 \end{bmatrix}  &\text{ 2x3 matrix for node 2 }\\[1em]

\mathbf{hash}(\text{node}_2) &= \big[\mathbf{\hat{a}_u}, \mathbf{r}_u \big] \\

\end{align*}

<IPython.core.display.Latex object>

This step shows the encoding's ability to aggregate and simplify the complex array of anchors, distances, and relational data into a **unified, dense representation for each node**, making it readily consumable for downstream graph processing tasks.

### What to do with it?

From this point now on, a user can proceed with any downstream task, like classification, clustering, etc.

**The representation of each node is now a flat vector, that can be used like any other features**.

It should capture unique properties of the node, like its position in the graph, its relations, etc.

The paper itself mentions multiple optimizations, tricks and improvements, that can be applied to the basic NodePiece algorithm, that were skipped in this simplified description. I strongly encourage you to read the original paper, if you are interested in all the details.

# Part 4: The (simplified) implementation

The existing implementations of the NodePiece algorithm, as provided by the original paper's authors and within the PyKeen library, are comprehensive yet complex. **These versions are optimized for performance and integration within existing frameworks, which, while beneficial for application development, may obscure the conceptual bridge between the algorithm's theoretical basis and its practical code representation. This complexity can pose challenges for those seeking to understand  the algorithm.**

In light of the limited availability of implementations tailored for educational use or for those aiming to gain a more in-depth understanding of the algorithm, I have opted to develop a simplified version of NodePiece. This version is designed with clarity and extensibility in mind, offering a more accessible entry point for individuals looking to grasp the fundamental mechanics of the algorithm without the overhead of optimization and framework-specific considerations.

This streamlined implementation comprises two primary components:

1. **Tokenization Module**: This segment of the code is responsible for selecting anchors and relations, as well as constructing the node representations with identifiers for anchors, relations, and distances. It aligns with the processes outlined in equations (1), encapsulating the initial steps of generating a node's unique “fingerprint” through its relationship to anchors and its participation in various relations.
2. **Models Module**: This section is tasked with embedding the identifiers into a vector space and constructing prediction models. It embodies the subsequent phase of the NodePiece algorithm, where the abstract representations of nodes are transformed into dense vector embeddings that can be utilized in downstream machine learning tasks.


## The data

For demonstration purposes, this tutorial will utilize a **small subset of the FB15k-237 dataset**. The rationale behind using a reduced version of the dataset is practicality: training a model on the full dataset locally could be prohibitively time-consuming. The subset is meticulously curated to ensure that the resultant graph remains consistent and connected, thus preserving the structural integrity and relational complexity characteristic of the full dataset. This approach allows for an expedited yet insightful exploration of the NodePiece algorithm, making it feasible to experiment with and extend the simplified implementation on a more manageable scale.

In [4]:
def get_data_subset(data: pyg_data.Data, node_count: float|int = 0.001, k_hops:int=2, relabel_nodes: bool = True, num_nodes: int = None) -> pyg_data.Data:
    max_nodes = max(data.edge_index[0].max(), data.edge_index[1].max())
    if type(node_count) == float:
        node_count = int(max_nodes * node_count)

    selected_nodes = random.sample(range(max_nodes), node_count)
    subset, edge_index, mapping, edge_mask = pyg_utils.k_hop_subgraph(
        selected_nodes,
        edge_index=data.edge_index,
        num_hops=k_hops,
        relabel_nodes=relabel_nodes,
        num_nodes=num_nodes)
    subset_data = pyg_data.Data(
        edge_index=edge_index,
        edge_type=data.edge_type[edge_mask])
    return subset_data

In [5]:
path = './data/'
train_data_orig = pyg_datasets.FB15k_237(path, split='train')[0]
val_data_orig = pyg_datasets.FB15k_237(path, split='val')[0]
test_data_orig = pyg_datasets.FB15k_237(path, split='test')[0]

In [6]:
pyg.seed_everything(123)

train_sset = get_data_subset(train_data_orig, node_count=15, k_hops=2, relabel_nodes=False)
train_sset

Data(edge_index=[2, 123816], edge_type=[123816])

In [7]:
train_sset.num_edge_types

237

In [8]:
val_is_in_train = th.isin(val_data_orig.edge_index[0], train_sset.edge_index[0]) | \
    th.isin(val_data_orig.edge_index[0], train_sset.edge_index[1]) | \
    th.isin(val_data_orig.edge_index[1], train_sset.edge_index[0]) | \
    th.isin(val_data_orig.edge_index[1], train_sset.edge_index[1])

test_is_in_train = th.isin(test_data_orig.edge_index[0], train_sset.edge_index[0]) | \
    th.isin(test_data_orig.edge_index[0], train_sset.edge_index[1]) | \
    th.isin(test_data_orig.edge_index[1], train_sset.edge_index[0]) | \
    th.isin(test_data_orig.edge_index[1], train_sset.edge_index[1])

val_data_related_to_train = pyg_data.Data(
    edge_index=val_data_orig.edge_index[:, val_is_in_train],
    edge_type=val_data_orig.edge_type[val_is_in_train])

test_data_related_to_train = pyg_data.Data(
    edge_index=test_data_orig.edge_index[:, test_is_in_train],
    edge_type=test_data_orig.edge_type[test_is_in_train])


In [9]:
pyg.seed_everything(333)
val_sset = get_data_subset(val_data_orig, node_count=8, k_hops=2, relabel_nodes=False)
pyg.seed_everything(333)
test_sset = get_data_subset(test_data_orig, node_count=50, k_hops=2, relabel_nodes=False)

In [10]:
val_sset, test_sset

(Data(edge_index=[2, 402], edge_type=[402]),
 Data(edge_index=[2, 224], edge_type=[224]))

Our dataset consists of the following components:
1. Training dataset - 123'816 triplets.
2. Valdation dataset - 402 triplets.
3. Test dataset - 224 triplets.
237 unique relation types.

Now we have smaller subsets of the original datasets. This will speed up computation time .

## Tokenization

Now it is time to tokenize our graphs. We will use custom functions with simplified NodePiece logic.

We will select:
1. 30 anchors from the whole dataset.
2. Use 20 closest anchors for each node.
3. Use 10 relations for each node.

In [11]:
recreate = True
if recreate:
    train_features = tok.tokenize_graph(train_data_orig, n_anchors=30, k_nearest_anchors=20, m_relations=10, use_closest=True)
    val_features = tok.tokenize_graph(val_data_orig, n_anchors=30, k_nearest_anchors=20, m_relations=10, use_closest=True)
    test_features = tok.tokenize_graph(test_data_orig, n_anchors=30, k_nearest_anchors=20, m_relations=10, use_closest=True)

    pkl.dump(train_features, open('./train_features.pkl', 'wb'))
    pkl.dump(val_features, open('./val_features.pkl', 'wb'))
    pkl.dump(test_features, open('./test_features.pkl', 'wb'))
else:
    train_features = pkl.load(open('./train_features.pkl', 'rb'))
    val_features = pkl.load(open('./val_features.pkl', 'rb'))
    test_features = pkl.load(open('./test_features.pkl', 'rb'))

In [12]:
train_data = train_sset.to(device)
val_data = val_sset.to(device)
test_data = test_sset.to(device)

In [13]:
if device == 'cuda':
    train_features.anchor_hashes = train_features.anchor_hashes.long().cuda()
    val_features.anchor_hashes = val_features.anchor_hashes.long().cuda()
    test_features.anchor_hashes = test_features.anchor_hashes.long().cuda()

    train_features.anchor_distances = train_features.anchor_distances.long().cuda()
    val_features.anchor_distances = val_features.anchor_distances.long().cuda()
    test_features.anchor_distances = test_features.anchor_distances.long().cuda()

    train_features.rel_hashes = train_features.rel_hashes.long().cuda()
    val_features.rel_hashes = val_features.rel_hashes.long().cuda()
    test_features.rel_hashes = test_features.rel_hashes.long().cuda()

### Anchor selection

Let's start with the anchor selection function.

In [14]:
print(inspect.getsource(tok.degree_anchor_select))

def degree_anchor_select(g: nx.Graph, n_anchors: int|float = 0.1) -> Tuple[List[int], Dict[int, int]]:
    """Anchor selection method, based on the node degree. It is based on the simplest heuristic,
    where the nodes with the highest degree are selected as anchors - as they will most likely be connected
    to the most nodes in the graph.

    Parameters
    ----------
    g : nx.Graph
        Networkx graph.
    n_anchors : int | float, optional
        Number of anchors to select, by default 0.1.
        If int - the number of anchors to select.
        If float - the fraction of the nodes to select as anchors.

    Returns
    -------
    Tuple[List[int], Dict[int, int]]
        1. List of anchor nodes.
        2. Dictionary mapping anchor node to its id. Anchor ids are in the range [0, n_anchors).
    """
    if type(n_anchors) == float:
        n_anchors = int(g.number_of_nodes() * n_anchors)

    degrees = sorted(g.degree, key=lambda x: x[1], reverse=True)
    anchor_2_id = {}

The `degree_anchor_select` function exemplifies a straightforward yet effective approach to this task, leveraging the degree of nodes as a heuristic for anchor selection. **This method presumes that nodes with higher degrees, implying a larger number of connections, serve as optimal anchors due to their likelihood of being connected to a vast portion of the graph**. Here's a step-by-step breakdown of how the function operates:

1. **Input Parameters**: The function accepts a NetworkX graph `g` and an `n_anchors` parameter, which dictates the number of anchors to be selected. The `n_anchors` parameter can be specified as either a float, representing the proportion of the graph's nodes to be designated as anchors (with a default of 0.1, or 10%), or an integer, indicating the exact count of anchors desired.
2. **Degree Calculation and Sorting**: It computes the degree for each node within the graph. The nodes are then sorted based on their degree in descending order, prioritizing nodes with the highest degrees for selection as anchors.
3. **Anchor to ID Mapping**: An `anchor_2_id` dictionary is initialized to map each selected anchor node to a unique identifier, starting from 0. This mapping facilitates the efficient identification and utilization of anchor nodes in subsequent steps of the NodePiece tokenization process.
4. **Return Values**: The function returns a tuple comprising a list of the selected anchor nodes and the `anchor_2_id dictionary`. The list contains the nodes chosen as anchors, while the dictionary provides a mapping between these anchor nodes and their assigned IDs, enabling their direct reference in the construction of node representations.

### Building distance to K nearest anchors

Next, we will build a function to calculate the shortest path distances from each node to the K nearest anchor nodes.

In [24]:
print(inspect.getsource(tok.build_distance_to_k_nearest_anchors))

def build_distance_to_k_nearest_anchors(
        G: nx.Graph,
        anchors: List[int],
        anchor2id: dict,
        k_closest_anchors: int = 15,
        use_closest: bool = True) -> Tuple[np.ndarray, np.ndarray, int]:
    """For each node in the graph, calculate the distance to the k closest anchors.

    Parameters
    ----------
    G : nx.Graph
        Netowrkx graph.
    anchors : List[int]
        List of anchor nodes.
    anchor2id : dict
        Anchor to id mapping.
    k_closest_anchors : int, optional
        Number of k closest anchors to pick per node, by default 15
    use_closest : bool, optional
        Should closest anchors be used, or all? By default True

    Returns
    -------
    Tuple[np.ndarray, np.ndarray, int]
        Tuple of:
        1. Node to anchor distance matrix. Shape: (num_nodes, num_anchors).
        2. Node to anchor id matrix. Shape: (num_nodes, num_anchors).
        3. Maximum distance in the graph. Will be used for distance encoding/embedd

Steps are as follows:

1. **Shortest Path Calculation**: The function iterates over each anchor node, computing the shortest path distances to all other nodes within the graph. This is accomplished using NetworkX's `nx.shortest_path_length` function. The distances are collected in a dictionary node_distances, where keys represent nodes and values are the distances to each anchor. Notably, this step can be computationally intensive, particularly for larger graphs. Optimizations in implementations like PyKeen or the original NodePiece repository aim to enhance efficiency here.
2. **Matrix initialization**: Two numpy arrays, `node2anchor_dist` and `node2anchor_idx`, are initialized to hold the distance from each node to its k closest anchors, and the indices of these anchors, respectively.An `unreachable_anchor_token` is introduced, set to the total number of anchors plus one (`A+1`), to initially populate the `node2anchor_idx` array. This serves as a placeholder for anchors that are not reachable from a given node, analogous to an "out of vocabulary" token in NLP.
3. **Distance and index matrices population**: The function proceeds to sort the distances to anchors for each node, selecting the closest k anchors. The distances to and indices of these selected anchors are then stored in the respective arrays.During this process, the `max_dist` variable is updated to reflect the maximum observed distance, ensuring that any unreachable anchors are subsequently marked with a distance exceeding this maximum (`max_dist +1` - "a magic token").
4. **Handle unreachable anchors**: For nodes from which an anchor is unreachable (signified by the absence of a path to the anchor), the distance is set to `max_dist + 1`. This adjustment is made by identifying indices in `node2anchor_idx` equal to the `unreachable_anchor_token` and setting the corresponding `node2anchor_dist` entries to this incremented max distance. This mechanism effectively accounts for nodes isolated from certain anchors, maintaining the integrity of the NodePiece representation by acknowledging the potential for disconnected components within the graph.
6. **Return values**: The function concludes by returning a tuple that includes the distance matrix `node2anchor_dist`, the anchor index matrix `node2anchor_idx`, and the `max_dist` value. These outputs collectively provide a comprehensive mapping of each node's proximity to its nearest anchor points, foundational for constructing the NodePiece embedding.

### Extraction of relational context

The next step is to extract the relational context for each node.

In [25]:
print(inspect.getsource(tok.sample_rels))

def sample_rels(pyg_g: pyg_data.Data, max_rels: int = 50) -> th.Tensor:
    """Samples m outgoing relations for each node. If the node has less than m relations, it pads the output with a special token.

    Parameters
    ----------
    pyg_g : pyg_data.Data
        PyTorch Geometric graph.
    max_rels : int, optional
        Maximal number of relations to use, by default 50.

    Returns
    -------
    th.Tensor
        Matrix of relations for each node. Shape: (num_nodes, max_rels).
        Each row corresponds to specific node, each column to a relation (id).
    """
    rels_matrix = []
    missing_rel_token = pyg_g.edge_type.max() + 1
    for node in tqdm(range(pyg_g.num_nodes)):
        node_edges = pyg_g.edge_index[0] == node
        node_edge_types = pyg_g.edge_type[node_edges].unique()
        num_edge_types = len(node_edge_types)

        if num_edge_types < max_rels:
            pad = th.ones(max_rels - num_edge_types, dtype=th.long) * missing_rel_token
            padded

1. **Initialization**: The function requires a PyTorch Geometric (PyG) graph object, `pyg_g`, along with an optional `max_rels` parameter, which dictates the maximum number of relations to sample for each node. It defaults to 50. An empty list, `rels_matrix`, is prepared to hold the relational data for all nodes. Additionally, a `missing_rel_token` is defined to represent absent relations for a given node, effectively serving as an `"out of vocabulary"` token. This token is set to one more than the highest relation ID found in the graph.
2. **Unique relations discovery**: For each node, it finds the unique outgoing edge types (relations).

3. **Check number of relations**: For each node in the graph, the function identifies all unique outgoing relation types (edge types) connected to it.
4. **Relation Number Adjustment** If a node's number of unique relations is less than the `max_rels` threshold, the list of relations is padded with the `missing_rel_token` to reach the specified maximum. This ensures uniformity in the relational context length across all nodes.
Conversely, if a node is connected through more relation types than `max_rels` allows, a subset is randomly selected to conform to the limit. This random sampling is indicative of the need to balance detail with computational efficiency.
4. **Sort and store relations**: The sampled (and potentially padded) relations for each node are sorted to maintain a consistent order. The sorted list is then appended to the `rels_matrix`, gradually building a comprehensive relational context repository for the entire graph.
6. **Tensor conversion and return results**: Upon completion of the relational context extraction for all nodes, the `rels_matrix` list is transformed into a PyTorch tensor. This tensor, with shape (num_nodes x max_rels), systematically represents each node's relational context in a structured and machine-readable form.

### Feature matrices

After all this operation we are left with three matrices:
1. `anchor_distsances` - distances from each node to the K closest anchors. Dim: N nodes x K anchors.
2. `anchor_hashes` - indices of the K closest anchors for each node. Dim: N nodes x K anchors.
3. `rel_hashes` - relational context for each node. Dim: N nodes x M relations.

In [26]:
print(train_features.anchor_distances[:10, :5])
print(train_features.anchor_distances.shape)

tensor([[2, 2, 1, 2, 1],
        [2, 3, 2, 3, 2],
        [1, 2, 2, 2, 2],
        [1, 1, 2, 3, 2],
        [2, 2, 3, 2, 2],
        [1, 1, 2, 2, 1],
        [1, 1, 2, 2, 2],
        [1, 2, 2, 1, 2],
        [2, 3, 2, 2, 2],
        [3, 2, 3, 3, 3]], device='cuda:0')
torch.Size([14541, 30])


In [27]:
print(train_features.anchor_hashes[:10, :5])
print(train_features.anchor_hashes.shape)

tensor([[ 2,  4,  0,  1,  3],
        [ 0,  2,  4,  5,  7],
        [ 0,  5,  1,  2,  3],
        [ 0,  1,  5,  6, 13],
        [ 9,  0,  1,  3,  4],
        [ 0,  1,  4,  5,  7],
        [ 0,  1,  6, 13,  2],
        [ 0,  3,  6, 27,  1],
        [19, 20, 22, 28,  0],
        [19, 20, 22, 28,  1]], device='cuda:0')
torch.Size([14541, 30])


In [15]:
print(train_features.rel_hashes[:10, :5])
print(train_features.rel_hashes.shape)

tensor([[ 15,  23,  69,  72,  73],
        [237, 237, 237, 237, 237],
        [  1,  51, 107, 185, 201],
        [ 17,  25,  43,  48, 232],
        [  2, 237, 237, 237, 237],
        [ 11,  12,  13,  25,  26],
        [  3,   8,  17,  19,  24],
        [  3,   6,   8,  17,  19],
        [  4,   5,  80, 121, 237],
        [130, 178, 237, 237, 237]], device='cuda:0')
torch.Size([14541, 10])


# Model definition

For this exercise, we pivot towards a simpler but effective knowledge graph embedding model: TransE. While the original paper employed RotatE—a somewhat more intricate model—TransE offers a streamlined approach, focusing on the fundamental aspects of embedding relations within a graph.

TransE operates by evaluating a triplet comprising a `head` node, a `relation`, and a `tail` node. **It assigns a likelihood score indicating the probability of the specified relation existing between the head and tail nodes**. The model's objective is encapsulated in the optimization of a loss function designed to distinguish true triplets from artificially generated (corrupted) ones:

%%latex

\begin{align*}
\mathcal{L}= \sum_{(h,r,t) \in \mathcal{D}} \sum_{(h',r,t') \in \mathcal{D}'} \left[ \gamma + d(h+r,t) - d(h'+r,t') \right]_{+}
\tag{5}
\end{align*}

Where:
- $h, r, t$ is the set of true triplets
- $h', r, t'$ is the set of corrupted triplets (negative-sampled, non-existing relations)
- $d$ is the distance function, in case of TransE it is L1 or L2 norm.
- $\gamma$ is the margin parameter, responsible for separating positive and negative samples.


We could write it in pseudocode as follows:


```python
for (head, relation, tail) in data:
    head_embed = EMBED(head)
    rel_embed = EMED(relation)
    tail_embed = EMBED(tail)
    score = -1 * [(head_embed + rel_embed) - tail_embed]
    return score
```

We return `-1 x score` as we want to minimize the score, and maximize the likelihood of the triplet.

When interacting with NodePiece embeddings, TransE gets interesting when it comes to embedding head and tail nodes.

TransE embedding in this case will perform several steps. For each head or tail node:
1. Take its closest anchor indices and embed them.
2. Take its distances to the closest anchors and embed them.
3. Take its relational context and embed it.
4. According to the equation (3): add anchor ID embedding and distance embedding, concatenate with relational embedding into a single vector.
5. Pass this vector through the encoder (MLP or Transformer) to get the final embedding.

Our implementation will look as follows - this procedure is invoked for each head and tail node:

In [29]:
print(inspect.getsource(models.NodePieceTransE.embed_node))

    def embed_node(self, node: th.Tensor, closest_anchors, anchor_distances, rel_hash):

        # Dim: (N x K) values are anchor ids --> (N x K x D)
        anchor_embed = self.anchor_embed(closest_anchors[node])

        # Dim: (N x K) values are anchor distances --> (N x K x D)
        anchor_distances_embed = self.anchor_distances_embed(anchor_distances[node])

        # Dim: (N x M) values are relation types --> (N x M x D)
        rel_embed = self.rel_emb(rel_hash[node])

        # Dim: (N x K x D)
        combined_anchor_embed = anchor_embed + anchor_distances_embed

        # N x (K + M) x D
        stacked_embed = th.cat([combined_anchor_embed, rel_embed], dim=1)
        N, anchors_plus_rel, hidden_channels = stacked_embed.shape

        # reshape: (N x (K + M) x D) --> (N x (K + M) * D)
        flattened_embed = stacked_embed.view(N, anchors_plus_rel * hidden_channels)

        # N x (K + M) * D --> N x O
        lin_out = self.lin_layer(flattened_embed)

        return lin_o

1. `anchor_embed` is an embedding lookup for the anchor hashes. The weights matrix has dimensionality `((K anchors +1) x D embedding size)`. It takes `(N x K) `matrix of anchor hashes and returns `(N x K x D)` tensor.
2. `anchor_distances_embed` is an embedding lookup for the anchor distances. The weights matrix has dimensionality `((max distance +1) x D embedding size)`. It takes `(N x K)` matrix of anchor distances and returns `(N x K x D)` tensor.
3. `rel_embed` is an embedding lookup for the relation hashes. The weights matrix has dimensionality `((unique relations +1) x D embedding size)`. It takes `(N x M) `matrix of relation hashes and returns `(N x M x D)` tensor.

## TransE model training

We will now prepare a TransE model, wrap it in the PyTorch Lightning and train it on the FB15k-237 data subset.

### Model trianing prep

In [16]:
num_nodes = max(train_data.edge_index.max(), test_data.edge_index.max(), val_data.edge_index.max())
num_nodes

tensor(14445, device='cuda:0')

In [17]:
max_distance = max(train_features.max_distance, val_features.max_distance, test_features.max_distance) + 1
max_distance

11

In [78]:
train_loader = pyg_nn.kge.loader.KGTripletLoader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=2048,
)

val_loader = pyg_nn.kge.loader.KGTripletLoader(
    head_index=val_data.edge_index[0, :250],
    rel_type=val_data.edge_type[:250],
    tail_index=val_data.edge_index[1, :250],
    batch_size=2048,
)


In [102]:
th.use_deterministic_algorithms(True)

In [110]:
use_swa = True
swa_lr = 5e-3

pyg.seed_everything(999)

params = models.KGModelParams(
    num_nodes=train_data_orig.num_nodes,
    num_relations=train_features.n_rels+1,
    embedding_dim=200,
    max_distance=max_distance+1,
    hidden_sizes=(400,),
    num_anchors=train_features.n_anchors,
    top_m_relations=train_features.m_relations,
    device=device,
    kg_model_type=models.ModelType.TransE,
    drop_prob=0.2
)
model_pl = models.NodePiecePL(
    params,
    lr=5e-3,
    train_features=train_features, 
    val_features=val_features)

In [111]:
model_pl.model

NodePieceTransE(
  (anchor_embed): Embedding(31, 200)
  (anchor_distances_embed): Embedding(13, 200)
  (lin_layer): Sequential(
    (0): BatchNorm(8000)
    (1): Linear(in_features=8000, out_features=400, bias=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=400, out_features=200, bias=True)
  )
  (rel_emb): Embedding(238, 200)
)

In [112]:
th.set_float32_matmul_precision('high')
early_stop = pl.callbacks.early_stopping.EarlyStopping(monitor='val_mean_rank', patience=3, mode='max', min_delta=0.001)
checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(monitor='val_mean_rank', mode='max', save_top_k=1)
callbacks = [early_stop, checkpoint]
if use_swa:
    swa = pl.callbacks.StochasticWeightAveraging(swa_lrs=swa_lr)
    callbacks.append(swa)

trainer = pl.Trainer(
    accelerator=device,
    max_epochs=100,
    check_val_every_n_epoch=2,
    num_sanity_val_steps=0,
    callbacks=callbacks,
    deterministic=True
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


First - we will evaluate the model using validation set **before any training** to get the baseline performance.

### Evaluate model before training

In [113]:
nodepiece_val_before_train = trainer.validate(model_pl, dataloaders=val_loader)
nodepiece_val_before_train

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

[{'val_mean_rank_epoch': 0.000877909071277827, 'val_hits_at_k_epoch': 0.0}]

Hits@k(10) = 0.0 and mean rank of correct tails is very very low. This is expected, as the model has not been trained yet.

### Training the model

Now let's train the model. This may take some time, depending on your system and hardware.

In [114]:
trainer.fit(model_pl, train_dataloaders=train_loader, val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type            | Params
------------------------------------------
0 | model | NodePieceTransE | 3.4 M 
------------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.412    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

### Eval model after training

Now we can check the model performance after training. The checkpoint callback was used, so we can pick the best model across all iterations.


We should see a significant improvement in the Hits@k and mean rank metrics.

In [115]:
path = trainer.checkpoint_callback.best_model_path

best_model = models.NodePiecePL.load_from_checkpoint(checkpoint_path=path, model_params=params)

best_model.train_features = train_features
best_model.val_features = val_features
best_model.test_features = test_features

In [116]:
nodepiece_val_after_train = trainer.validate(best_model, dataloaders=val_loader)
nodepiece_val_after_train

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

[{'val_mean_rank_epoch': 0.2862264811992645,
  'val_hits_at_k_epoch': 0.3240000009536743}]

Indeed, we can see an improvement in the prediction quality. Of course, even for such a small subset of data it can take quite a long time to run.

To give you a sense of scale - the original paper used a much larger dataset, and the training took several hours on a GPU. One of the experiments (**table 10** in paper) took 7 hours with 400 epochs and 1'000 anchors.

# Conclusion

In this blog post, we've taken a deep dive into the world of NodePiece, a novel approach to graph neural networks that draws inspiration from the tokenization techniques used in Transformers for natural language processing. J**ust as Transformers revolutionized text analysis by breaking down text into manageable pieces, NodePiece applies a similar concept to graphs**. It uses a set of basic elements, or "tokens", to represent the various parts of a graph, making it easier to handle complex networks.

We started with an overview of how NodePiece borrows ideas from NLP's tokenization strategies, particularly those used in Transformers, to address the challenges of representing nodes in large and complex graphs. **This approach allows NodePiece to efficiently capture the essence of nodes and their relationships without needing to explicitly identify every single node, which is a significant advantage for tasks like link prediction, node classification, and more**.

The theoretical background of NodePiece was also covered, explaining how it creates a flexible and generalized way of representing nodes by focusing on their relationships and positions within the graph. This simplifies the representation of nodes and enhances the model's ability to learn from and adapt to different graphs.

Finally, we presented a simplified implementation of the NodePiece model, designed with educational purposes in mind. This implementation breaks down the concept into more understandable parts, making it easier for readers to grasp how NodePiece works and how it can be applied to real-world graph neural network tasks.

Hopefully, you find it useful, and will be able to utilize NodePiece tokenization in your graph projects!

# Bibliography

1. Bordes, A., Usunier, N., Garcia-Duran, A., Weston, J., & Yakhnenko, O. (2013). Translating embeddings for modeling multi-relational data. Advances in Neural Information Processing Systems, 26. https://proceedings.neurips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html
2. Brochier, R., Guille, A., & Velcin, J. (2019). Global Vectors for Node Representations. The World Wide Web Conference, 2587–2593. https://doi.org/10.1145/3308558.3313595
3. Church, K. W. (2017). Word2Vec. Natural Language Engineering, 23(1), 155–162.
4. Galkin, M., Denis, E., Wu, J., & Hamilton, W. L. (2021). Nodepiece: Compositional and parameter-efficient representations of large knowledge graphs. arXiv Preprint arXiv:2106.12144.
5. Sennrich, R., Haddow, B., & Birch, A. (2016). Neural Machine Translation of Rare Words with Subword Units (arXiv:1508.07909; Version 5). arXiv. https://doi.org/10.48550/arXiv.1508.07909
6. Sun, Z., Deng, Z.-H., Nie, J.-Y., & Tang, J. (2019). RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (arXiv:1902.10197; Version 1). arXiv. https://doi.org/10.48550/arXiv.1902.10197