In [None]:
# default_exp data.question_answering

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# data.question_answering

> This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for question/answering tasks.

In [None]:
#export
import ast
from functools import reduce

from blurr.utils import *
from blurr.data.core import *

import torch
from transformers import *
from fastai2.text.all import *

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


## Question/Answering tokenization, batch transform, and DataBlock methods

Question/Answering tasks are models that require two text inputs (a context that includes the answer and the question).  The objective is to predict the start/end tokens of the answer in the context)

In [None]:
path = Path('./')
squad_df = pd.read_csv(path/'squad_sample.csv'); len(squad_df)

1000

We've provided a simple subset of a pre-processed SQUADv2 dataset below just for demonstration purposes. There is a lot that can be done to make this much better and more fully functional.  The idea here is just to show you how things can work for tasks beyond sequence classification. 

In [None]:
squad_df.head(2)

Unnamed: 0,title,context,question_id,question_text,is_impossible,answer_text,answer_start,answer_end
0,New_York_City,"The New York City Fire Department (FDNY), provides fire protection, technical rescue, primary response to biological, chemical, and radioactive hazards, and emergency medical services for the five boroughs of New York City. The New York City Fire Department is the largest municipal fire department in the United States and the second largest in the world after the Tokyo Fire Department. The FDNY employs approximately 11,080 uniformed firefighters and over 3,300 uniformed EMTs and paramedics. The FDNY's motto is New York's Bravest.",56d1076317492d1400aab78c,What does FDNY stand for?,False,New York City Fire Department,4,33
1,Cyprus,"Following the death in 1473 of James II, the last Lusignan king, the Republic of Venice assumed control of the island, while the late king's Venetian widow, Queen Catherine Cornaro, reigned as figurehead. Venice formally annexed the Kingdom of Cyprus in 1489, following the abdication of Catherine. The Venetians fortified Nicosia by building the Venetian Walls, and used it as an important commercial hub. Throughout Venetian rule, the Ottoman Empire frequently raided Cyprus. In 1539 the Ottomans destroyed Limassol and so fearing the worst, the Venetians also fortified Famagusta and Kyrenia.",572e7f8003f98919007566df,In what year did the Ottomans destroy Limassol?,False,1539,481,485


In [None]:
task = HF_TASKS_AUTO.ForQuestionAnswering

pretrained_model_name = 'roberta-base' #'xlm-mlm-ende-1024'
config = AutoConfig.from_pretrained(pretrained_model_name)

hf_arch, hf_tokenizer, hf_config, hf_model = BLURR_MODEL_HELPER.get_auto_hf_objects(pretrained_model_name, 
                                                                                    task=task, 
                                                                                    config=config)

In [None]:
#export
def pre_process_squad(row, hf_arch, hf_tokenizer):
    context, qst, ans = row['context'], row['question_text'], row['answer_text']
    
    add_prefix_space = hf_arch in ['gpt2', 'roberta']
    
    if(hf_tokenizer.padding_side == 'right'):
        tok_input = hf_tokenizer.convert_ids_to_tokens(hf_tokenizer.encode(qst, context, 
                                                                           add_prefix_space=add_prefix_space))
    else:
        tok_input = hf_tokenizer.convert_ids_to_tokens(hf_tokenizer.encode(context, qst, 
                                                                           add_prefix_space=add_prefix_space))
                                                                       
    tok_ans = hf_tokenizer.tokenize(str(row['answer_text']), 
                                    add_special_tokens=False, 
                                    add_prefix_space=add_prefix_space)
    
    start_idx, end_idx = 0,0
    for idx, tok in enumerate(tok_input):
        try:
            if (tok == tok_ans[0] and tok_input[idx:idx + len(tok_ans)] == tok_ans): 
                start_idx, end_idx = idx, idx + len(tok_ans)
                break
        except: pass
            
    row['tokenized_input'] = tok_input
    row['tokenized_input_len'] = len(tok_input)
    row['tok_answer_start'] = start_idx
    row['tok_answer_end'] = end_idx
    
    return row

The `pre_process_squad` method is structured around how we've setup the squad DataFrame above.

In [None]:
squad_df = squad_df.apply(partial(pre_process_squad, hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), axis=1)

