# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We have fine-tuned [BERT implemented by HuggingFace](https://huggingface.co/bert-base-uncased) for the text-extraction task with a dataset of questions and answers with the [SQuAD (The Stanford Question Answering Dataset)](https://rajpurkar.github.io/SQuAD-explorer/) dataset. Let evaluate the model.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

More info:
- [Glossary - HuggingFace docs](https://huggingface.co/transformers/glossary.html#model-inputs)
- [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [None]:
import utility.data_processing as dproc
import utility.testing as testing
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BertForQuestionAnswering
from tokenizers import BertWordPieceTokenizer

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [None]:
hf_model = 'bert-base-uncased'

In [None]:
tokenizer_simple = AutoTokenizer.from_pretrained(hf_model)
tokenizer = BertWordPieceTokenizer(vocab=tokenizer_simple.vocab, lowercase=True)

In [None]:
model = BertForQuestionAnswering.from_pretrained(hf_model)

In [None]:
model_path_name = '/scratch/snx3000/sarafael/pytorch-training/bert_squad/model_trained_pytorch_2024-02-03-161157' #'/scratch/snx3000/class401/bert_trained_deepspeed_example'

# load the model on cpu
model.load_state_dict(
    torch.load(model_path_name,
               map_location=torch.device('cpu'))
)

# load the model on gpu
# model.load_state_dict(torch.load(model_path_name))

In [None]:
hf_dataset = load_dataset('squad')

In [None]:
val_ds = hf_dataset['validation'].flatten()

In [None]:
max_len = 384

In [None]:
processed_val_ds = val_ds.map(
    lambda example: dproc.process_squad_item_batched(example, max_len, tokenizer),
    remove_columns=val_ds.column_names,
    batched=True,
    num_proc=12
)

In [None]:
processed_val_ds.set_format(type='torch')

In [None]:
batch_size = 1

eval_dataloader = DataLoader(
    processed_val_ds,
    shuffle=False,
    batch_size=batch_size
)

In [None]:
squad_example_objects = []
for item in val_ds:
    squad_examples = dproc.squad_examples_from_dataset(item, max_len, tokenizer)
    try:
        squad_example_objects.extend(squad_examples)
    except TypeError:
        squad_example_objects.append(squad_examples)
        
assert len(processed_val_ds) == len(squad_example_objects)

In [None]:
start_sample = 24100
num_test_samples = 10
for i, eval_batch in enumerate(eval_dataloader):
    if i > start_sample:
        testing.EvalUtility(eval_batch, [squad_example_objects[i]], model).results()

    if i > start_sample + num_test_samples:
        break