## Query answering on YAGO3-10

In this notebook we will show how to use the BESS-KGE package to perform knowledge graph completion on the YAGO3-10 dataset, a subset of [YAGO3](https://yago-knowledge.org/downloads/yago-3) (Yet Another Great Ontology 3) containing only entities with at least ten relations associated to them.

In [None]:
!pip install git+ssh://git@github.com/graphcore-research/bess-kge.git

In [None]:
# import ctypes
import pathlib
import time

import numpy as np
import poptorch
import torch

from besskge.batch_sampler import RigidShardedBatchSampler
from besskge.bess import EmbeddingMovingBessKGE, TopKQueryBessKGE
from besskge.dataset import KGDataset
from besskge.embedding import MarginBasedInitializer
from besskge.loss import MarginRankingLoss, LogSigmoidLoss
from besskge.metric import Evaluation
from besskge.negative_sampler import (
    PlaceholderNegativeSampler,
    RandomShardedNegativeSampler,
)
from besskge.scoring import ComplEx
from besskge.sharding import PartitionedTripleSet, Sharding

Download and preprocess the dataset with the built-in method of `KGDataset`.

In [None]:
yago = KGDataset.build_yago310(root=pathlib.Path("../datasets/yago310/"))

print(f"Number of entities: {yago.n_entity}\n")
print(f"Number of relation types: {yago.n_relation_type}\n")
print(f"Number of triples: \n training: {yago.triples['train'].shape[0]} \n validation/test: {yago.triples['validation'].shape[0]}\n")

# Print example triple
ex_triple_id = 2500
ex_triple = yago.triples["train"][ex_triple_id]
print(f'Example triple: {yago.entity_dict[ex_triple[0]], yago.relation_dict[ex_triple[1]], yago.entity_dict[ex_triple[2]]}')

We want to train on 4 IPUs, so we construct a sharding of the entity table in 4 parts. The entity sharding induces a sharding of the triples into 4*4=16 shardpairs, based on the shard of head and tail entities.

In [None]:
seed = 1234
n_shard = 4

sharding = Sharding.create(yago.n_entity, n_shard=n_shard, seed=seed)
print(f"Global entity IDs on {n_shard} shards:")
print(sharding.shard_and_idx_to_entity)

# The global entity IDs can be recovered, as a function of the shard ID and the local ID on the shard, by
print("\nReconstructed global entity IDs:")
print(sharding.shard_and_idx_to_entity[sharding.entity_to_shard, sharding.entity_to_idx])

train_triples = PartitionedTripleSet.create_from_dataset(yago, "train", sharding)

print("\nNumber of triples per (h,t) shardpair:")
print(train_triples.triple_counts)

To iterate over the sharded set of triples we use a batch sampler.
`RigidShardedBatchSampler` consumes, at each step, the same number of triples from all 16 shardpairs
(resampling from the shorter ones, until the longest one is completed).

To sample negatives during training we use a negative sampler. 
`RandomShardedNegativeSampler` constructs, for each triple, negative samples by sampling random corrupted entities.

See the [biogk_training_inference](biokg_training_inference.ipynb) notebook for more details on these classes' options.

In [None]:
device_iterations = 10
accum_factor = 5
shard_bs = 240
neg_sampler = RandomShardedNegativeSampler(n_negative=1, sharding=sharding, seed=seed, corruption_scheme="ht", local_sampling=False, flat_negative_format=False)
bs = RigidShardedBatchSampler(partitioned_triple_set=train_triples, negative_sampler=neg_sampler, shard_bs=shard_bs, batches_per_step=device_iterations*accum_factor, seed=seed, hrt_freq_weighting=False)

In [None]:
options = poptorch.Options()
options.replication_factor = sharding.n_shard
options.deviceIterations(device_iterations)
options.Training.gradientAccumulation(accum_factor)
options._popart.setPatterns(dict(RemoveAllReducePattern=True))

# Construct the dataloader with the dedicated utility function
train_dl = bs.get_dataloader(options=options, shuffle=True, num_workers=5, persistent_workers=True)

# Example batch
batch = next(iter(train_dl))
for k,v in batch.items():
    print(f"{k:<12} {str(v.shape):<30}")

Let's train the **ComplEx** KGE model with real embedding size 256 and **logsigmoid** loss function, using the `EmbeddingMovingBessKGE` distribution scheme.

In [None]:
marg_rank_loss_fn = LogSigmoidLoss(margin=12.0, negative_adversarial_sampling=True)
emb_initializer = MarginBasedInitializer(margin=marg_rank_loss_fn.margin)
complex_score_fn = ComplEx(negative_sample_sharing=True, sharding=sharding, n_relation_type=yago.n_relation_type, embedding_size=256,
                           entity_intializer=emb_initializer, relation_intializer=emb_initializer)
model = EmbeddingMovingBessKGE(sharding=sharding, negative_sampler=neg_sampler, score_fn=complex_score_fn,
                               loss_fn=marg_rank_loss_fn)
print(f"# model parameters: {model.n_embedding_parameters}")

We are now ready to train!

In [None]:
opt = poptorch.optim.AdamW(
        model.parameters(),
        lr=0.001,
        weight_decay=0.0,
        accum_type=torch.float32,
        first_order_momentum_accum_type=torch.float32,
        second_order_momentum_accum_type=torch.float32,
    )

poptorch_model = poptorch.trainingModel(model, options=options, optimizer=opt)

# The variable entity_embedding needs to hold different values on each replica,
# corresponding to the shards of the entity embedding table
poptorch_model.entity_embedding.replicaGrouping(
            poptorch.CommGroupType.NoGrouping,
            0,
            poptorch.VariableRetrievalMode.OnePerGroup,
        )

# Graph compilation
_ = batch.pop("triple_mask")
res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})

