In [1]:
from bidict import bidict
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
import itertools
import pandas as pd
import numpy as np
from franz.openrdf.connect import ag_connect
from franz.openrdf.query.query import QueryLanguage
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.io import read_txt_array
from torch_geometric.nn.models.re_net import RENet
from torch_geometric.datasets.icews import EventDataset
from tqdm import tqdm
import shutil
from pathlib import Path

datapath = Path(".").joinpath("data")
datapath.mkdir(exist_ok=True)

@dataclass
class _Resource(object):
    idx: int
    uri: str   
    def __hash__(self): return hash(self.uri)    
    def __repr__(self): return self.uri
        
class Entity(_Resource): pass
class Relation(_Resource): pass

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device (for PyTorch):", device)

Device (for PyTorch): cpu


## Construct a Dataframe

We will construct a dataframe where each row is represented as `(subject, predicate, object, timestamp)`.

The dataframe will be sorted by `timestamp`.

In [2]:
%%time
q = """SELECT DISTINCT ?s ?s_id ?p ?p_id ?o ?o_id ?t {
  ?event a <http://franz.com/Relation> ;
           <http://franz.com/date> ?t ;
           <http://franz.com/from> ?s ;
           <http://franz.com/predicate> ?p ;
           <http://franz.com/ID> ?p_id ;
           <http://franz.com/to> ?o .
  ?s <http://franz.com/ID> ?s_id .
  ?o <http://franz.com/ID> ?o_id .
}
ORDER BY ?t"""

data = defaultdict(list)
with ag_connect('Events2018', host="localhost", port=10000, catalog="demos") as conn:
    with conn.prepareTupleQuery(QueryLanguage.SPARQL, q).evaluate() as res:
        for bs in tqdm(res):
            data["s"].append(Entity(idx=bs.getValue("s_id").intValue(), uri=bs.getValue("s").value))
            data["p"].append(Relation(idx=bs.getValue("p_id").intValue(), uri=bs.getValue("p").value))
            data["o"].append(Entity(idx=bs.getValue("o_id").intValue(), uri=bs.getValue("o").value))
            data["t"].append(bs.getValue("t").toPython())
df = pd.DataFrame(data=data)
start_date = df["t"][0]
df.t = df.t.apply(lambda d: (d-start_date).days)
df

100%|██████████| 468558/468558 [01:25<00:00, 5465.97it/s]


CPU times: user 1min 38s, sys: 738 ms, total: 1min 39s
Wall time: 2min 6s


Unnamed: 0,s,p,o,t
0,http://franz.com/examples/Opposition%20Support...,http://franz.com/examples/Property#Accuse,http://franz.com/examples/Citizen%20(Russia),0
1,http://franz.com/examples/Government%20(Ukraine),http://franz.com/examples/Property#Make%20stat...,http://franz.com/examples/Military%20(Ukraine),0
2,http://franz.com/examples/Armed%20Rebel%20(Ukr...,http://franz.com/examples/Property#Accuse,http://franz.com/examples/Military%20(Ukraine),0
3,http://franz.com/examples/Militia%20(Ukraine),http://franz.com/examples/Property#Make%20stat...,http://franz.com/examples/Military%20(Ukraine),0
4,http://franz.com/examples/Philippine%20Nationa...,http://franz.com/examples/Property#Rally%20opp...,http://franz.com/examples/Military%20Personnel...,0
...,...,...,...,...
468553,http://franz.com/examples/Opposition%20Support...,http://franz.com/examples/Property#Accuse,http://franz.com/examples/Lawmaker%20(India),303
468554,http://franz.com/examples/Ant%C3%B3nio%20Manue...,http://franz.com/examples/Property#Make%20stat...,http://franz.com/examples/Foreign%20Affairs%20...,303
468555,http://franz.com/examples/Staffan%20de%20Mistura,http://franz.com/examples/Property#Make%20stat...,http://franz.com/examples/Foreign%20Affairs%20...,303
468556,http://franz.com/examples/Police%20(India),"http://franz.com/examples/Property#Arrest,%20d...",http://franz.com/examples/Saddam%20Hussein,303


## Construct entity2id and relation2id

To reference all entities and relations by ids (integers), we construct dictionaries for both of them.

In [3]:
entity2id, relation2id = bidict(), bidict()

for ent in tqdm(itertools.chain(df.s, df.o), desc="processing entity2id"):
    entity2id[ent] = ent.idx
        
for rel in tqdm(df.p, desc="processing relation2id"):
    relation2id[rel] = rel.idx

print(f"{len(entity2id)} Entities, {len(relation2id)} Relations, {len(df)} Events in total")

processing entity2id: 937116it [00:04, 211297.96it/s]
processing relation2id: 100%|██████████| 468558/468558 [00:02<00:00, 218576.14it/s]

23033 Entities, 256 Relations, 468558 Events in total





## Spliting training, validation and test dataset

* train - 80%
* test - 10%
* validation - 10 Sample Events of Donald Trump only

In [4]:
train, valid, test = np.split(df, [int(.8*len(df)), int(.9*len(df))])
len(train), len(valid), len(test)

(374846, 46856, 46856)

## Customising our Events2018 dataset class

By using torch_geometric's [Dataset API](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Dataset), we tell torch_geometric:

1. How to prepare raw data (see `download` method)
2. How to process raw data (see `process` method)

