# Understanding how the SQuAD dataset is set up for the text extraction task with BERT

We are going to fine-tune [BERT implemented by HuggingFace](https://huggingface.co/bert-base-uncased) for the text extraction task with a dataset of questions and answers with the [SQuAD (The Stanford Question Answering Dataset)](https://rajpurkar.github.io/SQuAD-explorer/) dataset.
The data is composed by a set of questions and corresponding paragraphs that contains the answers.
The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to see how the data is set up for training.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

More info:
- [Glossary - HuggingFace docs](https://huggingface.co/transformers/glossary.html#model-inputs)
- [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [None]:
import os
import utility.data_processing as dpp
import utility.testing as testing
from datasets import load_dataset, load_metric
from transformers import BertTokenizer
from tokenizers import BertWordPieceTokenizer
from rich.pretty import pprint

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

## The raw data

In [None]:
bert_cache = os.path.join(os.getcwd(), 'cache')

In [None]:
hf_dataset = load_dataset('squad')

In [None]:
hf_dataset

In [None]:
for i, _squad_example in enumerate(hf_dataset['train']):
    pprint(_squad_example)
    if i > 5:
        break

In [None]:
for i, _squad_example in enumerate(hf_dataset['validation']):
    pprint(_squad_example)
    if i > 5:
        break

In [None]:
len(hf_dataset['train']['title'])

In [None]:
len(hf_dataset['validation']['title'])

In [None]:
len(set(hf_dataset['train']['title']))

In [None]:
len(set(hf_dataset['validation']['title']))

In [None]:
squad_ex = hf_dataset['train'].select([20584])

In [None]:
squad_ex['title']

In [None]:
squad_ex['context']

In [None]:
squad_ex['question']

In [None]:
squad_ex['answers']

# The tokenizer

## Processing the data for training
Now we process the data so we can feed it later to the model.
The idea is to replace the words (and some word parts) by numbers using the tokenizer above and organize the training data as a set of paragraphs and questions.

In [None]:
hf_model = 'bert-base-uncased'

slow_tokenizer = BertTokenizer.from_pretrained(
    hf_model,
    cache_dir=os.path.join(bert_cache, f'_{hf_model}-tokenizer'),
)

In [None]:
# a faster tokenizer implementation
save_path = os.path.join(bert_cache, f'{hf_model}-tokenizer')
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)
    
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(os.path.join(save_path, 'vocab.txt'),
                                   lowercase=True)

In [None]:
encoding = tokenizer.encode("Let's tokenize something?")

In [None]:
encoding.tokens

In [None]:
encoding.ids

In [None]:
tokenizer.decode(encoding.ids)

In [None]:
for i, j in encoding.offsets:
    print("Let's tokenize something?"[i: j])

## Processing the data

In [None]:
max_len = 384

In [None]:
hf_dataset.flatten()

In [None]:
%%time
processed_dataset = hf_dataset.flatten().map(
    lambda example: dpp.process_squad_item_batched(example, max_len, tokenizer),
    remove_columns=hf_dataset.flatten()['train'].column_names,
    batched=True,  # dpp.process_squad_item_batched needs `batched=True`
    num_proc=12
)

In [None]:
processed_dataset

In [None]:
train_dataset = processed_dataset["train"]
train_dataset.set_format(type='numpy')

# eval_dataset = processed_dataset["validation"]
# eval_dataset.set_format(type='torch')

### The SquadExample objects

In [None]:
squad_ex   # Alps

In [None]:
squad_ex_obj = dpp.create_squad_example(squad_ex[0], max_len, tokenizer)
type(squad_ex_obj)

In [None]:
squad_ex_obj.__dict__.keys()

## The training set

In [None]:
train_dataset

In [None]:
train_sample = train_dataset.select([20299])[0]
pprint(train_sample)

## The model input

In [None]:
(
    train_sample['input_ids'].shape,
    train_sample['token_type_ids'].shape,
    train_sample['attention_mask'].shape
)

In [None]:
train_sample['input_ids']

In [None]:
tokenizer.decode(train_sample['input_ids'])

## [Attention masks](https://huggingface.co/transformers/glossary.html#attention-mask)
To create batches for training the text needs to be padded. The attention masks differentiate what is text and what is padding.

In [None]:
train_sample['attention_mask']

In [None]:
context_encoded = train_sample['input_ids'][train_sample['attention_mask'] == 1]
tokenizer.decode(context_encoded)

## [Token type ids](https://huggingface.co/transformers/glossary.html#token-type-ids)
Differentiate two types of tokens, the ones that correspond to the question and the ones that correspond to the answers.

In [None]:
train_sample['token_type_ids']

In [None]:
paragraph_encoded = train_sample['input_ids'][train_sample['token_type_ids'] == 0]
tokenizer.decode(paragraph_encoded)

In [None]:
question_encoded = train_sample['input_ids'][train_sample['token_type_ids'] == 1]
tokenizer.decode(question_encoded)

### The references

In [None]:
train_sample['start_token_idx'], train_sample['end_token_idx']

In [None]:
print('\n * CONTEXT:                   \n', squad_ex_obj.context)
print('\n * QUESTION:                  \n', squad_ex_obj.question)
print('\n * ANSWER (REFERENCE):        \n', squad_ex_obj.answer_text[0])
print('\n * ANSWER FROM CONTEXT:       \n', tokenizer.decode(train_sample['input_ids'][train_sample['start_token_idx']:
                                                                                       train_sample['end_token_idx']]))
print('\n\n === TRAINING SAMPLE ===')
print('\n * CONTEXT & QUESTION:        \n', tokenizer.decode(train_sample['input_ids']))
print('\n * POSITION in CONTEXT:       \n', (train_sample['start_token_idx'], train_sample['end_token_idx']))