# Fine-tuning a Question-Answering (QA) Model with HuggingFace

This examples shows how to train a Question-Answering Model using HuggingFace.

The example was taken from one of the lectures of the [Udacity Generative AI Nanodegree](), and it uses several snippets from the repository of the [HuggingFace Examples](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/trainer_qa.py).

The example uses the [SQuAD 2.0](https://arxiv.org/abs/1806.03822) format, in which each QA pair is formatted as follows

```python
{
  'id': 'xxx',
  'title': 'my title',
  'context': 'Here the complete context or text document is added.', # our document
  'question': 'What is...?', # our question
  'answers': {
    'text': ['1925'], # list of answers (list(str))
    'answer_start': [354] # the chars in context where the answer text starts (list(int))
  }
}
```

As we can see, the QA example is *extractive* (the answer is in the text), and not *generative-abstractive* (the answer is deduced).

## Imports

In [35]:
import pathlib
import pandas as pd
from datasets import Dataset, load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    EvalPrediction,
    Trainer,
    default_data_collator,
    pipeline,
)
from transformers.trainer_utils import PredictionOutput, speed_metrics
import torch
import math
import time
import collections
import numpy as np
from tqdm.notebook import tqdm

## Dataset: Apply Correct Format

Our *dummy* dataset is about AP controllers (Access Point) -- it is very technical.

The dataset consists of 10 QA pairs in [`data/qa.csv`](./data/qa.csv); in that dataframe, we have three columns

- `question`
- `answer`
- `filename`: this field points to any of two TXT documents ([`CVE-2020-29583.txt`](./data/CVE-2020-29583.txt) and [`xss.txt`](./data/xss.txt)), where the context for the answer is provided.

