# Shallow Node Embeddings

https://pytorch-geometric.readthedocs.io/en/latest/tutorial/shallow_node_embeddings.html

In [1]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec

In [3]:
data = Planetoid("/tmp/data/Planetoid", name="Cora")[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Node2Vec(
    data.edge_index,
    embedding_dim=128,
    walks_per_node=10,
    walk_length=20,
    context_size=10,
    p=1.0,
    q=1.0,
    num_negative_samples=1,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [10]:
dataloader = model.loader()
pos_rw, neg_rw = next(iter(dataloader))
pos_rw.shape, neg_rw.shape

(torch.Size([110, 10]), torch.Size([110, 10]))

In [11]:
pos_rw[0]

tensor([   0,  633,    0, 1862,    0, 1862,    0, 1862,  926, 1862])

In [12]:
def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in dataloader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [15]:
train()

3.2982661218507605

In [17]:
z = model()  # Full node-level embeddings.
z = model(torch.tensor([0, 1, 2]))  # Embeddings of first three nodes.

In [18]:
z

tensor([[ 1.1862e+00,  3.1759e-01, -1.4581e+00, -1.0847e+00, -5.7915e-01,
         -5.3900e-01, -3.5411e-01,  3.1054e-01, -3.8078e-01, -8.7181e-01,
         -2.9935e-01, -3.4637e-02,  8.7255e-01,  1.9898e-01, -5.6562e-01,
         -5.7564e-01,  8.3648e-02,  4.1237e-01,  7.6757e-01,  1.5138e+00,
          8.4080e-01, -4.7636e-01, -1.7962e-01, -4.4291e-01,  3.4876e-02,
         -4.7311e-01, -1.1404e-01, -7.8320e-01, -9.6307e-01,  1.0453e+00,
          6.1779e-01,  6.3907e-01, -4.5890e-03, -6.8076e-01,  8.8446e-01,
          9.9760e-01, -7.1722e-02, -5.8272e-01,  2.8821e-01,  2.0746e-01,
         -2.0808e+00,  1.6464e+00,  5.5154e-01, -8.1724e-02, -7.8051e-01,
         -3.5968e-01,  2.6516e-01, -7.0609e-01,  2.3366e-01,  7.0895e-01,
         -1.1908e+00,  3.1359e-01,  8.0355e-01,  4.2230e-01, -6.2012e-01,
         -1.2188e-01,  4.7258e-01, -1.1506e+00, -9.2656e-01,  1.7154e+00,
          7.2737e-01, -3.7786e-01, -6.3595e-01,  2.5232e-01,  9.9620e-01,
          1.0171e-01, -2.6627e-01,  3.