# Fine-tuning a Question-Answering (QA) Model with HuggingFace

This examples shows how to train a Question-Answering Model using HuggingFace.

The example was taken from one of the lectures of the [Udacity Generative AI Nanodegree](), and it uses several snippets from the repository of the [HuggingFace Examples](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/trainer_qa.py).

The example uses the [SQuAD 2.0](https://arxiv.org/abs/1806.03822) format, in which each QA pair is formatted as follows

```python
{
  'id': 'xxx',
  'title': 'my title',
  'context': 'Here the complete context or text document is added.', # our document
  'question': 'What is...?', # our question
  'answers': {
    'text': ['1925'], # list of answers (list(str))
    'answer_start': [354] # the chars in context where the answer text starts (list(int))
  }
}
```

As we can see, the QA example is *extractive* (the answer is in the text), and not *abstractive* (the answer is deduced).


## Imports

In [8]:
import pathlib
import pandas as pd
from datasets import Dataset, load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    EvalPrediction,
    Trainer,
    default_data_collator,
    pipeline,
)
from transformers.trainer_utils import PredictionOutput, speed_metrics
import math
import time
import collections
import numpy as np
from tqdm.notebook import tqdm

## Dataset: Apply Correct Format

Our *dummy* dataset is about AP controllers (Access Point) -- it is very technical.

The dataset consists of 10 QA pairs in [`data/qa.csv`](./data/qa.csv); in that dataframe, we have three columns

- `question`
- `answer`
- `filename`: this field points to any of two TXT documents ([`CVE-2020-29583.txt`](./data/CVE-2020-29583.txt) and [`xss.txt`](./data/xss.txt)), where the context for the answer is provided.

We need to transform the dataset CSV to the [SQuAD 2.0](https://arxiv.org/abs/1806.03822) format above, i.e., a list of dictionaries.

In [9]:
df = pd.read_csv("data/qa.csv")
df.head()

Unnamed: 0,question,answer,filename
0,Who is the manufacturer of the product?,Zyxel,CVE-2020-29583.txt
1,Who reported the vulnerability?,researchers from EYE Netherlands,CVE-2020-29583.txt
2,What is the vulnerability?,A hardcoded credential vulnerability was ident...,CVE-2020-29583.txt
3,How do users protect themselves?,we urge users to install the applicable updates,CVE-2020-29583.txt
4,What products are affected?,firewalls and AP controllers,CVE-2020-29583.txt


In [19]:
def qa_to_squad(
    question: str,
    answer: str,
    filename: str,
    identifier: str
) -> dict:
    filepath = pathlib.Path("data") / filename
    with open(filepath, "r") as f:
        context = f.read()
    
    # Assuming the answer appears exactly in the context
    # find where the answer starts in the context
    start_location = context.find(answer)
    qa_pair = {
        'id': identifier,
        'title': filepath.as_posix(),
        'context': context,
        'question': question,
        'answers': {
            'text': [answer],
            'answer_start': [start_location]
        }
    }
    return qa_pair

In [20]:
# Build a list of dictionaries
# being each dict the QA pair/row in SQuAD format
qa_list = list()
for i, row in df.iterrows():
    q = row['question']
    a = row['answer']
    f = row['filename']
    squad_dict = qa_to_squad(q, a, f, i)
    qa_list.append(squad_dict)

In [21]:
# Convert the list of dicts into a Dataset object
# We need to use pandas as intermediate auxiliary library
qa_df = pd.DataFrame(data=qa_list)
data = Dataset.from_pandas(qa_df)
print(data[0])
# {'id': 0, 'title': 'data/qa/CVE-2020-29583.txt', 'context': 'CVE: ...

{'id': 0, 'title': 'data/CVE-2020-29583.txt', 'context': 'CVE:   CVE-2020-29583 Summary Zyxel has released a patch for the hardcoded credential vulnerability of firewalls and AP controllers recently reported by researchers from EYE Netherlands. Users are advised to install the applicable firmware updates for optimal protection. What is the vulnerability? A hardcoded credential vulnerability was identified in the “zyfwp” user account in some Zyxel firewalls and AP controllers. The account was designed to deliver automatic firmware updates to connected access points through FTP. What versions are vulnerable—and what should you do? After a thorough investigation, we’ve identified the vulnerable products and are releasing firmware patches to address the issue, as shown in the table below. For optimal protection, we urge users to install the applicable updates. For those not listed, they are not affected. Contact your local Zyxel support team if you require further assistance or visit our  

In [22]:
# We can save the dataset to disk
data.save_to_disk("data/qa_data.hf")

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

In [24]:
# Load the dataset from disk
loaded_data = load_from_disk("data/qa_data.hf")

# Inspect the first few entries
print(loaded_data[0])         # Print the first example
print(loaded_data[:3])        # Print the first three examples

# Or convert to a pandas DataFrame for easier inspection
df = loaded_data.to_pandas()
print(df.head())

{'id': 0, 'title': 'data/CVE-2020-29583.txt', 'context': 'CVE:   CVE-2020-29583 Summary Zyxel has released a patch for the hardcoded credential vulnerability of firewalls and AP controllers recently reported by researchers from EYE Netherlands. Users are advised to install the applicable firmware updates for optimal protection. What is the vulnerability? A hardcoded credential vulnerability was identified in the “zyfwp” user account in some Zyxel firewalls and AP controllers. The account was designed to deliver automatic firmware updates to connected access points through FTP. What versions are vulnerable—and what should you do? After a thorough investigation, we’ve identified the vulnerable products and are releasing firmware patches to address the issue, as shown in the table below. For optimal protection, we urge users to install the applicable updates. For those not listed, they are not affected. Contact your local Zyxel support team if you require further assistance or visit our  