In [5]:
class Events2018(EventDataset):
    splits = [0, len(train), len(train)+len(valid), len(train)+len(valid)+len(test)]  # Train/Val/Test splits.

    def __init__(self, root, split='train', transform=None, pre_transform=None, pre_filter=None):
        assert split in ['train', 'val', 'test']
        super(Events2018, self).__init__(root, transform, pre_transform, pre_filter)
        idx = self.processed_file_names.index('{}.pt'.format(split))
        self.data, self.slices = torch.load(self.processed_paths[idx])

    @property
    def num_nodes(self):
        return len(entity2id)

    @property
    def num_rels(self):
        return len(relation2id)

    @property
    def raw_file_names(self):
        return ['{}.txt'.format(name) for name in ['train', 'valid', 'test']]

    @property
    def processed_file_names(self):
        return ['train.pt', 'val.pt', 'test.pt']
    
    def download(self):
        d = { "train.txt": train, "valid.txt": valid, "test.txt": test }        
        for filename in self.raw_file_names:            
            with datapath.joinpath("raw", filename).open("w") as fd:
                buffer = []
                for row in tqdm(d[filename].itertuples(), desc=f"writing {filename}"):
                    buffer.append(f"{entity2id[row.s]}\t{relation2id[row.p]}\t{entity2id[row.o]}\t{row.t}\n")
                    if len(buffer) == 4096:
                        fd.writelines(buffer)
                        buffer = []
                if len(buffer) != 0:
                    fd.writelines(buffer)                    

    def process_events(self):
        events = []
        for path in self.raw_paths:
            events += [read_txt_array(path, sep='\t', end=4, dtype=torch.long)]
        return torch.cat(events, dim=0)

    def process(self):
        s = self.splits
        data_list = super(Events2018, self).process()
        torch.save(self.collate(data_list[s[0]:s[1]]), self.processed_paths[0])
        torch.save(self.collate(data_list[s[1]:s[2]]), self.processed_paths[1])
        torch.save(self.collate(data_list[s[2]:s[3]]), self.processed_paths[2])

## Create dataloader for training and test dataset

In [6]:
%%time
seq_len = 10 # how many historical events to look back
if datapath.joinpath("processed").exists(): shutil.rmtree(datapath.joinpath("processed"))
train_dataset = Events2018(datapath, pre_transform=RENet.pre_transform(seq_len))
test_dataset = Events2018(datapath, split='test')

train_loader = DataLoader(
    train_dataset,
    batch_size=1024,
    shuffle=True,
    follow_batch=['h_sub', 'h_obj'],
    num_workers=6)

test_loader = DataLoader(
    test_dataset,
    batch_size=1024,
    shuffle=False,
    follow_batch=['h_sub', 'h_obj'],
    num_workers=6)

Processing...




Done!
CPU times: user 2min 39s, sys: 728 ms, total: 2min 40s
Wall time: 2min 40s


## Define the RENet model

In [7]:
model = RENet(
    train_dataset.num_nodes,
    train_dataset.num_rels,
    hidden_channels=200,
    seq_len=seq_len,
    dropout=0.5,
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, weight_decay=0.00001)

## Train the model and save it to disk

In [8]:
%%time
def train():
    model.train()
    # Train model via multi-class classification against the corresponding
    # object and subject entities.
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        log_prob_obj, log_prob_sub = model(data)
        loss_obj = F.nll_loss(log_prob_obj, data.obj)
        loss_sub = F.nll_loss(log_prob_sub, data.sub)
        loss = loss_obj + loss_sub
        loss.backward()
        optimizer.step()


def test(loader):
    model.eval()
    # Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.
    result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            log_prob_obj, log_prob_sub = model(data)
        result += model.test(log_prob_obj, data.obj) * data.obj.size(0)
        result += model.test(log_prob_sub, data.sub) * data.sub.size(0)
    result = result / (2 * len(loader.dataset))
    return result.tolist()

for epoch in range(1, 21):
    train()
    results = test(test_loader)
    print('Epoch: {:02d}, MRR: {:.4f}, Hits@1: {:.4f}, Hits@3: {:.4f}, '
          'Hits@10: {:.4f}'.format(epoch, *results))

torch.save(model.state_dict(), datapath.joinpath("model.pt")) # save model states

Epoch: 01, MRR: 0.1946, Hits@1: 0.1228, Hits@3: 0.2158, Hits@10: 0.3353
Epoch: 02, MRR: 0.2381, Hits@1: 0.1530, Hits@3: 0.2687, Hits@10: 0.4043
Epoch: 03, MRR: 0.2567, Hits@1: 0.1662, Hits@3: 0.2899, Hits@10: 0.4331
Epoch: 04, MRR: 0.2667, Hits@1: 0.1736, Hits@3: 0.3021, Hits@10: 0.4485
Epoch: 05, MRR: 0.2726, Hits@1: 0.1778, Hits@3: 0.3100, Hits@10: 0.4575
Epoch: 06, MRR: 0.2763, Hits@1: 0.1806, Hits@3: 0.3138, Hits@10: 0.4623
Epoch: 07, MRR: 0.2788, Hits@1: 0.1831, Hits@3: 0.3164, Hits@10: 0.4647
Epoch: 08, MRR: 0.2794, Hits@1: 0.1826, Hits@3: 0.3177, Hits@10: 0.4674
Epoch: 09, MRR: 0.2818, Hits@1: 0.1855, Hits@3: 0.3205, Hits@10: 0.4695
Epoch: 10, MRR: 0.2820, Hits@1: 0.1853, Hits@3: 0.3208, Hits@10: 0.4708
Epoch: 11, MRR: 0.2820, Hits@1: 0.1844, Hits@3: 0.3226, Hits@10: 0.4702
Epoch: 12, MRR: 0.2823, Hits@1: 0.1853, Hits@3: 0.3220, Hits@10: 0.4715
Epoch: 13, MRR: 0.2827, Hits@1: 0.1857, Hits@3: 0.3224, Hits@10: 0.4717
Epoch: 14, MRR: 0.2827, Hits@1: 0.1857, Hits@3: 0.3226, Hits@10: