In [1]:
import pandas as pd

from shared.schema import DatasetSchema, GraphSchema
from shared.graph.loading import pd_from_entity_schema

In [2]:
DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)

In [3]:
explicit_label = False
explicit_timestamp = True
unix_timestamp = True
prefix_id = None
include_properties = lambda cs: [c for c in cs if c.startswith('feat_') or c == 'name']

nodes_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    ).set_index('id').drop(columns=['type']).sort_index()
    for label, entity_schema in schema.nodes.items()
}
node_mappings_dfs = {
    label: pd.Series(range(len(df)), index=df.index, name='nid')
    for label, df in nodes_dfs.items()
}

edges_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    )
        .reset_index()
        .drop(columns=['type'])
        .drop_duplicates(subset=['src', 'dst', 'timestamp'])
        .join(node_mappings_dfs[entity_schema.source_type], on='src')
        .drop(columns=['src'])
        .rename(columns={'nid': 'src'})
        .join(node_mappings_dfs[entity_schema.target_type], on='dst')
        .drop(columns=['dst'])
        .rename(columns={'nid': 'dst'})
    for label, entity_schema in schema.edges.items()
}

cursor = 0
for df in edges_dfs.values():
    df.index += cursor
    cursor += len(df)

In [4]:
import torch
import numpy as np
from torch_geometric.data import HeteroData, Data
from torch_geometric.utils import negative_sampling

In [5]:
data = HeteroData()
for ntype, ndf in nodes_dfs.items():
    columns = [c for c in ndf.columns if c.startswith('feat_')]
    data[ntype].x = torch.tensor(ndf[columns].values.astype(np.float32))
    if 'timestamp' in ndf.columns:
        data[ntype].timestamp = torch.tensor(ndf['timestamp'].values.astype(np.int32))

for etype, edf in edges_dfs.items():
    columns = [c for c in edf.columns if c.startswith('feat_')]
    edge_schema = schema.edges[etype]
    edge_type = (edge_schema.source_type, edge_schema.get_type(), edge_schema.target_type)
    data[edge_type].edge_attr = torch.tensor(edf[columns].values.astype(np.float32))
    data[edge_type].edge_index = torch.tensor(edf[['src', 'dst']].T.values.astype(np.int64))
    if 'timestamp' in edf.columns:
        data[edge_type].timestamp = torch.tensor(edf['timestamp'].values.astype(np.int32))

In [6]:
print(data)

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)


In [7]:
from ml.data import LinkSplitter

transform = LinkSplitter(
    num_val=0.3,
    num_test=0.0,
    edge_types=data.edge_types,
)

train_data, val_data, test_data = transform(data)

In [9]:
import torch.nn.functional as F

from torch_geometric.nn import HGTConv, Linear
from torch_geometric.loader import HGTLoader

In [52]:
from typing import List, Any, Tuple
from torch_geometric.loader.base import BaseDataLoader

class EdgeLoader(BaseDataLoader):
    def __init__(
            self,
            data: HeteroData,
            num_samples,
            input_nodes,
            input_edges,
            **kwargs
    ):
        self.hgt_loader = HGTLoader(data, num_samples, input_nodes)

        super(EdgeLoader, self).__init__(
            input_edges.tolist(),
            collate_fn=self.sample,
            **kwargs,
        )

    def sample(self, indices: List[Tuple[int, int, int]]):
        idx_a, idx_b, labels = list(zip(*indices))
        return self.hgt_loader.sample(idx_a), self.hgt_loader.sample(idx_b), torch.tensor(labels, dtype=torch.int64)

    def transform_fn(self, out: Any) -> HeteroData:
        out_a, out_b, labels = out
        return self.hgt_loader.transform_fn(out_a), self.hgt_loader.transform_fn(out_b), labels


In [91]:
train_data_pos_edge_index = {
    edge_type: train_data[edge_type].edge_index
    for edge_type in train_data.edge_types
}

train_data_neg_edge_index = {
    edge_type: negative_sampling(train_data[edge_type].edge_index, num_neg_samples=train_data_pos_edge_index[edge_type].shape[1])
    for edge_type in train_data.edge_types 
}

train_data_edge_index = {
    edge_type: torch.cat([
        torch.cat([train_data_pos_edge_index[edge_type], torch.ones(1, train_data_pos_edge_index[edge_type].shape[1], dtype=torch.long)], dim=0),
        torch.cat([train_data_neg_edge_index[edge_type], torch.zeros(1, train_data_neg_edge_index[edge_type].shape[1], dtype=torch.long)], dim=0),
    ], dim=1)
    for edge_type in train_data.edge_types
}

train_loader = EdgeLoader(
    data,
    num_samples=[4] * 2,
    shuffle=True,
    input_nodes=('Character', torch.tensor(range(data['Character'].num_nodes))),
    input_edges=torch.cat([
        *train_data_edge_index.values()
    ], dim=1).t(),
    batch_size=8,
    num_workers=4,
)

next(iter(train_loader))

