In [1]:
import pandas as pd

from shared.schema import DatasetSchema, GraphSchema
from shared.graph.loading import pd_from_entity_schema

In [2]:
DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)

In [3]:
explicit_label = False
explicit_timestamp = True
unix_timestamp = True
prefix_id = None
include_properties = lambda cs: [c for c in cs if c.startswith('feat_') or c == 'name']

nodes_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    ).set_index('id').drop(columns=['type']).sort_index()
    for label, entity_schema in schema.nodes.items()
}
node_mappings_dfs = {
    label: pd.Series(range(len(df)), index=df.index, name='nid')
    for label, df in nodes_dfs.items()
}

edges_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    )
        .reset_index()
        .drop(columns=['type'])
        .drop_duplicates(subset=['src', 'dst', 'timestamp'])
        .join(node_mappings_dfs[entity_schema.source_type], on='src')
        .drop(columns=['src'])
        .rename(columns={'nid': 'src'})
        .join(node_mappings_dfs[entity_schema.target_type], on='dst')
        .drop(columns=['dst'])
        .rename(columns={'nid': 'dst'})
    for label, entity_schema in schema.edges.items()
}

cursor = 0
for df in edges_dfs.values():
    df.index += cursor
    cursor += len(df)

In [4]:
import torch
import numpy as np
from torch_geometric.data import HeteroData, Data

In [7]:
data = HeteroData()
for ntype, ndf in nodes_dfs.items():
    columns = [c for c in ndf.columns if c.startswith('feat_')]
    data[ntype].x = torch.tensor(ndf[columns].values.astype(np.float32))
    if 'timestamp' in ndf.columns:
        data[ntype].timestamp = torch.tensor(ndf['timestamp'].values.astype(np.int32))

for etype, edf in edges_dfs.items():
    columns = [c for c in edf.columns if c.startswith('feat_')]
    edge_schema = schema.edges[etype]
    edge_type = (edge_schema.source_type, edge_schema.get_type(), edge_schema.target_type)
    data[edge_type].edge_attr = torch.tensor(edf[columns].values.astype(np.float32))
    data[edge_type].edge_index = torch.tensor(edf[['src', 'dst']].T.values.astype(np.int64))
    if 'timestamp' in edf.columns:
        data[edge_type].timestamp = torch.tensor(edf['timestamp'].values.astype(np.int32))

In [8]:
list(data.edge_index_dict.keys())

[('Character', 'INTERACTIONS', 'Character'),
 ('Character', 'MENTIONS', 'Character')]

In [9]:
data.edge_index_dict[('Character', 'INTERACTIONS', 'Character')]

tensor([[24,  0, 25,  ..., 96, 37, 91],
        [45, 28, 77,  ..., 79, 74, 29]])

In [10]:
data.edge_index_dict[('Character', 'INTERACTIONS', 'Character')].shape

torch.Size([2, 958])

In [11]:
metapath = [
    ('Character', 'INTERACTIONS', 'Character'),
    ('Character', 'MENTIONS', 'Character')
]

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cpu"

In [31]:
from ml.data import LinkSplitter

transform = LinkSplitter(
    num_val=0.3,
    num_test=0.0,
    edge_types=metapath,
)

In [32]:
train_data, val_data, test_data = transform(data)

In [33]:
print(train_data)

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[671, 0],
    edge_index=[2, 671],
    timestamp=[671]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[784, 0],
    edge_index=[2, 784],
    timestamp=[784]
  }
)


In [34]:
print(val_data)

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)


In [35]:
from torch_geometric.nn import MetaPath2Vec

In [36]:
model = MetaPath2Vec(
    train_data.edge_index_dict,
    embedding_dim=32,
    metapath=metapath,
    walk_length=5,
    context_size=3,
    walks_per_node=5,
    num_negative_samples=5,
    sparse=True
).to(device)

