# Train a Retriever in a Distant Supervision setting

**Why Distant Supervision**

A retriever has to know the correct path from the question entities to answer entities. A traditional approach for collecting training data thus is to manually annotate a reasoning path in the knowledge graph. Such annotation is much more complex than traditional annotation tasks such as classification. It is also very expensive and time-consuming.

An alternative approach to generate training data is **distant supervision**. In this setting, given question entities and answer entities, we use the shortest path between them as the reasoning path. This is a reasonable assumption because the shortest path is often the most relevant path. This approach is much cheaper than manual annotation and can be easily scaled to large datasets.

**Data**

- Knowledge graph: [Wikidata](https://www.wikidata.org)
- Dataset: [Mintaka](https://huggingface.co/datasets/AmazonScience/mintaka)
- Note: In Mintaka, question entities and answer entities are annotated, but the path is not known.

## Step 0. Preparation

Install & import dependencies

In [None]:
!pip install srtk datasets

In [None]:
import os
from pathlib import Path
from pprint import pprint

import srsly
from datasets import load_dataset
from tqdm import tqdm

Run Wikidata SPARQL endpoint. Please refer to [Setup Wikidata](https://srtk.readthedocs.io/en/latest/setups/setup_wikidata.html) for setup instructions. We assume that:

- the SPARQL endpoint service deployed at: `http://localhost:1234/api/endpoint/sparql`.

Define paths and other constant variables.

In [None]:
data_root = Path('data/mintaka/')
data_root.mkdir(parents=True, exist_ok=True)
converted_dataset_path = data_root / 'dataset.jsonl'

## Step 1. Prepare Training Data

In this step, we load and convert Mintaka dataset to the required format. Each sample of the training data should be prepared in the following format:

```json
{
  "id": "sample-id",
  "question": "Which universities did Barack Obama graduate from?",
  "question_entities": [
    "Q76"
  ],
  "answer_entities": [
    "Q49122",
    "Q1346110",
    "Q4569677"
  ]
}
```



### 1.1 Load Mintaka Dataset

In [None]:
# Load the dataset from huggingface datasets
mintaka = load_dataset('AmazonScience/mintaka', split='train')
# Show the metadata of the dataset
print(mintaka)
# Examine a sample
mintaka[0]

No config specified, defaulting to: mintaka/en
Found cached dataset mintaka (/home/wiss/liao/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d)


Dataset({
    features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
    num_rows: 14000
})


{'id': 'a9011ddf',
 'lang': 'en',
 'question': 'What is the seventh tallest mountain in North America?',
 'answerText': 'Mount Lucania',
 'category': 'geography',
 'complexityType': 'ordinal',
 'questionEntity': [{'name': 'Q49',
   'entityType': 'entity',
   'label': 'North America',
   'mention': 'North America',
   'span': [40, 53]},
  {'name': '7',
   'entityType': 'ordinal',
   'label': '',
   'mention': 'seventh',
   'span': [12, 19]}],
 'answerEntity': [{'name': 'Q1153188', 'label': 'Mount Lucania'}]}

### 1.2 Convert and Filter Data

The Mintaka dataset contains questions that may not have annotations for question entities and answer entities. Therefore, we exclude any samples where either the question entities or answer entities are not annotated, or are not in the form of Wikidata entities.

In [None]:
skipped = 0
processed_samples = []
for sample in tqdm(mintaka, desc='Preparing samples'):
    question_entities = [e['name'] for e in sample['questionEntity'] if e['entityType']=='entity']
    answer_entities = [e['name'] for e in sample['answerEntity']]
    if len(question_entities) == 0 or len(answer_entities) == 0:
        skipped += 1
        continue
    processed_sample = {
        'id': str(sample['id']),
        'question': sample['question'],
        'question_entities': question_entities,
        'answer_entities': answer_entities,
    }
    processed_samples.append(processed_sample)

srsly.write_jsonl(converted_dataset_path, processed_samples)
print(f'Processed {len(processed_samples)} samples, skipped {skipped} samples, total {len(mintaka)} samples')
print(f'Output saved to {converted_dataset_path}')

Preparing samples: 100%|██████████| 14000/14000 [00:01<00:00, 8354.38it/s]

Processed 9880 samples, skipped 4120 samples, total 14000 samples
Output saved to data/mintaka/dataset.jsonl





## Step 2. Preprocess the Training Data


To streamline the preprocessing of training data, `srtk preprocess` can be used. This command performs the following operations:

1. It searches for the shortest paths between `question_entities` and `answer_entities` in the knowledge graph. These paths consist of a chain of relations.
2. The paths are then scored based on the Jaccard score between the answer entities and the entities derived from the question entities along the path.
3. Negative sampling of relations is then performed.
4. Finally, training samples are generated. Each sample consists of:
    - a question plus previous relations
    - the next positive relation
    - k negative relations (where k defaults to 15).

As a result, three files are generated in the output directory:

- `paths.jsonl`: contains the shortest paths between question entities and answer entities.
- `scores.jsonl`: contains the scores of the paths.
- `train.jsonl`: contains the training samples, in which negative samples are also included.


For more information on the preprocessing options, use the command `srtk preprocess --help`. Additional details about the preprocessing pipeline can be found in the [Preprocessing API documentation](https://srtk.readthedocs.io/en/latest/cli.html#srtk-preprocess).

### 2.1 Preprocess and create the Training Data

In [None]:
!srtk preprocess --input $converted_dataset_path \
    --output-dir $data_root \
    --sparql-endpoint http://localhost:1234/api/endpoint/sparql \
    --knowledge-graph wikidata

Searching paths: 100%|██████████████████████| 9880/9880 [06:30<00:00, 25.30it/s]
Processed 5711 samples; skipped 4169 samples without any paths between question entities and answer entities; total 9880 samples
Retrieved paths saved to data/mintaka/paths.jsonl
Scoring paths: 100%|███████████████████████████████████████| 5711/5711 [01:16<00:00, 74.87it/s]
Scored paths saved to data/mintaka/scores.jsonl
Negative sampling: 100%|████████████████████| 5711/5711 [03:24<00:00, 27.89it/s]
Number of training records: 25107
Converting relation ids to labels: 100%|█| 25107/25107 [00:39<00:00, 628.79it/s]
Training samples are saved to data/mintaka/train.jsonl


### 2.2 Inspect the training data that we created:

In [None]:
!head -n 1 data/mintaka/train.jsonl | jq

[1;39m{
  [0m[34;1m"query"[0m[1;39m: [0m[0;32m"Which actor starred in Vanilla Sky and was married to Katie Holmes? [SEP] "[0m[1;39m,
  [0m[34;1m"positive"[0m[1;39m: [0m[0;32m"spouse"[0m[1;39m,
  [0m[34;1m"negatives"[0m[1;39m: [0m[1;39m[
    [0;32m"cast member"[0m[1;39m,
    [0;32m"composer"[0m[1;39m,
    [0;32m"child"[0m[1;39m,
    [0;32m"eye color"[0m[1;39m,
    [0;32m"filming location"[0m[1;39m,
    [0;32m"given name"[0m[1;39m,
    [0;32m"instance of"[0m[1;39m,
    [0;32m"different from"[0m[1;39m,
    [0;32m"child"[0m[1;39m,
    [0;32m"given name"[0m[1;39m,
    [0;32m"original language of film or TV show"[0m[1;39m,
    [0;32m"described by source"[0m[1;39m,
    [0;32m"distributed by"[0m[1;39m,
    [0;32m"distributed by"[0m[1;39m,
    [0;32m"screenwriter"[0m[1;39m
  [1;39m][0m[1;39m
[1;39m}[0m


## Step 3. Train the retriever

With `srtk train`, you can train a retriever (i.e. a relation scorer) with a single command. The improtant arguments inlcude:

- `input`: This specifies the path of the training data generated in the previous step.
- `model_name_or_path`: This specifies the pretrained model to be used, and can be any HuggingFace model identifier or a local path to a model.
- `accelerator`: This specifies the accelerator to be used, and can be `cpu`, `gpu`, or `tpu`.
- `output_dir`: This specifies the directory where the trained model will be saved. The model is saved in HuggingFace model format, which can be uploaded to the HuggingFace hub and shared with the community.

Additionally, common training arguments like `max_epochs` and `batch_size` can also be passed to the command.

Internally, a [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) trainer is used to train the model, which is also wrapped with PyTorch Lightning Module. [Wandb](https://wandb.ai/) logger is used to log the training progress and metrics.

For more information on the training options, use the command `srtk train --help`. Additional details about the training pipeline can be found in the [Training API documentation](https://srtk.readthedocs.io/en/latest/cli.html#srtk-train).

In [None]:
!export TOKENIZERS_PARALLELISM=false

In [None]:
!srtk train --input data/mintaka/train.jsonl \
    --output-dir artifacts/mintaka \
    --model-name-or-path smallbenchnlp/roberta-small \
    --accelerator gpu \
    --fast-dev-run

Some weights of the model checkpoint at smallbenchnlp/roberta-small were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at smallbenchnlp/roberta-small and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fo