Token indices sequence length is longer than the specified maximum sequence length for this model (16 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (11 > 512). Running this sequence through the model will result in indexing errors


In [None]:
max_seq_len= 128

In [None]:
squad_df = squad_df[(squad_df.answer_end < max_seq_len) & (squad_df.is_impossible == False)]

In [None]:
#hide
squad_df.head(2)

Unnamed: 0,title,context,question_id,question_text,is_impossible,answer_text,answer_start,answer_end,tokenized_input,tokenized_input_len,tok_answer_start,tok_answer_end
0,New_York_City,"The New York City Fire Department (FDNY), provides fire protection, technical rescue, primary response to biological, chemical, and radioactive hazards, and emergency medical services for the five boroughs of New York City. The New York City Fire Department is the largest municipal fire department in the United States and the second largest in the world after the Tokyo Fire Department. The FDNY employs approximately 11,080 uniformed firefighters and over 3,300 uniformed EMTs and paramedics. The FDNY's motto is New York's Bravest.",56d1076317492d1400aab78c,What does FDNY stand for?,False,New York City Fire Department,4,33,"[<s>, ĠWhat, Ġdoes, ĠFD, NY, Ġstand, Ġfor, ?, </s>, </s>, ĠThe, ĠNew, ĠYork, ĠCity, ĠFire, ĠDepartment, Ġ(, FD, NY, ),, Ġprovides, Ġfire, Ġprotection, ,, Ġtechnical, Ġrescue, ,, Ġprimary, Ġresponse, Ġto, Ġbiological, ,, Ġchemical, ,, Ġand, Ġradioactive, Ġhazards, ,, Ġand, Ġemergency, Ġmedical, Ġservices, Ġfor, Ġthe, Ġfive, Ġborough, s, Ġof, ĠNew, ĠYork, ĠCity, ., ĠThe, ĠNew, ĠYork, ĠCity, ĠFire, ĠDepartment, Ġis, Ġthe, Ġlargest, Ġmunicipal, Ġfire, Ġdepartment, Ġin, Ġthe, ĠUnited, ĠStates, Ġand, Ġthe, Ġsecond, Ġlargest, Ġin, Ġthe, Ġworld, Ġafter, Ġthe, ĠTokyo, ĠFire, ĠDepartment, ., ĠThe, Ġ...",118,11,16
5,Communications_in_Somalia,"The Somali Postal Service (Somali Post) is the national postal service of the Federal Government of Somalia. It is part of the Ministry of Information, Posts and Telecommunication.",56e1b959cd28a01900c67ad1,What is the name of the National postal service of Somalia?,False,Somali Post,4,15,"[<s>, ĠWhat, Ġis, Ġthe, Ġname, Ġof, Ġthe, ĠNational, Ġpostal, Ġservice, Ġof, ĠSomalia, ?, </s>, </s>, ĠThe, ĠSomali, ĠPostal, ĠService, Ġ(, S, om, ali, ĠPost, ), Ġis, Ġthe, Ġnational, Ġpostal, Ġservice, Ġof, Ġthe, ĠFederal, ĠGovernment, Ġof, ĠSomalia, ., ĠIt, Ġis, Ġpart, Ġof, Ġthe, ĠMinistry, Ġof, ĠInformation, ,, ĠPosts, Ġand, ĠTele, communication, ., </s>]",52,0,0


In [None]:
vocab = dict(enumerate(range(max_seq_len)));

Below we utilize the @typedispatch decorator to completely change how we'll tokenize the data for the `ForQuestionAnsweringTask`.  This requires us defining a custom type to identify question/answer inputs

In [None]:
#export
class HF_QuestionAnswerInput(list): pass

In [None]:
#export
@typedispatch
def build_hf_input(task:ForQuestionAnsweringTask, tokenizer, 
                   a_tok_ids, b_tok_ids=None, targets=None,
                   max_length=512, pad_to_max_length=True, truncation_strategy=None):

    if (truncation_strategy is None):
        truncation_strategy = "only_second" if tokenizer.padding_side == "right" else "only_first"

    res = tokenizer.prepare_for_model(a_tok_ids if tokenizer.padding_side == "right" else b_tok_ids, 
                                      b_tok_ids if tokenizer.padding_side == "right" else a_tok_ids,
                                      max_length=max_length, 
                                      pad_to_max_length=pad_to_max_length,
                                      truncation_strategy=truncation_strategy, 
                                      return_special_tokens_mask=True,
                                      return_tensors='pt')
    
    input_ids = res['input_ids'][0]
    attention_mask = res['attention_mask'][0] if ('attention_mask' in res) else tensor([-9999]) 
    token_type_ids = res['token_type_ids'][0] if ('token_type_ids' in res) else tensor([-9999]) 
    
    # cls_index: location of CLS token (used by xlnet and xlm) ... this is a list.index(value) for pytorch tensor's
    cls_index = (input_ids == tokenizer.cls_token_id).nonzero()[0]
    
    # p_mask: mask with 1 for token than cannot be in the answer, else 0 (used by xlnet and xlm)
    p_mask = tensor(res['special_tokens_mask']) if ('special_tokens_mask' in res) else tensor([-9999]) 
    
    return HF_QuestionAnswerInput([input_ids, attention_mask, token_type_ids, cls_index, p_mask]), targets

And here we demonstrate some more of the extensibility bits of the framework, by passing in our own instance of `HF_BatchTransform`.  

Notice how we set the `task=ForQuestionAnsweringTask()` so that our custom `build_hf_input` above, for qustion/answering tasks, gets called rather than the default implementation.

In [None]:
# (optional): override HF_BatchTransform defaults
hf_batch_tfm = HF_BatchTransform(hf_arch, hf_tokenizer, max_seq_len=max_seq_len, truncation_strategy='only_second', 
                                 task=ForQuestionAnsweringTask())

blocks = (
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm), 
    CategoryBlock(vocab=vocab),
    CategoryBlock(vocab=vocab)
)

