## KGE training and inference on OGB-BioKG

In this notebook we will learn how to use the BESS-KGE library to train KGE models and perform link prediction inference, using the public biomedical dataset [OGB-BioKG](https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg).

In [None]:
!pip install git+ssh://git@github.com/graphcore-research/bess-kge.git

In [None]:
import ctypes
import pathlib
import time

import numpy as np
import poptorch
import torch

from besskge.batch_sampler import RigidShardedBatchSampler
from besskge.bess import EmbeddingMovingBessKGE, ScoreMovingBessKGE
from besskge.dataset import KGDataset
from besskge.embedding import MarginBasedInitializer
from besskge.loss import LogSigmoidLoss
from besskge.metric import Evaluation
from besskge.negative_sampler import (
    RandomShardedNegativeSampler,
    TripleBasedShardedNegativeSampler,
)
from besskge.scoring import TransE
from besskge.sharding import PartitionedTripleSet, Sharding

The OGB-BioKG dataset can be downloaded and preprocessed with the built-in method of `KGDataset`. `KGDataset` is the standard class which holds data from the KG dataset, such as head-relation-tail triples (with entities and relation types suitably converted to their integer IDs), triple-specific data (e.g. negative heads/tails to be used to corrupt the triple), ID -> label lists, etc.

In [None]:
biokg = KGDataset.build_biokg(root=pathlib.Path("../datasets/biokg"))

print(f"Number of entities: {biokg.n_entity}\n")
print(f"Number of relation types: {biokg.n_relation_type}\n")
print(f"Number of triples: \n training: {biokg.triples['train'].shape[0]} \n validation/test: {biokg.triples['valid'].shape[0]}\n")
print(f"Number of negative heads/tails for validation/test triples: {biokg.neg_heads['valid'].shape[-1]}")

Entities in the BioKG dataset have different types. Entity IDs need to be always assigned so that entities of the same type have contigous IDs. Then, the ID offsets corresponding to different types are stored in `KGDataset.type_offsets`:

In [None]:
biokg.type_offsets

meaning that entities with ID in the range [0, 10686] are of type 'disease', etc.

To train on 4 IPUs, we shard the entity set in 4 parts of equal size. This is done using the `Sharding` class.

In [None]:
seed = 1234
n_shard = 4

sharding = Sharding.create(n_entity=biokg.n_entity, n_shard=n_shard, seed=seed, type_offsets=np.fromiter(biokg.type_offsets.values(), dtype=np.int32))

print(f"Number of shards: {sharding.n_shard}\n")

print(f"Number of entities in each shard: {sharding.max_entity_per_shard}\n")

print(f"Entity sharding:\n {sharding.shard_and_idx_to_entity}\n")

# If the number of entities is not divisibile by n_shard, some shards will have one trailing padding entity (ID >= n_entity)
print(f"Number of actual (=non-padding) entities per shard:\n {sharding.shard_counts}\n")

# Entities of the same type mantain contigous local IDs in each shard
print(f"Type offsets per shard: \n", sharding.entity_type_offsets)

The entity sharding induces a partitioning of the set of training triples into 4*4=16 **shardpairs**, based on the entity shard of the head and of the tail. Triple partitioning is performed using the `PartitionedTripleSet` class.

In [None]:
train_triples = PartitionedTripleSet.create_from_dataset(dataset=biokg, part="train", sharding=sharding, partition_mode="ht_shardpair")

# train_triples.triple_counts[i,j] is the number of triples with head entity in shard i and tail entitiy in shard j
print(f"Number of triples per (h,t) shardpair:\n {train_triples.triple_counts}")

Triples in `train_triples.triples` have been sorted based on their shardpair: `(0,0), (0,1), ..., (n_shard-1, n_shard-1)`. One can use `train_triples.triple_sort_idx` to recover the original ordering.

Moreover, the global entity ID for heads and tails has been replaced with the local index on the corresponding shard.

In [None]:
triple_sorted = biokg.triples["train"][train_triples.triple_sort_idx]
triple_sorted[:,0] = sharding.entity_to_idx[triple_sorted[:,0]]
triple_sorted[:,2] = sharding.entity_to_idx[triple_sorted[:,2]]
np.all(triple_sorted == train_triples.triples)

