# Train A Retriever on WebQSP with Weak Supervision

**Data**
- Knowledge Graph: Freebase
- Dataset: WebQSP

## Step 1. Prerequisite

### 1.1. Install & import dependencies

```bash
pip install srtk
```

In [None]:
import os
from pathlib import Path

import srsly
from pprint import pprint
from tqdm import tqdm

### 1.2 Define constants

In [None]:
formatted_data_dir = 'data/webqsp/formatted'
intermediate_dir = 'data/webqsp/intermediate'
dataset_dir = 'data/webqsp/dataset'
save_model_dir = 'artifacts/models/webqsp'
retrieve_subgraph_path = 'artifacts/subgraphs/webqsp.jsonl'

formatted_train_path = os.path.join(formatted_data_dir, 'train.jsonl')
formatted_test_path = os.path.join(formatted_data_dir, 'test.jsonl')
train_dataset_path = os.path.join(dataset_dir, 'train.jsonl')

### 1.3 Run SPARQL endpoint

We assume the freebase SPARQL endpoint is running at `http://localhost:3001/sparql`.

### 1.4 Download data

You may download WebQSP dataset from [Microsof download center](https://www.microsoft.com/en-us/download/details.aspx?id=52763).

```bash
mkdir -p data/webqsp/raw
wget https://download.microsoft.com/download/F/5/0/F5012144-A4FB-4084-897F-CFDA99C60BDF/WebQSP.zip -P data/webqsp
unzip data/webqsp/WebQSP.zip -d data/webqsp/raw
```

### 1.5 Format the raw data

The raw data should be formatted like this:
```json
{
  "id": "sample-id",
  "question": "Which universities did Barack Obama graduate from?",
  "question_entities": [  ],
  "answer_entities": [  ]
}
```

In [None]:
raw_data_path = {
    'train': 'data/webqsp/raw/WebQSP/data/WebQSP.train.json',
    'test': 'data/webqsp/raw/WebQSP/data/WebQSP.test.json'
}

In [None]:
processed_samples = {}
for split, split_path in raw_data_path.items():
    raw_samples = srsly.read_json(split_path)['Questions']
    processed_split_samples = []
    for raw_sample in tqdm(raw_samples, desc=f'Processing {split} split'):
        answers = set()
        for parse in raw_sample['Parses']:
            for answer in parse['Answers']:
                if answer['AnswerType'] == 'Entity':
                    answers.add(answer['AnswerArgument'])
        sample = {
            'question': raw_sample['ProcessedQuestion'],
            'question_entities': list(set(e['TopicEntityMid'] for e in raw_sample['Parses'])),
            'answer_entities': list(answers)
        }
        if len(sample['answer_entities']) > 0 and len(sample['question_entities']) > 0:
            processed_split_samples.append(sample)
    processed_samples[split] = processed_split_samples

Processing train split: 100%|██████████| 3098/3098 [00:00<00:00, 37617.62it/s]
Processing test split: 100%|██████████| 1639/1639 [00:00<00:00, 124835.92it/s]


In [None]:
Path(formatted_data_dir).mkdir(parents=True, exist_ok=True)
for split, split_samples in processed_samples.items():
    save_path = os.path.join(formatted_data_dir, f'{split}.jsonl')
    srsly.write_jsonl(save_path, split_samples)
    print(f'Formatted {split} samples are saved to {save_path}')

## Step 2. Preprocess

### 2.1 Preprocess and create the Training Data

In [None]:
!srtk preprocess --input data/webqsp/intermediate/scores.jsonl \
    --output $dataset_train_path \
    --intermediate-dir $intermediate_dir \
    --sparql-endpoint http://localhost:3001/sparql \
    --knowledge-graph freebase

Negative sampling: 100%|████████████████████| 2990/2990 [09:22<00:00,  5.32it/s]
Number of training records: 9802
Converting relation ids to labels: 100%|█| 9802/9802 [00:00<00:00, 148940.77it/s
Training samples are saved to data/webqsp/dataset/train.jsonl


### 2.2 Inspect the training data that we created:

In [None]:
!head -n 1 $dataset_train_path | jq

[1;39m{
  [0m[34;1m"query"[0m[1;39m: [0m[0;32m"what country is the grand bahama island in [SEP] "[0m[1;39m,
  [0m[34;1m"positive"[0m[1;39m: [0m[0;32m"location.location.containedby"[0m[1;39m,
  [0m[34;1m"negatives"[0m[1;39m: [0m[1;39m[
    [0;32m"location.location.nearby_airports"[0m[1;39m,
    [0;32m"location.location.contains"[0m[1;39m,
    [0;32m"location.location.time_zones"[0m[1;39m,
    [0;32m"location.location.nearby_airports"[0m[1;39m,
    [0;32m"kg.object_profile.prominent_type"[0m[1;39m,
    [0;32m"location.statistical_region.population"[0m[1;39m,
    [0;32m"common.topic.webpage"[0m[1;39m,
    [0;32m"location.location.contains"[0m[1;39m,
    [0;32m"location.location.time_zones"[0m[1;39m,
    [0;32m"common.topic.notable_types"[0m[1;39m,
    [0;32m"location.location.time_zones"[0m[1;39m,
    [0;32m"common.topic.article"[0m[1;39m,
    [0;32m"location.location.nearby_airports"[0m[1;39m,
    [0;32m"location.location.nea

## Step 3. Train the Retriever

In [None]:
!export TOKENIZERS_PARALLELISM=false

In [None]:
!CUDA_VISIBLE_DEVICES=1 srtk train --train-dataset $train_dataset_path \
    --model-name-or-path roberta-base \
    --output-dir $save_model_dir \
    --accelerator gpu \
    --learning-rate 1e-5 \
    --batch-size 64 \
    --max-epochs 5

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Found cached dataset json (/home/wiss/liao/.cache/huggingface/datasets/json/default-d39e2cbcbb9827f5/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)
Loading cached processed dataset at /home/wiss/liao/.cache/huggingface/datasets/json/default-d39e2cbcbb9827f5/0.0.0/0f7e3662

## Step 4. Evaluate the Retriever

For evaluation of the retriever, simply pass `--evaluate` flag to the `retrieve` subcommand.

In [None]:
formatted_test_path = os.path.join(formatted_data_dir, 'test.jsonl')

In [None]:
!srtk retrieve --input $formatted_test_path \
    --output $retrieve_subgraph_path \
    --sparql-endpoint http://localhost:3001/sparql \
    --knowledge-graph freebase \
    --scorer-model-path $save_model_dir \
    --beam-width 10 \
    --max-depth 2 \
    --evaluate

Retrieving subgraphs: 100%|██████████████████████████████████████████████████████████████████████████████████| 1582/1582 [58:41<00:00,  2.23s/it]
Retrieved subgraphs saved to to artifacts/subgraphs/webqsp.jsonl
Answer recall: 0.9121365360303414 (1443 / 1582)