# MobileBERT for Question Answering on the SQuAD dataset

### 2. Fine-tuning the model

In these notebooks we are going use [MobileBERT implemented by HuggingFace](https://huggingface.co/docs/transformers/model_doc/mobilebert) on the question answering task by text-extraction on the [The Stanford Question Answering Dataset (SQuAD)](https://rajpurkar.github.io/SQuAD-explorer/). The data is composed by a set of questions and paragraphs that contain the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to Fine-tuning the model.

More info from HuggingFace docs:
- [Question Answering](https://huggingface.co/tasks/question-answering)
- [Glossary](https://huggingface.co/transformers/glossary.html#model-inputs)
- [Question Answering chapter of NLP course](https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt)

In [None]:
import torch
from transformers import AutoTokenizer, MobileBertForQuestionAnswering
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [None]:
hf_model_checkpoint = 'google/mobilebert-uncased'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(hf_model_checkpoint)
model = MobileBertForQuestionAnswering.from_pretrained(hf_model_checkpoint)

In [None]:
hf_dataset = load_dataset('squad')

In [None]:
# Preprocessing data
# Find more info about this in the notebook about exploring the dataset

MAX_SEQ_LEN = 300

def tokenize_dataset(squad_example, tokenizer=tokenizer):
    """Tokenize the text in the dataset and convert
    the start and ending positions of the answers
    from text to tokens"""
    max_len = MAX_SEQ_LEN
    context = squad_example['context']
    answer_start = squad_example['answers']['answer_start'][0]
    answer = squad_example['answers']['text'][0]
    squad_example_tokenized = tokenizer(
        context, squad_example['question'],
        padding='max_length',
        max_length=max_len,
        truncation='only_first',
    )
    token_start = len(tokenizer.tokenize(context[:answer_start + 1]))
    token_end = len(tokenizer.tokenize(answer)) + token_start

    squad_example_tokenized['start_token_idx'] = token_start
    squad_example_tokenized['end_token_idx'] = token_end

    return squad_example_tokenized


def filter_samples_by_max_seq_len(squad_example):
    """Fliter out the samples where the answers are
    not within the first `MAX_SEQ_LEN` tokens"""
    max_len = MAX_SEQ_LEN
    answer_start = squad_example['answers']['answer_start'][0]
    answer = squad_example['answers']['text'][0]
    token_start = len(tokenizer.tokenize(squad_example['context'][:answer_start]))
    token_end = len(tokenizer.tokenize(answer)) + token_start
    if token_end < max_len:
        return True

In [None]:
dataset_filtered = hf_dataset.filter(
    filter_samples_by_max_seq_len,
    num_proc=12,
)

In [None]:
dataset_tok = dataset_filtered.map(
    tokenize_dataset,
    remove_columns=hf_dataset['train'].column_names,
    num_proc=12,
)
dataset_tok.set_format('pt')
dataset_tok

In [None]:
batch_size = 50

train_dataloader = DataLoader(
    dataset_tok['train'],
    shuffle=False,
    batch_size=batch_size,
)

# eval_dataloader = DataLoader(
#     dataset_tok['validation'],
#     shuffle=True,
#     batch_size=batch_size
# )

In [None]:
device = 0
model.to(device)
model.train();

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=3e-5)

In [None]:
for epoch in range(1):
    for i, batch in enumerate(train_dataloader):
        optim.zero_grad()
        outputs = model(input_ids=batch['input_ids'].to(device),
                        token_type_ids=batch['token_type_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device),
                        start_positions=batch['start_token_idx'].to(device),
                        end_positions=batch['end_token_idx'].to(device))        
        loss = outputs[0]
        loss.backward()
        optim.step()
    
#         print(loss)        
#         if i > 10:
#             break

In [None]:
torch.save(model.state_dict(), 'mobilebertqa_ft')