A key component in KGE models training is the selection of negative samples to contrast against each positive triple. A standard strategy is to corrupt either the head or the tail of the positive triple.

**Negative samplers** are implemented in `besskge.negative_sampler`. Two suitable ones are the following:

* `RandomShardedNegativeSampler` randomly picks the corrupted entity among all the entities in the KG;
* `TypeBasedShardedNegativeSampler` selects the corrupted entity only among entities of the same type as the original one.

By default, BESS samples the same number of negative entities from each entity shard, to minimize selection bias. For each pair of devices, the same amount of data is exchanged in both directions, through collective operators.

* If `flat_negative_format=False`, negative entities are sampled on a triple basis. For each positive triple in a microbatch, `n_negative` corrupted entities are received from each device, for a total of `shard_bs * n_negative * n_shard` negatives used in the microbatch. We can then decide whether to score each triple only against the corresponding `n_negative * n_shard` negative samples, or against all negatives seen in the microbatch (negative sample sharing).

* If `flat_negative_format=True`, negative entities are sampled on a shardpair basis. Each device receives `n_negative` negatives from each shard, for a total of `n_negative * n_shard` negatives used in the microbatch. As the negatives are not sampled on a triple basis, this requires the use of negative sample sharing.

In order to reduce inter-device communication, we have the option of sampling all negatives just from the device where the triple loss will be computed, by setting `local_sampling=True`.

When instantiating the negative sampler, we must also specify the corruption scheme:
* `corruption_scheme='h'` if negative samples are to be constructed by corrupting the head entity;
* `corruption_scheme='t'` if negative samples are to be constructed by corrupting the tail entity;
* `corruption_scheme='ht'` if negative samples are to be constructed by corrupting the head entity for half the triples in the microbatch and the tail entity for the other half (in this case, when using negative sample sharing, negatives are shared only among the triples in the same half).

In [None]:
neg_sampler = RandomShardedNegativeSampler(n_negative=1, sharding=sharding, seed=seed, corruption_scheme="ht",
                                           local_sampling=False, flat_negative_format=False)

In order to start training, we only miss one component: a **batch sampler**. This class is responsible for cooking up the batches to pass to each device: this is not a trivial task, since at each step each device needs to know which embeddings stored in its local memory are needed by itself and by all other devices.
Different types of batch samplers are implemented in `besskge.batch_sampler`. All of them, at each step, sample the same number of triples from each of the 16 shardpair buckets. Here we use `RigidShardedBatchSampler`, where each bucket is consumed sequentially. The length of an epoch is then dictated by the length of the largest bucket.

In [None]:
device_iterations = 8
accum_factor = 6
# Microbatch size, i.e. number of positive triples processed on each device at each step
shard_bs = 240

batch_sampler = RigidShardedBatchSampler(partitioned_triple_set=train_triples, negative_sampler=neg_sampler,
                              shard_bs=shard_bs, batches_per_step=device_iterations*accum_factor, seed=seed)


print(f"# triples per shardpair per step: {batch_sampler.positive_per_partition} \n")

# Example batch
idx_sampler = iter(batch_sampler.get_dataloader_sampler(shuffle=True))
for k,v in batch_sampler[next(idx_sampler)].items():
    print(f"{k:<12} {str(v.shape):<30} {v.dtype};")

We see that each call returns `device_iterations * accum_factor = 48` batches. Each of them is composed of `60` triples from each of the 16 shardpair buckets. In particular, IPU `i` will process 240 triples, 60 from each of the four (h, t) shardpairs `(i,0), (i,1), (i,2), (i,3)`. `head[:,i,:,:], tail[:,i,:,:]` are the entity/relation IDs of the embeddings that need to be gathered from the SRAM of IPU `i`.

`negative[:,i,j,t,:]` are the negative entities sampled on IPU `i` to be used for triple `t` on IPU `j` (where the trailing 1 is the `n_negative` specified in the negative sampler).

`triple_mask` is a boolean mask telling which of the positive triples in the batch are non-padding (since, as mentioned above, `RigidShardedBatchSampler` will repeat triples in smaller shardpair buckets during an epoch).

We can use the method `get_dataloader` of the batch sampler to build the PopTorch dataloader which we iterate over during training.