n_epochs = 25

for ep in range(n_epochs):
    ep_start_time = time.time()
    ep_log = []
    for batch in iter(train_dl):
        step_start_time = time.time()
        triple_mask = batch.pop("triple_mask")
        res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})
        ep_log.append(dict(loss=res["loss"], step_time=(time.time()-step_start_time)))
    print(f"Epoch {ep+1} loss: {torch.concat([v['loss'] for v in ep_log]).mean().item():.6f}")
    print(f"Epoch duration (sec): {(time.time() - ep_start_time):.5f}, average step time (sec): {np.mean([v['step_time'] for v in ep_log]):.5f}")

poptorch_model.detachFromDevice()
del train_dl

For validation, for each query we want to predict the top-10 most likely tails among all the 123k+ entities in the knowledge graph. We can do so by using the `TopKQueryBessKGE` distribution scheme of BESS.

Since there are no specific tail candidates to sample, our dataloader does not need to pass negative indices to the IPUs. We therefore use the `PlaceholderNegativeSampler` class.

When using `TopKQueryBessKGE` we partition triples based just on the shard of the head entitiy (or the tail entity, if we wanted to predict heads), specifying `partition_mode='h_shard'` when constructing the `PartitionedTripleSet`. Moreover, we set the option `return_triple_idx=True` to return the indices of the triples in the batch with respect to the ordered list `validation_triples.triples`.

In [None]:
device_iterations = 3
shard_bs = 480

validation_triples = PartitionedTripleSet.create_from_dataset(yago, "validation", sharding, partition_mode="h_shard")
candidate_sampler = PlaceholderNegativeSampler(corruption_scheme="t", seed=seed)
bs_valid = RigidShardedBatchSampler(partitioned_triple_set=validation_triples, negative_sampler=candidate_sampler, shard_bs=shard_bs, batches_per_step=device_iterations,
                                    seed=seed, duplicate_batch=False, return_triple_idx=True)

print("Number of triples per h_shard:")
print(validation_triples.triple_counts)

In [None]:
val_options = poptorch.Options()
val_options.replication_factor = n_shard
val_options.deviceIterations(bs_valid.batches_per_step)
val_options.outputMode(poptorch.OutputMode.All)

valid_dl = bs_valid.get_dataloader(options=val_options, shuffle=False, num_workers=5, persistent_workers=True)

# Example batch
batch = next(iter(valid_dl))
for k,v in batch.items():
    print(f"{k:<12} {str(v.shape):<30}")