In [37]:
model_val = MetaPath2Vec(
    val_data.edge_index_dict,
    embedding_dim=32,
    metapath=metapath,
    walk_length=5,
    context_size=3,
    walks_per_node=5,
    num_negative_samples=5,
    sparse=True
)

In [38]:
loader = model.loader(batch_size=8, shuffle=True, num_workers=0)

In [39]:
loader_val = model_val.loader(batch_size=8, shuffle=True, num_workers=0)

In [40]:
for idx, (pos_rw, neg_rw) in enumerate(loader):
    print(idx, pos_rw.shape, neg_rw.shape)

0 torch.Size([160, 3]) torch.Size([800, 3])
1 torch.Size([160, 3]) torch.Size([800, 3])
2 torch.Size([160, 3]) torch.Size([800, 3])
3 torch.Size([160, 3]) torch.Size([800, 3])
4 torch.Size([160, 3]) torch.Size([800, 3])
5 torch.Size([160, 3]) torch.Size([800, 3])
6 torch.Size([160, 3]) torch.Size([800, 3])
7 torch.Size([160, 3]) torch.Size([800, 3])
8 torch.Size([160, 3]) torch.Size([800, 3])
9 torch.Size([160, 3]) torch.Size([800, 3])
10 torch.Size([160, 3]) torch.Size([800, 3])
11 torch.Size([160, 3]) torch.Size([800, 3])
12 torch.Size([160, 3]) torch.Size([800, 3])
13 torch.Size([160, 3]) torch.Size([800, 3])
14 torch.Size([20, 3]) torch.Size([100, 3])


In [41]:
print(pos_rw[5], neg_rw[0])

tensor([ 14,  27, 113]) tensor([15, 78, 43])


In [42]:
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.001)

In [43]:
def train(epoch, log_steps=10, eval_steps=1000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

    val_loss = test()

@torch.no_grad()
def test():
    losses = []
    for i, (pos_rw, neg_rw) in enumerate(loader_val):
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        losses.append(loss.item())
    print(f'Val Loss: {np.mean(losses):.4f}')
    return np.mean(losses)


for epoch in range(400):
    train(epoch)
    # acc = test()
    print(f'Epoch: {epoch}')

Epoch: 0, Step: 00010/15, Loss: 4.6709
Val Loss: 4.5841
Epoch: 0
Epoch: 1, Step: 00010/15, Loss: 4.3992
Val Loss: 4.5623
Epoch: 1
Epoch: 2, Step: 00010/15, Loss: 4.4315
Val Loss: 4.5363
Epoch: 2
Epoch: 3, Step: 00010/15, Loss: 4.3543
Val Loss: 4.4628
Epoch: 3
Epoch: 4, Step: 00010/15, Loss: 4.4213
Val Loss: 4.2370
Epoch: 4
Epoch: 5, Step: 00010/15, Loss: 4.3693
Val Loss: 4.2387
Epoch: 5
Epoch: 6, Step: 00010/15, Loss: 4.2699
Val Loss: 4.1733
Epoch: 6
Epoch: 7, Step: 00010/15, Loss: 4.1207
Val Loss: 4.2241
Epoch: 7
Epoch: 8, Step: 00010/15, Loss: 4.0273
Val Loss: 4.0768
Epoch: 8
Epoch: 9, Step: 00010/15, Loss: 4.0696
Val Loss: 4.0095
Epoch: 9
Epoch: 10, Step: 00010/15, Loss: 3.9247
Val Loss: 3.9732
Epoch: 10
Epoch: 11, Step: 00010/15, Loss: 3.9155
Val Loss: 3.9895
Epoch: 11
Epoch: 12, Step: 00010/15, Loss: 3.9311
Val Loss: 3.9387
Epoch: 12
Epoch: 13, Step: 00010/15, Loss: 3.7940
Val Loss: 3.8636
Epoch: 13
Epoch: 14, Step: 00010/15, Loss: 3.7046
Val Loss: 3.8229
Epoch: 14
Epoch: 15, Step