In [None]:
options = poptorch.Options()
options.replication_factor = sharding.n_shard
options.deviceIterations(device_iterations)
options.Training.gradientAccumulation(accum_factor)
# Add a memory saving optimisation pattern. This removes an unnecessary
# entity_embedding gradient all-reduce, which is a no-op since it is fully
# sharded across replicas.
options._popart.setPatterns(dict(RemoveAllReducePattern=True))

train_dl = batch_sampler.get_dataloader(options=options, shuffle=True, num_workers=5, persistent_workers=True)

We are now ready to define the model and train it. We will use **TransE** with 128-dimensional embeddings and the **logsigmoid** loss function with negative adversarial sampling. We will also use **negative sample sharing** within the microbatches.

The BESS distribution scheme is implemented in `besskge.bess` with different flavours. Here we use the basic `EmbeddingMovingBessKGE` class, where entity embeddings are exchanged between IPUs (from the one where the embedding is stored to the one where it is needed for computation) through AllToAll collectives.

In [None]:
logsigmoid_loss_fn = LogSigmoidLoss(margin=12.0, negative_adversarial_sampling=True)
emb_initializer = MarginBasedInitializer(margin=logsigmoid_loss_fn.margin)
transe_score_fn = TransE(negative_sample_sharing=True, scoring_norm=1, sharding=sharding,
                  n_relation_type=biokg.n_relation_type, embedding_size=128,
                  entity_intializer=emb_initializer, relation_intializer=emb_initializer)

model = EmbeddingMovingBessKGE(sharding=sharding, negative_sampler=neg_sampler, score_fn=transe_score_fn,
                               loss_fn=logsigmoid_loss_fn)

opt = poptorch.optim.AdamW(
        model.parameters(),
        lr=0.001,
        weight_decay=0.0,
        accum_type=torch.float32,
        first_order_momentum_accum_type=torch.float32,
        second_order_momentum_accum_type=torch.float32,
    )

poptorch_model = poptorch.trainingModel(model, options=options, optimizer=opt)

# The variable entity_embedding needs to hold different values on each replica,
# corresponding to the shards of the entity embedding table
poptorch_model.entity_embedding.replicaGrouping(
            poptorch.CommGroupType.NoGrouping,
            0,
            poptorch.VariableRetrievalMode.OnePerGroup,
        )

# Compile model
batch = next(iter(train_dl))
_ = batch.pop("triple_mask")
res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})

In [None]:
# Train for 15 epochs
n_epochs = 15

for ep in range(n_epochs):
    ep_start_time = time.time()
    ep_log = []
    for batch in train_dl:
        step_start_time = time.time()
        triple_mask = batch.pop("triple_mask")
        res = poptorch_model(**{k: v.flatten(end_dim=1) for k, v in batch.items()})
        ep_log.append(dict(loss=res["loss"], step_time=(time.time()-step_start_time)))
    print(f"Epoch {ep+1} loss: {torch.concat([v['loss'] for v in ep_log]).mean().item():.6f}")
    print(f"Epoch duration (sec): {(time.time() - ep_start_time):.5f}, average step time (sec): {np.mean([v['step_time'] for v in ep_log]):.5f}")

poptorch_model.detachFromDevice()
del train_dl

Let's see how the trained model performs on the **validation** set. 

We create a new `PartitionedTripleSet` with the validation triples. Since now, differently from the training set, each triple has a specific set of 500 negative heads/tails to be scored against, we use the `TripleBasedShardedNegativeSampler` class of negative sampler.

For the batch sampler, we use again a `RigidShardedBatchSampler` but we now set the option `duplicate_batch=True`: this means that the two halves of the microbatch (where we corrupt heads and tails respectively) are based on the same positive triples, so that we can score negative heads and negative tails with the same model.

In [None]:
valid_triples = PartitionedTripleSet.create_from_dataset(dataset=biokg, part="valid", sharding=sharding, partition_mode="ht_shardpair")
ns_valid = TripleBasedShardedNegativeSampler(negative_heads=valid_triples.neg_heads, negative_tails=valid_triples.neg_tails,
                                             sharding=sharding, corruption_scheme="ht", seed=seed)