We see that each of the 3 batches returned by a call to the dataloader has a microbatch size of 480 triples. Notice that, when constructing the partitioned triple set with `partition_mode='h_shard'`, the tensor `head` will contain the **local** entity IDs in the corresponding shard, while `tail` contains the **global** IDs of the ground truth tails:

In [None]:
triple_sorted = yago.triples["validation"][validation_triples.triple_sort_idx]
triple_sorted[:,0] = sharding.entity_to_idx[triple_sorted[:,0]]
np.all(triple_sorted == validation_triples.triples)

Let us now compile the inference `TopKQueryBessKGE` model. We use the `Evaluation` class to specify which metrics we want to compute (see `besskge.metric`)

In [None]:
# Set worst_rank_infty=True to assign a reciprocal rank of 0 if the ground truth tail is not among the top-10 predicted tails (otherwise the reciprocal rank would be 1/11).
evaluation = Evaluation(["mrr", "hits@3", "hits@10"], worst_rank_infty=True)
inf_model = TopKQueryBessKGE(k=10, sharding=sharding, candidate_sampler=candidate_sampler, score_fn=complex_score_fn, evaluation=evaluation)

poptorch_inf_model = poptorch.inferenceModel(inf_model, options=val_options)

poptorch_inf_model.entity_embedding.replicaGrouping(
            poptorch.CommGroupType.NoGrouping,
            0,
            poptorch.VariableRetrievalMode.OnePerGroup,
        )

_ = batch.pop("triple_mask")
_ = batch.pop("triple_idx")
res = poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})

Let us now iterate over the validation set to compute the predictions and the corresponding metrics.

In [None]:
val_log = []
start_time = time.time()
n_val_queries = 0
for batch_val in iter(valid_dl):
    triple_mask = batch_val.pop("triple_mask")
    triple_idx = batch_val.pop("triple_idx")
    step_start_time = time.time()
    res = poptorch_inf_model(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})
    
    n_val_queries += triple_mask.sum()
    # Filter out padding triples using triple_mask
    val_log.append({k: v[triple_mask.flatten()].sum() for k, v in res["metrics"].items()})

print(f"Validation time (sec): {(time.time() - start_time):.5f}\n")

for metric in val_log[0].keys():
    reduced_metric = sum([l[metric] for l in val_log]) / n_val_queries
    print("%s : %f" % (metric, reduced_metric))

A mean reciprocal rank of 0.2 means that, on average, the correct tail is the 5th most likely one predicted by the model.

We can check that this is correct by computing the MRR directly on the un-sharded triples, on CPU...

In [None]:
ent_table = complex_score_fn.entity_embedding.detach()[sharding.entity_to_shard, sharding.entity_to_idx]
rel_table = complex_score_fn.relation_embedding.detach()

scores = complex_score_fn.score_tails(ent_table[yago.triples["validation"][:,0]], torch.from_numpy(yago.triples["validation"][:,1]), ent_table.unsqueeze(0))
top_k = torch.topk(scores, dim=-1, k=10)
rec_rank_true = evaluation.metrics_from_indices(torch.from_numpy(yago.triples["validation"][:,2]), top_k.indices.squeeze())["mrr"]

print(f"CPU validation MRR: {rec_rank_true.mean():.6f}")

...and have a look at some of the predictions made by the model

In [None]:
def check_prediction(val_triple_id):
    # Recover the non-padding triples seen in the last batch using triple_idx and triple_mask
    triples = yago.triples["validation"][validation_triples.triple_sort_idx][triple_idx[triple_mask]]
    h,r,t = triples[val_triple_id]
    # Top-10 tails predicted by the KGE model
    top10_t = res["topk_global_id"][triple_mask.flatten()][val_triple_id]
    
    print(f'Example query: ({yago.entity_dict[h]}, {yago.relation_dict[r]}, ?)\n')
    print(f"Correct tail: {yago.entity_dict[t]}\n")
    print(f"10 most likely predicted tails:")
    for i, pt in enumerate(top10_t):
        print(f"{i+1}) {yago.entity_dict[pt]}" + ("   <-----" if pt == t else ""))
    print("\n")

check_prediction(1615)
check_prediction(1990)