# Knowledge Graph Embedding with TorchKGE

In this tutorial, we will explore how to train a knowledge graph embedding model to perform link prediction in biological knowledge graphs.

We will consider a knowledge graph composed of nodes representing "genes/proteins," "diseases," and "drugs," and containing interactions such as "drug-protein," "disease-protein," and "drug-disease." Our goal will be to predict new interactions of these types.

## Conda environment

To get started, we need to set up a conda environment that will contain the necessary libraries. Create the conda environment with the following commands:

In [None]:
conda create --name torch_pyg python=3.10
conda activate torch_pyg
pip install torch
pip install torchkge
pip install pandas matplotlib numpy pyyaml tqdm
pip install pytorch-ignite
pip install ipykernel

## Knowledge Graph Preprocessing

### Data loading

The graph is stored in `TSV` format. Let's start by loading the file using the `pandas` library.

In [None]:
### DIY ###

From this dataframe, we can instantiate a knowledge graph using [the `KnowledgeGraph` class from the TorchKGE library](https://torchkge.readthedocs.io/en/latest/reference/data.html).


In [None]:
### DIY ###

Explore the KnowledgeGraph class. What does `ent2ix` and `rel2ix` contains? How many entities and relations are contained in the KG? What are the methods associated to the KG class?

**Answers here: ..............**

Split the KG into a training set, a validation set and a test set. 

In [None]:
### DIY ###

What are the training, validation and test sets in KGE? Why do we need such sets?

**Answers here: ..............**

## Knowledge Graph Embedding

#### Instanciating the KGE model

Choose a [model implemented in TorchKGE](https://torchkge.readthedocs.io/en/latest/reference/models.html). What hyperparameters do you need to define? Instanciate the chosen model with is corresponding parameters.

In [None]:
### DIY ###

Define the [loss function to train the model](https://torchkge.readthedocs.io/en/latest/reference/utils.html#losses).
For Translational Models, the loss should be a MarginLoss. For Bilinear models, the loss should be a BinaryCrossEntropyLoss. Depending on the chosen loss, do you need to define new hyperparameters? Which ones?

In [None]:
### DIY ###

#### Defining an optimizer

What is an Optimizer?

**Answers here: ..............**

 Check the documentation of the [Adam optimize from Torch](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) and instanciate one. Which hyperparameters do you need to define ?

In [None]:
### DIY ###

#### Defining a Negative Sampler

What is Negative Sampling in KGE? Check [TorchKGE's documentation on negative sampling](https://torchkge.readthedocs.io/en/latest/reference/sampling.html), and check the various implementations available. What is the difference between those samplers? Chose one and instanciate it.

**Answers here: ..............**

In [None]:
### DIY ###

#### Optional: Defining a Learning Rate Scheduler

A learning rate scheduler is a tool that adjusts the learning rate during training to improve model convergence and performance. It typically decreases the learning rate over time or based on certain conditions, helping the model settle into an optimal solution. 

A learning rate scheduler adjusts the learning rate during training, with many types available to suit different training needs. Here, we focus on [`lr_scheduler.CosineAnnealingWarmRestarts`](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html), which gradually decreases the learning rate following a cosine curve, then "restarts" it at a higher rate periodically to allow the model to explore new solutions and avoid local minima.

In [None]:
### DIY ###

#### Instanciating a Data Loader

Right now, we have defined:
* the KGE model
* an optimizer
* a negative sampler
* (optionnally) a learning rate scheduler

Now, we need to define the trainning hyperparameters:
* number of training epochs
* training batch size

In [1]:
### DIY ###

TorchKGE proposes a [`DataLoader`](https://github.com/torchkge-team/torchkge/blob/master/torchkge/utils/data.py) to read and pass the data from the training set to the model. It takes as input:
* the training KG
* the batch size for loading the data in batches

Instanciate it. 

In [None]:
### DIY ###

#### Defining the function for processing batches

Now, that we have defined the DataLoader, we need to define the operations to perform on each batch. For this, we will use an [`Engine`](https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine) from Pytorch-Ignite.


In PyTorch-Ignite, an `Engine` is a core component used to abstract and manage the process of training or evaluating a model, making it easier to handle the loop over data batches and implement complex training workflows. The `Engine` runs a given `process_function`, which defines the specific steps to perform on each batch. 

With `Engine`, you can also attach event handlers at specific stages in the loop, like at the beginning or end of an epoch or after processing each batch. This enables you to incorporate logging, metrics, scheduling, and other functionalities in a structured way, making it highly flexible for building custom training and evaluation pipelines.

First, let's define the function for processing a batch.

In [None]:
def process_batch(engine, batch):

    # Unpack the batch into head, tail, and relation
    h, t, r = batch[0], batch[1], batch[2]
    
    # Generate corrupted (negative) samples using the sampler
    n_h, n_t = sampler.corrupt_batch(h, t, r)

    # Clear previous gradients in the optimizer
    optimizer.zero_grad()

    # Calculate the loss using positive and negative triplets
    pos, neg = model(h, t, r, n_h, n_t)
    loss = criterion(pos, neg)
    loss.backward()  # Perform backpropagation to compute gradients
    
    # Update model parameters with the optimizer
    optimizer.step()

    # Normalize model parameters 
    model.normalize_parameters()

    # Return the loss value for this batch
    return loss.item()

# Attaching the process_batch function to the training Engine
trainer = Engine(process_batch)

# Attach a running average of the loss to the trainer for tracking average loss over time
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss_ra')


The `process_batch` function defines a single training step, where it processes a batch of data to train a knowledge graph embedding model. First, it takes a batch containing head (`h`), tail (`t`), and relation (`r`) entities and generates corrupted (negative) examples (`n_h`, `n_t`) using the `corrupt_batch` function from our defined negative sampler. It then clears the gradients in the optimizer (`optimizer.zero_grad()`), calculates the loss using both positive and negative triplets (`pos` and `neg`) with the specified criterion, and performs backpropagation (`loss.backward()`) to compute gradients. Afterward, the optimizer updates the model parameters (`optimizer.step()`), and the model parameters are normalized (`model.normalize_parameters()`). Finally, the function returns the loss for this batch.

The line `RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss_ra')` serves to calculate a smoothed, running average of the loss over batches during training. This running average is attached to the `trainer` engine with name `loss_ra`. This running average will be accessible after each epoch with `trainer.state.metrics['loss_ra']`.


Look at [TorchKGE's documentation for models](https://torchkge.readthedocs.io/en/latest/reference/models.html), and explain why we run `pos, neg = model(h, t, r, n_h, n_t)` in the process_batch function.

**Answers here: ..............**

Look at [TorchKGE's documentation for losses](https://torchkge.readthedocs.io/en/latest/reference/utils.html#losses), and explain why we run `loss = criterion(pos, neg)` in the process_batch function and what is the score obtained at this step.

**Answers here: ..............**

### Defining event handlers

We have defined the `trainer` engine to process data batches and compute the loss. Now, we can attach **events** to our `trainer` engine. [Various events can be included at various steps of the training](https://pytorch.org/ignite/generated/ignite.engine.events.Events.html), depending on the user's needs. 

Events are defined as follow:

```
@trainer.on(<Event-type>)
def func(engine):
    # DO STUFF
    
```


The first line, `@trainer.on(<Event-type>)`, is a **decorator** in PyTorch-Ignite that registers an event handler to execute a specific function at the end of each epoch during training. Here’s how it works:

1. `@trainer.on`: This decorator attaches the function defined immediately below it to the `trainer` engine.

2. [`<Event-type>`](https://pytorch.org/ignite/generated/ignite.engine.events.Events.html): This is the specific event, such as the end of an epoch (i.e., one complete pass through the dataset), that will trigger the function.

When applied, any function defined directly below `@trainer.on(<Event-type>)` will be executed each time `<Event-type>` occurs. This setup allows for performing actions like logging, validation, saving model checkpoints, or adjusting learning rates during training.

Create an event handler triggered at the end of each epoch to write the training metrics into a csv file. This should include the epoch number and the loss value.

In [None]:
### DIY ###

**(Optional):** If you have defined a CosineAnnealing learning rate scheduler, you need to define an event handler that updates the scheduler after each epoch is complete. To update a scheduler, you can use `scheduler.step()`.  Additionally, you can modify the handler defined above in order to also record the learning rate used at each epoch using `optimizer.param_groups[0]['lr']`.

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def update_scheduler(engine):
    scheduler.step()

### Run the training!

Let's sum up:
* we have defined a KGE model, an optimizer and created the Engine to process the data and train the model
* we have defined a data loader to send the data to the model
* we have defined an event to keep track of the loss value through the training process

We are ready to start training!

In [None]:
trainer.run(train_iterator, max_epochs=n_epochs)

### Taking a look at the training metrics

Load the training metrics file and plot the training loss evolution across training epochs. What do you think? 

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_loss(df):
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(df[0], df[1], label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss over Epochs')
    plt.show()

df = pd.read_csv('training_metrics.csv', header=None)
plot_loss(df)

**Answers here: ..............**

### Evaluating the model on Link Prediction

TorchKGE provides with [a class for Link Prediction evaluation](https://torchkge.readthedocs.io/en/latest/tutorials/evaluation.html). Use it to evaluate the performances of your trained model to predict relationships from the test set and interpret the results. What is the MRR? Hit@10?

Note: you need to put your model on evaluation mode with `model.eval()` before running the evaluation.

In [None]:
### DIY ###

**Answers here: ..............**

### Infering new interactions in the KG

Now that we have trained our embedding model and have an idea of it's performances on predicting interactions from the test set, we can start using it to infer novel associations. 

TorchKGE provides with a [class for infering a missing entity given a head (or a tail) and a relation](https://torchkge.readthedocs.io/en/latest/reference/inference.html#entity-inference).

However, the original implementation form TorchKGE suffers an issue with the evaluate function. Use the following implementation instead:

In [None]:
from torch import empty, tensor
from tqdm import tqdm  
from torchkge.utils import filter_scores
from torchkge.utils.data import get_n_batches
import torch

class DataLoader_:
    """This class is inspired from :class:`torch.utils.dataloader.DataLoader`.
    It is however way simpler.

    """
    def __init__(self, a, b, batch_size, use_cuda=None):
        """

        Parameters
        ----------
        batch_size: int
            Size of the required batches.
        use_cuda: str (opt, default = None)
            Can be either None (no use of cuda at all), 'all' to move all the
            dataset to cuda and then split in batches or 'batch' to simply move
            the batches to cuda before they are returned.
        """
        self.a = a
        self.b = b

        self.use_cuda = use_cuda
        self.batch_size = batch_size

        if use_cuda is not None and use_cuda == 'all':
            self.a = self.a.cuda()
            self.b = self.b.cuda()

    def __len__(self):
        return get_n_batches(len(self.a), self.batch_size)

    def __iter__(self):
        return _DataLoaderIter(self)


class _DataLoaderIter:
    def __init__(self, loader):
        self.a = loader.a
        self.b = loader.b

        self.use_cuda = loader.use_cuda
        self.batch_size = loader.batch_size

        self.n_batches = get_n_batches(len(self.a), self.batch_size)
        self.current_batch = 0

    def __next__(self):
        if self.current_batch == self.n_batches:
            raise StopIteration
        else:
            i = self.current_batch
            self.current_batch += 1

            tmp_a = self.a[i * self.batch_size: (i + 1) * self.batch_size]
            tmp_b = self.b[i * self.batch_size: (i + 1) * self.batch_size]

            if self.use_cuda is not None and self.use_cuda == 'batch':
                return tmp_a.cuda(), tmp_b.cuda()
            else:
                return tmp_a, tmp_b

    def __iter__(self):
        return self


class EntityInference(object):
    """Use trained embedding model to infer missing entities in triples.

    Parameters
    ----------
    model: torchkge.models.interfaces.Model
        Embedding model inheriting from the right interface.
    known_entities: `torch.Tensor`, shape: (n_facts), dtype: `torch.long`
        List of the indices of known entities.
    known_relations: `torch.Tensor`, shape: (n_facts), dtype: `torch.long`
        List of the indices of known relations.
    top_k: int
        Indicates the number of top predictions to return.
    missing: str
        String indicating if the missing entities are the heads or the tails.
    dictionary: dict, optional (default=None)
        Dictionary of possible heads or tails (depending on the value of `missing`).
        It is used to filter predictions that are known to be True in the training set
        in order to return only new facts.

    Attributes
    ----------
    predictions: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.long`
        List of the indices of predicted entities for each test fact.
    scores: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.float`
        List of the scores of resulting triples for each test fact.
    """

    def __init__(self, model, known_entities, known_relations, top_k=1, missing='tails', dictionary=None):
        assert missing in ['heads', 'tails'], "missing entity should either be 'heads' or 'tails'"
        
        self.model = model
        self.known_entities = known_entities
        self.known_relations = known_relations
        self.missing = missing
        self.top_k = top_k
        self.dictionary = dictionary

        self.predictions = empty(size=(len(known_entities), top_k), dtype=torch.long)
        self.scores = empty(size=(len(known_entities), top_k), dtype=torch.float)

    def evaluate(self, b_size, verbose=True):
        use_cuda = next(self.model.parameters()).is_cuda

        if use_cuda:
            dataloader = DataLoader_(self.known_entities, self.known_relations, batch_size=b_size, use_cuda='batch')
            self.predictions = self.predictions.cuda()
            self.scores = self.scores.cuda()
        else:
            dataloader = DataLoader_(self.known_entities, self.known_relations, batch_size=b_size)

        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader),
                             unit='batch', disable=(not verbose),
                             desc='Inference'):
            known_ents, known_rels = batch[0], batch[1]
            
            if self.missing == 'heads':
                _, t_emb, rel_emb, candidates = self.model.inference_prepare_candidates(
                    tensor([]).long(), known_ents, known_rels, entities=True
                )
                scores = self.model.inference_scoring_function(candidates, t_emb, rel_emb)
            else:
                h_emb, _, rel_emb, candidates = self.model.inference_prepare_candidates(
                    known_ents, tensor([]).long(), known_rels, entities=True
                )
                scores = self.model.inference_scoring_function(h_emb, candidates, rel_emb)

            if self.dictionary is not None:
                scores = filter_scores(scores, self.dictionary, known_ents, known_rels, None)

            scores, indices = scores.sort(descending=True)

            start_index = i * b_size
            end_index = min((i + 1) * b_size, self.predictions.size(0))
            
            self.predictions[start_index:end_index, :self.top_k] = indices[:end_index - start_index, :self.top_k]
            self.scores[start_index:end_index, :self.top_k] = scores[:end_index - start_index, :self.top_k]

        if use_cuda:
            self.predictions = self.predictions.cpu()
            self.scores = self.scores.cpu()


Define one or several incomplete triplets for which you want to predict the missing head or tail. Instanciate the corresponding `EntityInference` class and run the inference process with `.evaluate()`.

In [None]:
model = TransEModel(emb_dim, n_entities, n_relations, dissimilarity_type=dissimilarity_type)
criterion = MarginLoss(margin)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
sampler = UniformNegativeSampler(kg_train, kg_val=kg_val, kg_test=kg_test, n_neg=5)
train_iterator = DataLoader(kg_train, batch_size)
trainer = Engine(process_batch)
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss_ra')

if os.path.exists('training_metrics_with_validation.csv'):
    os.remove('training_metrics_with_validation.csv')

@trainer.on(Events.EPOCH_COMPLETED(every=10))
def evaluate(engine):
    model.eval()  # Put the model in evaluation mode
    with torch.no_grad():
        val_mrr = link_pred(model, kg_val, 32) 
    engine.state.metrics['val_mrr'] = val_mrr 

    model.train()  # Put the model back in training mode

@trainer.on(Events.EPOCH_COMPLETED)
def log_metrics_to_csv(engine):
    epoch = engine.state.epoch
    train_loss = engine.state.metrics['loss_ra']
    if 'val_mrr' in engine.state.metrics.keys():
        mrr = engine.state.metrics['val_mrr']
    else:
        mrr = 0
    lr = optimizer.param_groups[0]['lr']
    with open('training_metrics_with_validation.csv', mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch, train_loss, mrr, lr])
    print(f"Epoch: {epoch} ; Loss: {train_loss} ; Val MRR: {mrr}, LR: {optimizer.param_groups[0]['lr']}")
    
@trainer.on(Events.EPOCH_COMPLETED)
def update_scheduler(engine):
    scheduler.step()
    
trainer.run(train_iterator, max_epochs=n_epochs)

Print the results. What are the heads or tails predicted for your incomplete triplets? How are the scores computed? Are the scores convincing? The higher the better or the lower the better? 

**Answers here: ..............**

## Training a KGE model with a validation set for Link Prediction

For now, the training procedure we implemented does not make use of the validation set. In the following part, we will include a procedure to evaluate the link prediction on the validation set during the training. **The goal will be to evaluate the training not on the loss value, but on the prediction of interactions from the validation set.**

Create an event handler that evaluates the link prediction on the validation set at the end of each epoch. Adapt the training metric tracking function to include the MRR on the validation set on each epoch, and rerun the whole training procedure. You can store the MRR score in `engine.state.metrics['val_mrr']` to acess it in the logging metrics function. Don't forget to switch the model from training mode to evaluation mode with `model.eval()` and from evaluation mode to training mode with `model.train()` when necessary. 

In [None]:
### DIY ###

Nottice how evaluation is computationally intensive, making the whole training quite long...

Modify your code so that the evaluation is done only every 10 epochs instead.

In [None]:
### DIY ###

Why is it better to use link prediction evaluation on the validation set instead of tracking the loss only?

**Answers here: ..............**

Plot the training loss and the MRR over epochs and analyse the results.

In [None]:
### DIY ###

During training, you may notice the MRR decreasing after a certain number of epochs. If this happens, you could end up with a model that, by the last epoch, performs worse on the validation MRR compared to its performance a few epochs earlier. To ensure that you keep the best model in terms of MRR on the validation set, it’s essential to save the model whenever it reaches its highest score. You can achieve this using [`ModelCheckpoint` from `ignite.handlers`](https://pytorch.org/ignite/generated/ignite.handlers.checkpoint.ModelCheckpoint.html).

Here is how they work:

```
checkpoint_best_handler = ModelCheckpoint(
    dirname=checkpoint_dir,                 # Saving directory
    filename_prefix='best_model',           # Filename prefix
    n_saved=N,                              # Number of models to save (to keep the N best models)
    score_function=function_to_call,        # Function that returns the score of the model
    score_name='val_mrr',                   # Name of the score
    require_empty=False,                    # Should the saving directory be empty?
    create_dir=True,                        # Create the saving directory if it does not exists
    atomic=True                             # Ensure the file is not damaged
)
```

A checkpoint handler can be added to the training procedure with: 

```
trainer.add_event_handler(
    Events.EPOCH_COMPLETED,
    checkpoint_best_handler,
    {'model': model}
)
```

The `score_function` to pass to the checkpoint handler is a function that should return the current value of the performance metric you want to track. For instance, if you are interested into saving the model with the best MRR, the score_function should be a function that takes an engine as input and returns the current MRR as output. 

Define the score function, create the checkpoint handler, and add the corresponding event handler to the training pocedure. Then rerun the training. Since we are only performing the validation evaluation every 10 epochs, how often should the checkpoint handler be launched?

In [None]:
### DIY ###

Now, we can load the best model for evaluating on the test dataset and for infering new triplets in the KG! Checkpoints can be loaded with `torch.load(file)`. Once your checkpoint is loaded, you can use `model.load_state_dict(checkpoint)` to load the state of your best model. 

In [None]:
### DIY ###

Evaluate your best model on the test set.

In [None]:
### DIY ###

Infer new triples with your trained model.

In [None]:
### DIY ###

## To go further: a few ideas...

Training can take a long time. Take a look at [`EarlyStopping` in ignite](https://pytorch.org/ignite/v0.5.1/generated/ignite.handlers.early_stopping.EarlyStopping.html) and try to add an early stopping mechanism to your model.

Implement a function to evaluate the link prediction on each relationship type.

Extract the node embeddings and visualize the latent space.

Train a model on triplet classification instead of link prediction.

Experiment with new models, and perform a benchmark of the best embedding models.