(HeteroData(
   [1mCharacter[0m={
     x=[16, 32],
     batch_size=8
   },
   [1m(Character, INTERACTIONS, Character)[0m={
     edge_attr=[182, 0],
     edge_index=[2, 182],
     timestamp=[182]
   },
   [1m(Character, MENTIONS, Character)[0m={
     edge_attr=[192, 0],
     edge_index=[2, 192],
     timestamp=[192]
   }
 ),
 HeteroData(
   [1mCharacter[0m={
     x=[16, 32],
     batch_size=8
   },
   [1m(Character, INTERACTIONS, Character)[0m={
     edge_attr=[123, 0],
     edge_index=[2, 123],
     timestamp=[123]
   },
   [1m(Character, MENTIONS, Character)[0m={
     edge_attr=[143, 0],
     edge_index=[2, 143],
     timestamp=[143]
   }
 ),
 tensor([0, 1, 1, 1, 1, 1, 1, 0]))

In [92]:
val_data_pos_edge_index = {
    edge_type: val_data[edge_type].edge_index[:, val_data[edge_type].edge_partitions != 0]
    for edge_type in val_data.edge_types
}

val_data_neg_edge_index = {
    edge_type: negative_sampling(val_data[edge_type].edge_index, num_neg_samples=val_data_pos_edge_index[edge_type].shape[1])
    for edge_type in val_data.edge_types
}

val_data_edge_index = {
    edge_type: torch.cat([
        torch.cat([val_data_pos_edge_index[edge_type], torch.ones(1, val_data_pos_edge_index[edge_type].shape[1], dtype=torch.long)], dim=0),
        torch.cat([val_data_neg_edge_index[edge_type], torch.zeros(1, val_data_neg_edge_index[edge_type].shape[1], dtype=torch.long)], dim=0),
    ], dim=1)
    for edge_type in train_data.edge_types
}

val_loader = EdgeLoader(
    data,
    num_samples=[4] * 2,
    shuffle=True,
    input_nodes=('Character', torch.tensor(range(data['Character'].num_nodes))),
    input_edges=torch.cat([
        *val_data_edge_index.values()
    ], dim=1).t(),
    batch_size=8,
    num_workers=4,
)

next(iter(train_loader))

(HeteroData(
   [1mCharacter[0m={
     x=[16, 32],
     batch_size=8
   },
   [1m(Character, INTERACTIONS, Character)[0m={
     edge_attr=[109, 0],
     edge_index=[2, 109],
     timestamp=[109]
   },
   [1m(Character, MENTIONS, Character)[0m={
     edge_attr=[76, 0],
     edge_index=[2, 76],
     timestamp=[76]
   }
 ),
 HeteroData(
   [1mCharacter[0m={
     x=[16, 32],
     batch_size=8
   },
   [1m(Character, INTERACTIONS, Character)[0m={
     edge_attr=[110, 0],
     edge_index=[2, 110],
     timestamp=[110]
   },
   [1m(Character, MENTIONS, Character)[0m={
     edge_attr=[86, 0],
     edge_index=[2, 86],
     timestamp=[86]
   }
 ),
 tensor([1, 1, 1, 1, 0, 0, 1, 0]))

In [93]:
from tqdm import tqdm

In [94]:
class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels * 2, out_channels)

    def forward_embed(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict['Character']

    def forward(self, larg, rarg):
        lemb = self.forward_embed(*larg)
        remb = self.forward_embed(*rarg)
        emb = torch.cat([lemb, remb], dim=-1)

        return self.lin(emb)

In [101]:
model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)
print(model)

HGT(
  (lin_dict): ModuleDict(
    (Character): Linear(-1, 64, bias=True)
  )
  (convs): ModuleList(
    (0): HGTConv(64, heads=2)
  )
  (lin): Linear(128, 4, bias=True)
)


In [102]:
@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch_l, batch_r, labels = next(iter(train_loader))
    batch_l = batch_l.to(device)
    batch_r = batch_r.to(device)
    model(
        (batch_l.x_dict, batch_l.edge_index_dict),
        (batch_r.x_dict, batch_r.edge_index_dict),
    )

In [103]:
def train():
    model.train()

    total_examples = total_loss = 0
    for batch_l, batch_r, label in tqdm(train_loader):
        optimizer.zero_grad()

        batch_size = batch_l['Character'].batch_size
        batch_l = batch_l.to(device)
        batch_r = batch_r.to(device)

        out = model(
            (batch_l.x_dict, batch_l.edge_index_dict),
            (batch_r.x_dict, batch_r.edge_index_dict),
        )[:batch_size]

        label = label.to(device)
        loss = F.cross_entropy(out, label)
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples


In [104]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_examples = total_correct = 0
    for batch_l, batch_r, label in tqdm(loader):
        batch_size = batch_l['Character'].batch_size
        batch_l = batch_l.to(device)
        batch_r = batch_r.to(device)

        out = model(
            (batch_l.x_dict, batch_l.edge_index_dict),
            (batch_r.x_dict, batch_r.edge_index_dict),
        )[:batch_size]

        pred = out.argmax(dim=-1)
        label = label.to(device)

        total_examples += batch_size
        total_correct += int((pred == label).sum())

    return total_correct / total_examples

In [105]:
init_params()

In [106]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

val_acc = test(val_loader)
print(f'Epoch: untrained, Val Acc: {val_acc:.4f}')

for epoch in range(1, 21):
    loss = train()
    val_acc = test(val_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')

100%|██████████| 156/156 [00:02<00:00, 60.67it/s]


Epoch: untrained, Val Acc: 0.3716


100%|██████████| 364/364 [00:12<00:00, 29.95it/s]
100%|██████████| 156/156 [00:02<00:00, 57.92it/s]


Epoch: 01, Loss: 0.6664, Val Acc: 0.7255


100%|██████████| 364/364 [00:12<00:00, 28.47it/s]
100%|██████████| 156/156 [00:02<00:00, 55.72it/s]


Epoch: 02, Loss: 0.5569, Val Acc: 0.7552


100%|██████████| 364/364 [00:11<00:00, 31.80it/s]
100%|██████████| 156/156 [00:02<00:00, 60.68it/s]


Epoch: 03, Loss: 0.5359, Val Acc: 0.7319


100%|██████████| 364/364 [00:12<00:00, 28.27it/s]
100%|██████████| 156/156 [00:02<00:00, 58.86it/s]


Epoch: 04, Loss: 0.5268, Val Acc: 0.7705


100%|██████████| 364/364 [00:12<00:00, 29.40it/s]
100%|██████████| 156/156 [00:02<00:00, 58.43it/s]


Epoch: 05, Loss: 0.5103, Val Acc: 0.7697


100%|██████████| 364/364 [00:11<00:00, 31.20it/s]
100%|██████████| 156/156 [00:02<00:00, 56.74it/s]


Epoch: 06, Loss: 0.5095, Val Acc: 0.7841


100%|██████████| 364/364 [00:12<00:00, 30.10it/s]
100%|██████████| 156/156 [00:02<00:00, 59.98it/s]


Epoch: 07, Loss: 0.5037, Val Acc: 0.7640


100%|██████████| 364/364 [00:12<00:00, 29.15it/s]
100%|██████████| 156/156 [00:02<00:00, 61.15it/s]


Epoch: 08, Loss: 0.4938, Val Acc: 0.7753


100%|██████████| 364/364 [00:11<00:00, 31.52it/s]
100%|██████████| 156/156 [00:02<00:00, 62.25it/s]


Epoch: 09, Loss: 0.5005, Val Acc: 0.7721


100%|██████████| 364/364 [00:12<00:00, 28.57it/s]
100%|██████████| 156/156 [00:02<00:00, 60.78it/s]


Epoch: 10, Loss: 0.4940, Val Acc: 0.7785


100%|██████████| 364/364 [00:11<00:00, 30.49it/s]
100%|██████████| 156/156 [00:02<00:00, 61.34it/s]


Epoch: 11, Loss: 0.4896, Val Acc: 0.7873


100%|██████████| 364/364 [00:11<00:00, 30.55it/s]
100%|██████████| 156/156 [00:02<00:00, 58.88it/s]


Epoch: 12, Loss: 0.4869, Val Acc: 0.7817


100%|██████████| 364/364 [00:11<00:00, 30.48it/s]
100%|██████████| 156/156 [00:02<00:00, 61.33it/s]


Epoch: 13, Loss: 0.4821, Val Acc: 0.7849


100%|██████████| 364/364 [00:12<00:00, 29.19it/s]
100%|██████████| 156/156 [00:02<00:00, 59.36it/s]


Epoch: 14, Loss: 0.4855, Val Acc: 0.7793


100%|██████████| 364/364 [00:13<00:00, 27.71it/s]
100%|██████████| 156/156 [00:02<00:00, 62.13it/s]


Epoch: 15, Loss: 0.4814, Val Acc: 0.7857


100%|██████████| 364/364 [00:13<00:00, 27.54it/s]
100%|██████████| 156/156 [00:02<00:00, 55.24it/s]


Epoch: 16, Loss: 0.4821, Val Acc: 0.7777


100%|██████████| 364/364 [00:12<00:00, 28.18it/s]
100%|██████████| 156/156 [00:02<00:00, 59.03it/s]


Epoch: 17, Loss: 0.4793, Val Acc: 0.7697


100%|██████████| 364/364 [00:13<00:00, 26.70it/s]
100%|██████████| 156/156 [00:02<00:00, 60.03it/s]


Epoch: 18, Loss: 0.4776, Val Acc: 0.7817


100%|██████████| 364/364 [00:12<00:00, 28.78it/s]
100%|██████████| 156/156 [00:02<00:00, 62.03it/s]


Epoch: 19, Loss: 0.4735, Val Acc: 0.7785


100%|██████████| 364/364 [00:12<00:00, 29.32it/s]
100%|██████████| 156/156 [00:02<00:00, 56.71it/s]

Epoch: 20, Loss: 0.4777, Val Acc: 0.7793