We need to transform the dataset CSV to the [SQuAD 2.0](https://arxiv.org/abs/1806.03822) format above, i.e., a list of dictionaries.

In [9]:
df = pd.read_csv("data/qa.csv")
df.head()

Unnamed: 0,question,answer,filename
0,Who is the manufacturer of the product?,Zyxel,CVE-2020-29583.txt
1,Who reported the vulnerability?,researchers from EYE Netherlands,CVE-2020-29583.txt
2,What is the vulnerability?,A hardcoded credential vulnerability was ident...,CVE-2020-29583.txt
3,How do users protect themselves?,we urge users to install the applicable updates,CVE-2020-29583.txt
4,What products are affected?,firewalls and AP controllers,CVE-2020-29583.txt


In [19]:
def qa_to_squad(
    question: str,
    answer: str,
    filename: str,
    identifier: str
) -> dict:
    filepath = pathlib.Path("data") / filename
    with open(filepath, "r") as f:
        context = f.read()
    
    # Assuming the answer appears exactly in the context
    # find where the answer starts in the context
    start_location = context.find(answer)
    qa_pair = {
        'id': identifier,
        'title': filepath.as_posix(),
        'context': context,
        'question': question,
        'answers': {
            'text': [answer],
            'answer_start': [start_location]
        }
    }
    return qa_pair

In [20]:
# Build a list of dictionaries
# being each dict the QA pair/row in SQuAD format
qa_list = list()
for i, row in df.iterrows():
    q = row['question']
    a = row['answer']
    f = row['filename']
    squad_dict = qa_to_squad(q, a, f, i)
    qa_list.append(squad_dict)

In [21]:
# Convert the list of dicts into a Dataset object
# We need to use pandas as intermediate auxiliary library
qa_df = pd.DataFrame(data=qa_list)
data = Dataset.from_pandas(qa_df)
print(data[0])
# {'id': 0, 'title': 'data/qa/CVE-2020-29583.txt', 'context': 'CVE: ...

{'id': 0, 'title': 'data/CVE-2020-29583.txt', 'context': 'CVE:   CVE-2020-29583 Summary Zyxel has released a patch for the hardcoded credential vulnerability of firewalls and AP controllers recently reported by researchers from EYE Netherlands. Users are advised to install the applicable firmware updates for optimal protection. What is the vulnerability? A hardcoded credential vulnerability was identified in the “zyfwp” user account in some Zyxel firewalls and AP controllers. The account was designed to deliver automatic firmware updates to connected access points through FTP. What versions are vulnerable—and what should you do? After a thorough investigation, we’ve identified the vulnerable products and are releasing firmware patches to address the issue, as shown in the table below. For optimal protection, we urge users to install the applicable updates. For those not listed, they are not affected. Contact your local Zyxel support team if you require further assistance or visit our  

In [22]:
# We can save the dataset to disk
data.save_to_disk("data/qa_data.hf")

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

In [24]:
# Load the dataset from disk
loaded_data = load_from_disk("data/qa_data.hf")

# Inspect the first few entries
print(loaded_data[0])         # Print the first example
print(loaded_data[:3])        # Print the first three examples

# Or convert to a pandas DataFrame for easier inspection
df = loaded_data.to_pandas()
print(df.head())

{'id': 0, 'title': 'data/CVE-2020-29583.txt', 'context': 'CVE:   CVE-2020-29583 Summary Zyxel has released a patch for the hardcoded credential vulnerability of firewalls and AP controllers recently reported by researchers from EYE Netherlands. Users are advised to install the applicable firmware updates for optimal protection. What is the vulnerability? A hardcoded credential vulnerability was identified in the “zyfwp” user account in some Zyxel firewalls and AP controllers. The account was designed to deliver automatic firmware updates to connected access points through FTP. What versions are vulnerable—and what should you do? After a thorough investigation, we’ve identified the vulnerable products and are releasing firmware patches to address the issue, as shown in the table below. For optimal protection, we urge users to install the applicable updates. For those not listed, they are not affected. Contact your local Zyxel support team if you require further assistance or visit our  

## Model

Since we are working on a *extractive* QA task, we can use any encoder or decoder transformer as the backbone; we add a small QA head on top, a linear layer that predicts:

- `start_logits`: probabilities of each token being the start of the answer
- and `end_logits`: probabilities of each token being the end of the answer

The selected model is [distilbert](https://huggingface.co/docs/transformers/en/model_doc/distilbert), and we build a custom `QuestionAnsweringTrainer(Trainer)` for training it, based on the official HuggingFace example [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/trainer_qa.py).

In [25]:
# We will use DistilBERT as the backbone foundation model to be fine-tuned
# Load the tokenizer for DistilBERT
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

# Load the model for the task: AutoModelForQuestionAnswering
# Note: This will throw warnings, which is expected!
model = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased')

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
# The Trainer subclass here is lightly modified from HuggingFace
# Original source at https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/trainer_qa.py
class QuestionAnsweringTrainer(Trainer):
    def __init__(self, *args, post_process_function=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.post_process_function = post_process_function

    def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
        predict_dataloader = self.get_test_dataloader(predict_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        start_time = time.time()
        try:
            output = eval_loop(
                predict_dataloader,
                description="Prediction",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        if self.post_process_function is None or self.compute_metrics is None:
            return output

        predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
        metrics = self.compute_metrics(predictions)

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
        metrics.update(output.metrics)
        return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)

## Pre-Processing and Post-Processing Functions

Pre- and post-processing functions are required to adapt the data and the outputs; they are taken from the HuggingFace `run_qa.py` example [on Github](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py):

1. `prepare_train_features(...)`: This function prepares raw QA examples for training by:
   - Tokenizing each (question, context) pair with padding and truncation.
   - Using return_offsets_mapping to trace back token positions to character positions in the context.
   - For each example:
     - Computes the start and end token indices of the answer span using the offsets.
     - If there's no answer, sets both positions to the CLS token index (for null prediction).
   - Returns tokenized inputs with additional "start_positions" and "end_positions" fields used for model training.
2. `postprocess_qa_predictions(...)`: This function transforms the model output logits into human-readable answer strings by:
   - Mapping predicted start and end logits back to their char spans in the original context using offset mappings.
   - Collecting top n-best candidate spans per feature based on score (start_logit + end_logit).
   - Filtering out invalid spans (too long, reversed, or not in max context).
   - Optionally handles "null answers" (no answer present) when version_2_with_negative=True.
   - Returns:
     - all_predictions: best answer per example.
     - all_nbest_json: top-n candidates per example.
     - Optionally, scores_diff_json: for null answer thresholding.
3. `post_processing_function(examples, features, predictions, stage="eval")`: A wrapper around `postprocess_qa_predictions`, specifically used in evaluation or inference, that:
   - Calls `postprocess_qa_predictions()` with sensible defaults.
   - Converts the final predictions to the format expected by HuggingFace metrics (e.g. "id", "prediction_text").
   - Prepares the reference answers in expected format (for metric computation like exact match or F1).
   - Returns an `EvalPrediction` object from the Transformers library.

In [28]:
# Training preprocessing
# Adapted from the HuggingFace `run_qa.py` example on Github:
# https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py
def prepare_train_features(examples):
    """Preprocesses each example for training a question answering model."""
    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=512,
        padding="max_length",
        return_offsets_mapping=True
    )

    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        answers = examples["answers"][i]

        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1

            # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
            # Note: we could go after the last offset if the answer is the last word (edge case).
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            tokenized_examples["start_positions"].append(token_start_index - 1)
            while offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            tokenized_examples["end_positions"].append(token_end_index + 1)
    
    return tokenized_examples

def postprocess_qa_predictions(
    examples,
    features,
    predictions,
    version_2_with_negative = False,
    n_best_size = 20,
    max_answer_length = 30,
    null_score_diff_threshold = 0.0,
):
    """
    Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
    original contexts. This is the base postprocessing functions for models that only return start and end logits.

    Args:
        examples: The non-preprocessed dataset (see the main script for more information).
        features: The processed dataset (see the main script for more information).
        predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
            The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
            first dimension must match the number of elements of :obj:`features`.
        version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the underlying dataset contains examples with no answers.
        n_best_size (:obj:`int`, `optional`, defaults to 20):
            The total number of n-best predictions to generate when looking for an answer.
        max_answer_length (:obj:`int`, `optional`, defaults to 30):
            The maximum length of an answer that can be generated. This is needed because the start and end predictions
            are not conditioned on one another.
        null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
            The threshold used to select the null answer: if the best answer has a score that is less than the score of
            the null answer minus this threshold, the null answer is selected for this example (note that the score of
            the null answer for an example giving several features is the minimum of the scores for the null answer on
            each feature: all features must be aligned on the fact they `want` to predict a null answer).

            Only useful when :obj:`version_2_with_negative` is :obj:`True`.
    """
    if len(predictions) != 2:
        raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
    all_start_logits, all_end_logits = predictions

    if len(predictions[0]) != len(features):
        raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    if version_2_with_negative:
        scores_diff_json = collections.OrderedDict()

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_prediction = None
        prelim_predictions = []

        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]
            # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
            # available in the current feature.
            token_is_max_context = features[feature_index].get("token_is_max_context", None)

            # Update minimum null prediction.
            feature_null_score = start_logits[0] + end_logits[0]
            if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or len(offset_mapping[start_index]) < 2
                        or offset_mapping[end_index] is None
                        or len(offset_mapping[end_index]) < 2
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue
                    # Don't consider answer that don't have the maximum context available (if such information is
                    # provided).
                    if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                        continue

                    prelim_predictions.append(
                        {
                            "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                            "score": start_logits[start_index] + end_logits[end_index],
                            "start_logit": start_logits[start_index],
                            "end_logit": end_logits[end_index],
                        }
                    )
        if version_2_with_negative and min_null_prediction is not None:
            # Add the minimum null prediction
            prelim_predictions.append(min_null_prediction)
            null_score = min_null_prediction["score"]

        # Only keep the best `n_best_size` predictions.
        predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

        # Add back the minimum null prediction if it was removed because of its low score.
        if (
            version_2_with_negative
            and min_null_prediction is not None
            and not any(p["offsets"] == (0, 0) for p in predictions)
        ):
            predictions.append(min_null_prediction)

        # Use the offsets to gather the answer text in the original context.
        context = example["context"]
        for pred in predictions:
            offsets = pred.pop("offsets")
            pred["text"] = context[offsets[0] : offsets[1]]

        # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
        # failure.
        if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
            predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})

        # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
        # the LogSumExp trick).
        scores = np.array([pred.pop("score") for pred in predictions])
        exp_scores = np.exp(scores - np.max(scores))
        probs = exp_scores / exp_scores.sum()

        # Include the probabilities in our predictions.
        for prob, pred in zip(probs, predictions):
            pred["probability"] = prob

        # Pick the best prediction. If the null answer is not possible, this is easy.
        if not version_2_with_negative:
            all_predictions[example["id"]] = predictions[0]["text"]
        else:
            # Otherwise we first need to find the best non-empty prediction.
            i = 0
            while predictions[i]["text"] == "":
                i += 1
            best_non_null_pred = predictions[i]

            # Then we compare to the null prediction using the threshold.
            score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
            scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
            if score_diff > null_score_diff_threshold:
                all_predictions[example["id"]] = ""
            else:
                all_predictions[example["id"]] = best_non_null_pred["text"]

        # Make `predictions` JSON-serializable by casting np.float back to float.
        all_nbest_json[example["id"]] = [
            {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
            for pred in predictions
        ]

    return all_predictions

# Post-processing:
def post_processing_function(examples, features, predictions, stage="eval"):
    """Wraps postprocess_qa_predictions for use in evaluation or prediction"""
    # Post-processing: we match the start logits and end logits to answers in the original context.
    version_2_with_negative = False  # If true, some of the examples do not have an answer
    predictions = postprocess_qa_predictions(
        examples=examples,
        features=features,
        predictions=predictions,
        version_2_with_negative=version_2_with_negative,  # If true, some of the examples do not have an answer
        n_best_size=20,  # The total number of n-best predictions to generate when looking for an answer
        # The maximum length of an answer that can be generated. This is needed because the start
        # and end predictions are not conditioned on one another
        max_answer_length=30,
        # The threshold used to select the null answer: if the best answer has a score that is less than
        # the score of the null answer minus this threshold, the null answer is selected for this example.
        # Only useful when `version_2_with_negative=True`.
        null_score_diff_threshold=0.0,
        output_dir="./output_dir",
        # 'debug', 'info', 'warning', 'error' and 'critical', 
        # plus a 'passive' level which doesn't set anything and keeps the
        # current log level for the Transformers library (which will be `"warning"` by default).
        log_level="warning",
        prefix=stage,
    )
    # Format the result to the format the metric expects.
    if version_2_with_negative:
        formatted_predictions = [
            {"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
        ]
    else:
        formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()]


    answer_column_name = "answers"
    references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
    return EvalPrediction(predictions=formatted_predictions, label_ids=references)

## Training

After mapping the `prepare_train_features` pre-processing function to the dataset, we train the model with `QuestionAnsweringTrainer`.

In [29]:
# Preprocess dataset
data = data.map(prepare_train_features, batched=True)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [31]:
# Set up our trainer
trainer = QuestionAnsweringTrainer(
    model=model,
    train_dataset=data,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    post_process_function=post_processing_function
)

  super().__init__(*args, **kwargs)


In [32]:
# Run the trainer!
trainer.train()

Step,Training Loss


TrainOutput(global_step=6, training_loss=0.0, metrics={'train_runtime': 61.3409, 'train_samples_per_second': 0.489, 'train_steps_per_second': 0.098, 'total_flos': 3919593000960.0, 'train_loss': 0.0, 'epoch': 3.0})

In [33]:
# Save our model
trainer.save_model("./ft-distilbert")

In [34]:
# Load the fine-tuned model and tokenizer from the local directory
model = AutoModelForQuestionAnswering.from_pretrained('./ft-distilbert')
tokenizer = AutoTokenizer.from_pretrained('./ft-distilbert')

# Test loading by printing model config
print(model.config)

DistilBertConfig {
  "_attn_implementation_autoset": true,
  "activation": "gelu",
  "architectures": [
    "DistilBertForQuestionAnswering"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torch_dtype": "float32",
  "transformers_version": "4.51.3",
  "vocab_size": 30522
}



## Inference

To run inference, we can use the `pipeline` function from HuggingFace, which in our case uses/returns a `QuestionAnsweringPipeline` object.

Note that DistilBert has a context window of 512 tokens only, whereas our contexts surpass that size; `QuestionAnsweringPipeline` handles that automatically by chunking the context with some overlap and running the model with each context. Then, the best answer (the one with the highest score) is chosen.

We can also write a custom `ask()` function that does that.

In [None]:
# Let's evaluate our model!
# Specify an input question and context
question = "What can an attacker do with XSS?"
with open("./data/xss.txt", "r") as f:
    context = f.read()

# The context is longer than the 512 token limit of DistilBERT
# but pipeline handles that by chunking the context
print(f"Number of words in context: {len(context.split())}")

# Use HuggingFace pipeline to answer the question
question_answerer = pipeline("question-answering", model="./ft-distilbert")
question_answerer(question=question, context=context)
# {'score': 5.4717063903808594e-05,
# 'start': 7927,
# 'end': 7942,
# 'answer': 'the victims did'}

Device set to use cpu


Number of words in context: 1355


{'score': 5.987742042634636e-05,
 'start': 7629,
 'end': 7653,
 'answer': 'the Samy worm on MySpace'}

In [None]:
# We can simulate what the pipeline does internally
# by running the model directly.
# To that end, we need to chunk the context
# and apply the model to each chunk.
# Finally, we select the answer with the highest score.
# However, as we see, the function needs some work to be fully robust...

import torch
import numpy as np

def ask(question, context, model, tokenizer, max_len=512, stride=128):
    # Tokenize with truncation & sliding window
    inputs = tokenizer(
        question,
        context,
        return_tensors="pt",
        max_length=max_len,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length"
    )

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    offset_mapping = inputs["offset_mapping"]
    overflow_to_sample_mapping = inputs["overflow_to_sample_mapping"]

    best_score = -float("inf")
    best_answer = ""

    for i in range(len(input_ids)):
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids[i].unsqueeze(0),
                attention_mask=attention_mask[i].unsqueeze(0)
            )

        start_logits = outputs.start_logits[0]
        end_logits = outputs.end_logits[0]

        # Compute score matrix
        start_index = torch.argmax(start_logits)
        end_index = torch.argmax(end_logits)

        # Make sure it's a valid span
        if start_index <= end_index and end_index < len(offset_mapping[i]):
            offsets = offset_mapping[i][start_index:end_index + 1]

            if len(offsets) > 0 and offsets[0] is not None and offsets[-1] is not None:
                start_char = offsets[0][0]
                end_char = offsets[-1][1]
                answer = context[start_char:end_char]
                score = start_logits[start_index] + end_logits[end_index]
                if score > best_score:
                    best_score = score
                    best_answer = answer


    return {
        "question": question,
        "answer": best_answer,
        "score": float(best_score),
    }

In [81]:
model = AutoModelForQuestionAnswering.from_pretrained('./ft-distilbert')
tokenizer = AutoTokenizer.from_pretrained('./ft-distilbert')

In [82]:
question = "What can an attacker do with XSS?"
with open("./data/xss.txt", "r") as f:
    context = f.read()

In [None]:
# As we see, the function needs some work to be fully robust...
# It is not working as well as the pipeline function
answer = ask(question, context, model, tokenizer)
print(answer)

{'question': 'What can an attacker do with XSS?', 'answer': "The following JSP code segment reads an employee ID, eid, from an HTTP request and displays it to the user. The following ASP.NET code segment reads an employee ID number from an HTTP request and displays it to the user. The code in this example operates correctly if the Employee ID variable contains only standard alphanumeric text. If it has a value that includes meta-characters or source code, then the code will be executed by the web browser as it displays the HTTP response. Example 3 This example covers a Stored XSS (Type 2) scenario. The following JSP code segment queries a database for an employee with a given ID and prints the corresponding employee's name. The following ASP.NET code segment queries a database for an employee with a given employee ID and prints the name corresponding with the ID. This code can appear less dangerous because the value of name is read from a database, whose contents are apparently managed