In [1]:
from datasets import load_dataset

In [3]:
data = load_dataset('squad')

Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.63 MiB, post-processed: Unknown size, total: 119.14 MiB) to /Users/jamesbriggs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453...


Downloading: 30.3MB [00:35, 846kB/s]
Downloading: 4.85MB [00:00, 11.7MB/s]


Dataset squad downloaded and prepared to /Users/jamesbriggs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453. Subsequent calls will reuse this data.


In [4]:
data

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [7]:
data['train']['answers'][:5]

Downloading: 20.1MB [06:12, 54.1kB/s]


[{'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]},
 {'text': ['a copper statue of Christ'], 'answer_start': [188]},
 {'text': ['the Main Building'], 'answer_start': [279]},
 {'text': ['a Marian place of prayer and reflection'], 'answer_start': [381]},
 {'text': ['a golden statue of the Virgin Mary'], 'answer_start': [92]}]

In [25]:
data['train']['answers'][0]

{'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

In [51]:
from tqdm.auto import tqdm  # for showing progress bar

def add_end_idx(answers, contexts):
    new_answers = []
    # loop through each answer-context pair
    for answer, context in tqdm(zip(answers, contexts)):
        # quick reformating to remove lists
        answer['text'] = answer['text'][0]
        answer['answer_start'] = answer['answer_start'][0]
        # gold_text refers to the answer we are expecting to find in context
        gold_text = answer['text']
        # we already know the start index
        start_idx = answer['answer_start']
        # and ideally this would be the end index...
        end_idx = start_idx + len(gold_text)

        # ...however, sometimes squad answers are off by a character or two
        if context[start_idx:end_idx] == gold_text:
            # if the answer is not off :)
            answer['answer_end'] = end_idx
        else:
            # this means the answer is off by 1-2 tokens
            for n in [1, 2]:
                if context[start_idx-n:end_idx-n] == gold_text:
                    answer['answer_start'] = start_idx - n
                    answer['answer_end'] = end_idx - n
        new_answers.append(answer)
    return new_answers

def prep_data(dataset):
    questions = dataset['question']
    contexts = dataset['context']
    answers = add_end_idx(
        dataset['answers'],
        contexts
    )
    return {
        'question': questions,
        'context': contexts,
        'answers': answers
    }

In [34]:
dataset = prep_data(data['train'])

In [35]:
dataset['answers'][:5]

[{'text': 'Saint Bernadette Soubirous',
  'answer_start': 515,
  'answer_end': 541},
 {'text': 'a copper statue of Christ', 'answer_start': 188, 'answer_end': 213},
 {'text': 'the Main Building', 'answer_start': 279, 'answer_end': 296},
 {'text': 'a Marian place of prayer and reflection',
  'answer_start': 381,
  'answer_end': 420},
 {'text': 'a golden statue of the Virgin Mary',
  'answer_start': 92,
  'answer_end': 126}]

Next, we tokenize our questions and contexts.

In [36]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# tokenize
train = tokenizer(dataset['context'], dataset['question'],
                  truncation=True, padding='max_length',
                  max_length=512, return_tensors='pt')

We now how context-question pairs represented as `Encoding` objects. Let's take a look at a decoded context-question pair.

In [46]:
tokenizer.decode(train['input_ids'][0])[:855]

'[CLS] architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary. [SEP] to whom did the virgin mary allegedly appear in 1858 in lourdes france? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

Next, we need to add our token start and end positions - remember before we identified the character start and end positions of our answers? Now we must do the same but translated into *token positions*. We will write an `add_token_positions` function to do this.

In [52]:
def add_token_positions(encodings, answers):
    # initialize lists to contain the token indices of answer start/end
    start_positions = []
    end_positions = []
    for i in tqdm(range(len(answers))):
        # append start/end token position using char_to_token method
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        # end position cannot be found, char_to_token found space, so shift position until found
        shift = 1
        while end_positions[-1] is None:
            end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] - shift)
            shift += 1
    # update our encodings object with the new token-based start/end positions
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

In [53]:
# apply function to our data
add_token_positions(train, dataset['answers'])

100%|██████████| 87599/87599 [00:02<00:00, 40006.03it/s]


This has added two more tensors to our `Encoding` object, `start_positions` and `end_positions`. Each of these is a list containing the start and end position of the answer.

In [54]:
train.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])

In [56]:
train['start_positions'][:5], train['end_positions'][:5]

([114, 40, 65, 85, 20], [121, 44, 67, 92, 27])

Now we initialize our dataset for training with PyTorch. We use a dataset object for this.

In [57]:
import torch

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# build datasets for both our training and validation sets
train_dataset = SquadDataset(train)