# Fine-tuning a MobileBERT model for Q&A with the SQuAD dataset

We are going to fine-tune [MobileBERT implemented by HuggingFace](https://huggingface.co/docs/transformers/model_doc/mobilebert) for the text-extraction task on the [The Stanford Question Answering Dataset (SQuAD)](https://rajpurkar.github.io/SQuAD-explorer/).

The data is composed by a set of questions and paragraphs that contain the answers.
The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to evaluate the model from a checkpoint we already obtained.

More info:
- [Glossary - HuggingFace docs](https://huggingface.co/transformers/glossary.html#model-inputs)
- [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, MobileBertForQuestionAnswering
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [None]:
hf_model_checkpoint = 'google/mobilebert-uncased'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(hf_model_checkpoint)
model = MobileBertForQuestionAnswering.from_pretrained(hf_model_checkpoint)

In [None]:
hf_dataset = load_dataset('squad')

In [None]:
MAX_SEQ_LEN = 300

def tokenize_dataset(squad_example, tokenizer=tokenizer):
    """Tokenize the text in the dataset and convert
    the start and ending positions of the answers
    from text to tokens"""
    max_len = MAX_SEQ_LEN
    context = squad_example['context']
    answer_start = squad_example['answers']['answer_start'][0]
    answer = squad_example['answers']['text'][0]
    squad_example_tokenized = tokenizer(
        context, squad_example['question'],
        padding='max_length',
        max_length=max_len,
        truncation=True,
    )
    token_start = len(tokenizer.tokenize(context[:answer_start + 1]))
    token_end = len(tokenizer.tokenize(answer)) + token_start

    squad_example_tokenized['start_token_idx'] = token_start
    squad_example_tokenized['end_token_idx'] = token_end

    return squad_example_tokenized


def filter_samples_by_max_seq_len(squad_example):
    """Fliter out the samples where the answers are
    not within the first `MAX_SEQ_LEN` tokens"""
    max_len = MAX_SEQ_LEN
    answer_start = squad_example['answers']['answer_start'][0]
    answer = squad_example['answers']['text'][0]
    token_start = len(tokenizer.tokenize(squad_example['context'][:answer_start]))
    token_end = len(tokenizer.tokenize(answer)) + token_start
    if token_end < max_len:
        return True

In [None]:
dataset_filtered = hf_dataset.filter(
    filter_samples_by_max_seq_len,
    num_proc=12,
)

In [None]:
dataset_tok = dataset_filtered.map(
    tokenize_dataset,
    remove_columns=hf_dataset['train'].column_names,
    num_proc=12,
)
dataset_tok.set_format('pt')
dataset_tok

In [None]:
eval_dataloader = DataLoader(
    dataset_tok['validation'],
    shuffle=True,   # shuffle to print different predictions
    batch_size=10
)

In [None]:
model.eval();

In [None]:
model.load_state_dict(
    torch.load('mobilebertforaq_trained_thu',
               map_location=torch.device('cpu'))
)

In [None]:
for i, batch in enumerate(eval_dataloader):
    # evaluate the model
    outputs = model(
        input_ids=batch['input_ids'],
        token_type_ids=batch['token_type_ids'],
        attention_mask=batch['attention_mask']
    )
    
    # obtain the predicted start and end possitions and
    # apply a softmax to it
    pred_start = F.softmax(outputs.start_logits, dim=-1)
    pred_end   = F.softmax(outputs.end_logits,   dim=-1)
    
    # loop over the batch
    for context_tokens, start_ref, end_ref, start, end, in zip(batch['input_ids'],
                                                               batch['start_token_idx'], batch['end_token_idx'],
                                                               pred_start, pred_end):
        context_text = tokenizer.decode(context_tokens, skip_special_tokens=True,
                                        clean_up_tokenization_spaces=True)
        # find the max class that the softmax gives
        start = torch.argmax(start)
        end = torch.argmax(end)
        
        # predicted answer
        answer_tokens = context_tokens[start:end]
        answer_text = tokenizer.decode(answer_tokens, skip_special_tokens=True,
                                       clean_up_tokenization_spaces=True)
        
        # reference answers
        answer_tokens_ref = context_tokens[start_ref:end_ref]
        answer_text_ref = tokenizer.decode(answer_tokens_ref, skip_special_tokens=True,
                                           clean_up_tokenization_spaces=True)
                
        print(f'* {context_text}\n')
        print(f'[robot] {answer_text}')
        print(f'[ ref ] {answer_text_ref}\n')
        print('---')
        
    if i > 10:
        break