In [7]:
import os.path as osp
import numpy as np
import torch
from torch_geometric.datasets import AMiner
from torch_geometric.nn import MetaPath2Vec

In [8]:
# load the dataset
path = osp.join('..', 'data', 'AMiner')
dataset = AMiner(path)
data = dataset[0]

In [9]:
print(data)

HeteroData(
  [1mauthor[0m={
    y=[246678],
    y_index=[246678],
    num_nodes=1693531
  },
  [1mvenue[0m={
    y=[134],
    y_index=[134],
    num_nodes=3883
  },
  [1mpaper[0m={ num_nodes=3194405 },
  [1m(paper, written_by, author)[0m={ edge_index=[2, 9323605] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 9323605] },
  [1m(paper, published_in, venue)[0m={ edge_index=[2, 3194405] },
  [1m(venue, publishes, paper)[0m={ edge_index=[2, 3194405] }
)


In [4]:
print(type(data.edge_index_dict))
print(data.edge_index_dict[('paper', 'written_by', 'author')])

<class 'dict'>
tensor([[      0,       1,       2,  ..., 3194404, 3194404, 3194404],
        [      0,       1,       2,  ...,    4393,   21681,  317436]])


In [6]:
data.edge_index_dict[('paper', 'written_by', 'author')].shape

torch.Size([2, 9323605])

In [5]:
print(type(data.num_nodes_dict))
print(data.num_nodes_dict)

<class 'dict'>
{'author': 1693531, 'venue': 3883, 'paper': 3194405}


In [6]:
print(type(data.y_dict))
print(data.y_dict["venue"])


<class 'dict'>
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])


In [7]:
print(type(data.y_index_dict))
print(data.y_index_dict["venue"])

<class 'dict'>
tensor([1741, 2245,  111,  837, 2588, 2116, 2696, 3648, 3784,  313, 3414,  598,
        2995, 2716, 1423,  783, 1902, 3132, 1753, 2748, 2660, 3182,  775, 3339,
        1601, 3589,  156, 1145,  692, 3048,  925, 1587,  820, 1374, 3719,  819,
         492, 3830, 2777, 3001, 3693,  517, 1808, 2353, 3499, 1763, 2372, 1030,
         721, 2680, 3355, 1217, 3400, 1271, 1970, 1127,  407,  353, 1471, 1095,
         477, 3701,   65, 1009, 1899, 1442, 2073, 3143, 2466,  289, 1996, 1070,
        3871, 3695,  281, 3633,   50, 2642, 1925, 1285, 2587, 3814, 3582, 1873,
        1339, 3450,  271, 2966,  453, 2638, 1354, 3211,  391, 1588, 3875, 2216,
        2146, 3765, 2486,  661, 3367,  426,  750, 2158,  519,  230, 1677,  839,
        2945, 1313, 1037, 2879, 2225, 3523, 1247,  448,  227, 3385,  529, 2849,
        1584, 1229,  373, 2235, 1819, 1764, 3155, 2852, 2789, 3474, 1571, 2088,
         208,  462])


In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:

metapath = [
    ('author', 'writes', 'paper'),
    ('paper', 'published_in', 'venue'),
    ('venue', 'publishes', 'paper'),
    ('paper', 'written_by', 'author'),
]


model = MetaPath2Vec(
    data.edge_index_dict,
                     embedding_dim=32,
                     metapath=metapath,
                     walk_length=5,
                     context_size=3,
                     walks_per_node=3,
                     num_negative_samples=1,
                     sparse=True
                    ).to(device)

In [10]:
# use the loader to build a loader
loader = model.loader(batch_size=64, shuffle=True, num_workers=3)

In [11]:
for idx, (pos_rw, neg_rw) in enumerate(loader):
    if idx == 10: break
    print(idx, pos_rw.shape, neg_rw.shape)

0 torch.Size([768, 3]) torch.Size([768, 3])
1 torch.Size([768, 3]) torch.Size([768, 3])
2 torch.Size([768, 3]) torch.Size([768, 3])
3 torch.Size([768, 3]) torch.Size([768, 3])
4 torch.Size([768, 3]) torch.Size([768, 3])
5 torch.Size([768, 3]) torch.Size([768, 3])
6 torch.Size([768, 3]) torch.Size([768, 3])
7 torch.Size([768, 3]) torch.Size([768, 3])
8 torch.Size([768, 3]) torch.Size([768, 3])
9 torch.Size([768, 3]) torch.Size([768, 3])


In [12]:
# Inizialize optimizer
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

In [14]:
def train(epoch, log_steps=500, eval_steps=1000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

        if (i + 1) % eval_steps == 0:
            acc = test()
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Acc: {acc:.4f}'))

@torch.no_grad()
def test(train_ratio=0.1):
    model.eval()

    batch = data.y_index_dict['author'].to(device)
    z = model('author', batch=batch)
    y = data.y_dict['author']

    perm = torch.randperm(z.size(0))
    train_perm = perm[:int(z.size(0) * train_ratio)]
    test_perm = perm[int(z.size(0) * train_ratio):]

    return model.test(z[train_perm], y[train_perm], z[test_perm],
                      y[test_perm], max_iter=150)

for epoch in range(1, 2):
    train(epoch)
    acc = test()
    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')


Epoch: 1, Step: 00500/26462, Loss: 3.6145
Epoch: 1, Step: 01000/26462, Loss: 3.3771
Epoch: 1, Step: 01000/26462, Acc: 0.2792
Epoch: 1, Step: 01500/26462, Loss: 3.2212
Epoch: 1, Step: 02000/26462, Loss: 3.1220
Epoch: 1, Step: 02000/26462, Acc: 0.2785
Epoch: 1, Step: 02500/26462, Loss: 3.0549
Epoch: 1, Step: 03000/26462, Loss: 3.0186
Epoch: 1, Step: 03000/26462, Acc: 0.2790
Epoch: 1, Step: 03500/26462, Loss: 2.9998
Epoch: 1, Step: 04000/26462, Loss: 2.9734
Epoch: 1, Step: 04000/26462, Acc: 0.2787
Epoch: 1, Step: 04500/26462, Loss: 2.9631
Epoch: 1, Step: 05000/26462, Loss: 2.9510
Epoch: 1, Step: 05000/26462, Acc: 0.2788
Epoch: 1, Step: 05500/26462, Loss: 2.9447
Epoch: 1, Step: 06000/26462, Loss: 2.9255
Epoch: 1, Step: 06000/26462, Acc: 0.2786
Epoch: 1, Step: 06500/26462, Loss: 2.9268
Epoch: 1, Step: 07000/26462, Loss: 2.9194
Epoch: 1, Step: 07000/26462, Acc: 0.2789
Epoch: 1, Step: 07500/26462, Loss: 2.9076
Epoch: 1, Step: 08000/26462, Loss: 2.8999
Epoch: 1, Step: 08000/26462, Acc: 0.2788