bs_valid = RigidShardedBatchSampler(partitioned_triple_set=valid_triples, negative_sampler=ns_valid, shard_bs=shard_bs, batches_per_step=10,
                                    seed=seed, duplicate_batch=True)

# Example batch
idx_sampler = iter(bs_valid.get_dataloader_sampler(shuffle=False))
for k,v in bs_valid[next(idx_sampler)].items():
    print(f"{k:<15} {str(v.shape):<35} {v.dtype};")

We see that the `negative` tensor returned now by the dataloader has trailing dimension 175, meaning that each validation query is scored against `4*175` negative heads/tails, which is larger than 500. This is due to the fact that the triple-specific negatives will not be equally split between the 4 shards, therefore some padding needs to be applied. The `negative_mask` returned by the negative sampler is used to identify the padding negative entities, so that the corresponding scores can be filtered out when computing the metrics.

We can now instantiate the inference model. We use the `besskge.metric.Evaluation` class to specify which metrics we want to compute and pass it to the BESS module. We now use a different flavour of BESS compared to training, namely `ScoreMovingBessKGE`: this is recommended when the number of negative entities to be fetched from other devices is large, as it is typically the case when using `TripleBasedShardedNegativeSampler`. While `EmbeddingMovingBessKGE` sends the negative embeddings to the device where the positive triple is scored, `ScoreMovingBessKGE` fetches the queries with an AllGather and computes negative scores on the shard where the negative entities are stored, and then sends back the scores to the original device. This allows to communicate scores instead of embeddings, which is usually cheaper, although it requires additional collective communications between devices. We encourage the reader to play with different configurations to see which one gives the shorter overall validation time.

In [None]:
val_options = poptorch.Options()
val_options.replication_factor = sharding.n_shard
val_options.deviceIterations(bs_valid.batches_per_step)
val_options.outputMode(poptorch.OutputMode.All)

valid_dl = bs_valid.get_dataloader(options=val_options, shuffle=False, num_workers=5, persistent_workers=True)

# Each triple is now to be scored against a specific set of negatives, so we turn off negative sample sharing
transe_score_fn.negative_sample_sharing = False
evaluator = Evaluation(["mrr", "hits@1", "hits@5", "hits@10"])
model_inf = ScoreMovingBessKGE(sharding=sharding, negative_sampler=ns_valid, score_fn=transe_score_fn, evaluation=evaluator)

poptorch_model_inf = poptorch.inferenceModel(model_inf, options=val_options)

poptorch_model_inf.entity_embedding.replicaGrouping(
            poptorch.CommGroupType.NoGrouping,
            0,
            poptorch.VariableRetrievalMode.OnePerGroup,
        )

# Compile model
batch = next(iter(valid_dl))
_ = batch.pop("triple_mask")
res = poptorch_model_inf(**{k: v.flatten(end_dim=1) for k, v in batch.items()})

In [None]:
# Perform validation and print metrics

val_log = []
val_time = []
start_time = time.time()
# The final value of n_val_queries will be twice the number of triples in the validation set
# as each triple is scored against negative heads and negative tails separately
n_val_queries = 0
for batch_val in valid_dl:
    triple_mask = batch_val.pop("triple_mask")
    step_start_time = time.time()
    res = poptorch_model_inf(**{k: v.flatten(end_dim=1) for k, v in batch_val.items()})
    
    n_val_queries += triple_mask.sum()
    # Filter out padding triples using triple_mask
    val_log.append({k: v[triple_mask.flatten()].sum() for k, v in res["metrics"].items()})
    val_time.append(time.time()-step_start_time)

print(f"Validation time (sec): {(time.time() - start_time):.5f}, average step time (sec): {np.mean(val_time):.5f}\n")

for metric in val_log[0].keys():
    reduced_metric = sum([l[metric] for l in val_log]) / n_val_queries
    print("%s : %f" % (metric, reduced_metric))

What next? You can try training with different combinations of KGE models, embedding sizes and loss functions (see `besskge.scoring` and `besskge.loss`), change the number of negative samples and the sampling scheme (for instance, try training with `TypeBasedShardedNegativeSampler`).

If you are interested in using these models to make predictions based on incomplete queries, have a look at the [yago_topk_prediction](yago_topk_prediction.ipynb) notebook.