dblock = DataBlock(blocks=blocks, 
                   get_x=lambda x: (x.question_text, x.context),
                   get_y=[ColReader('tok_answer_start'), ColReader('tok_answer_end')],
                   splitter=RandomSplitter(),
                   n_inp=1)

In [None]:
dls = dblock.dataloaders(squad_df, bs=4)

In [None]:
b = dls.one_batch(); len(b), len(b[0]), len(b[1]), len(b[2])

(3, 5, 4, 4)

In [None]:
b[0][0].shape, b[0][1].shape, b[0][2].shape, b[0][3].shape, b[0][4].shape, b[1].shape, b[2].shape

(torch.Size([4, 128]),
 torch.Size([4, 128]),
 torch.Size([4, 1]),
 torch.Size([4, 1]),
 torch.Size([4, 128]),
 torch.Size([4]),
 torch.Size([4]))

In [None]:
#export
@typedispatch
def show_batch(x:HF_QuestionAnswerInput, y, samples, hf_tokenizer, skip_special_tokens=True, 
               ctxs=None, max_n=6, **kwargs):  
    res = L()
    for inp, start, end in zip(x[0], *y):
        txt = hf_tokenizer.decode(inp, skip_special_tokens=skip_special_tokens).replace(hf_tokenizer.pad_token, '')
        ans_toks = hf_tokenizer.convert_ids_to_tokens(inp, skip_special_tokens=False)[start:end]
        res.append((txt, (start.item(),end.item()), hf_tokenizer.convert_tokens_to_string(ans_toks)))
                       
    display_df(pd.DataFrame(res, columns=['text', 'start/end', 'answer'])[:max_n])
    return ctxs

The `show_batch` method above allows us to create a more interpretable view of our question/answer data.

In [None]:
dls.show_batch(hf_tokenizer=hf_tokenizer, skip_special_tokens=False, max_n=2)

Unnamed: 0,text,start/end,answer
0,"<s> When was Brazil colonized?</s></s> Child labour has been a consistent struggle for children in Brazil ever since the country was colonized on April 22, 1550 by Pedro Álvares Cabral. Work that many children took part in was not always visible, legal, or paid. Free or slave labour was a common occurrence for many youths and was a part of their everyday lives as they grew into adulthood. Yet due to there being no clear definition of how to classify what a child or youth is, there has been little historical documentation of child labour during the colonial period. Due to this lack of documentation, it is hard</s>","(28, 33)","April 22, 1550"
1,"<s> What proportion of the federal census records in existence does the National Archives Building house?</s></s> The National Archives Building in downtown Washington holds record collections such as all existing federal census records, ships' passenger lists, military unit records from the American Revolution to the Philippine–American War, records of the Confederate government, the Freedmen's Bureau records, and pension and land records.</s>","(31, 32)",all


## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_data-core.ipynb.
Converted 01a_data-language-modeling.ipynb.
Converted 01c_data-question-answering.ipynb.
Converted 01d_data-token-classification.ipynb.
Converted 02_modeling-core.ipynb.
Converted 02a_modeling-language-modeling.ipynb.
Converted 02c_modeling-question-answering.ipynb.
Converted 02d_modeling-token-classification.ipynb.
Converted index.ipynb.
