# Using pre-trained BERT for building a Q&A system

##### **😉 Hello there!**

Before starting to explore and run this notebook, here are a few things you should know:

<div class="alert alert-info">
<b>🧐 What is BERT?</b>

It is a language representation model which stands for Bidirectional Encoder Representations from Transformers. BERT model can be finetuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial taskspecific architecture modifications. See more details in the official paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.](https://arxiv.org/abs/1810.04805)
</div>

<div class="alert alert-info">
<b>🧐 What is transformers?</b>

[Transformers](https://huggingface.co/docs/transformers/index) provides APIs and tools to easily download and train state-of-the-art pretrained models. Using pretrained models can reduce your compute costs, carbon footprint, and save you the time and resources required to train a model from scratch. These models support common tasks in different modalities, such as:

- Natural Language Processing
- Computer Vision
- Audio
- Multimodal
</div>

<div class="alert alert-info">
<b>🧐 What is HuggingFace?</b>

[Hugging Face, Inc.](https://huggingface.co/) is a French company that develops tools for building applications using machine learning. It is most notable for its Transformers library built for natural language processing applications and its platform that allows users to share machine learning models and datasets. The [HuggingFace Hub ](https://huggingface.co/docs/hub/index) is a platform with over 120k models, 20k datasets, and 50k demo apps (Spaces), all open source and publicly available, in an online platform where people can easily collaborate and build ML together. The Hub works as a central place where anyone can explore, experiment, collaborate and build technology with Machine Learning.
</div>

# Notebook Overview
- Import Dependencies
- Configure Logging and Define Constants and Paths
- Dataset
- Tokenizer
- Targets
- Training
- Inference

# Install requirements and Imports Dependencies 

In [None]:
!pip install -r requirements.txt --quiet

In [None]:
# Standard Library Imports
import logging
from datetime import datetime


# Third-Party Libraries
import torch

# MLflow for Experiment Tracking and Model Management
import mlflow
import mlflow.pytorch

from datasets import load_dataset
from evaluate import load as load_metric
from tqdm.autonotebook import tqdm
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer, pipeline

# Define Constants and Paths and Configure Logging

In [None]:
# Define global experiment and run names to be used throughout the notebook
MODEL_CHECKPOINT = "distilbert-base-cased"
EXPERIMENT_SET = "BERT Q&A - distilbert-base-cased"
SAVE_MODEL_NAME = "distilbert_bertqa"

# Set up the paths
MODEL_PATH = "models:/BERT_QA"

# Set up the chunk separator for text processing
CHUNK_SEPARATOR = "\n\n"

In [None]:
# Configure the logging module with desired format and level
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Create a logger for this notebook
logger = logging.getLogger('training-notebook')
logger.info("Logging configured successfully")

In [None]:
start_time_all_execution = datetime.now() # This variable is to help us to see in how much time this notebook will run

# Dataset

## <b>🧐 What is the SQuAD dataset?</b>

Stanford Question Answering Dataset ([SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. See more details at SQuAD.

In this case, we're going to download this dataset from the [HuggingFace datasets](https://huggingface.co/datasets) repository.
</div>

In [None]:
squad_dataset = load_dataset("squad") # Downloading the dataset
squad_dataset

By printing the squad_dataset variable, we can see it is a dict composed by two key-values: the 'train' key, with all the train dataset as value, and the 'validation' key, with all the validation dataset as value. In this first part, we're using just the train dataset. Let's explore it a little bit, shall we?

## **The train dataset**

As we saw, we have and dict with two key-values. Let's access the "train" key.

In [None]:
squad_train_dataset = squad_dataset['train']
squad_train_dataset

We can see that, inside the squad_train_dataset, we have another type of data, which is a Dataset type, (very similar to a dict) composed by two key-values: 'features' and 'num_rows'. 
We have in this Dataset 87599 rows, which corresponds to 87599 indixes inside it, each index corresponds to one question-anwser input for our model, like this:



In [None]:
index_input = 1

squad_train_dataset[index_input]

Looking into the this input of this Dataset object, we can see our first input.

The inputs are a dict where each key comes from the 'features' list we just seen earlier, so we have:

- id: A unique id for each input
- title: The title for the question-answer (to give context)
- context: The text input for the model to search the answer for the question
- question: The question based on the context
- answers: The answer based on the context

Each of this feature can be accessed individually, like this:

In [None]:
logger.info(squad_train_dataset[index_input]['title'])
logger.info(squad_train_dataset[index_input]['context'])
logger.info(squad_train_dataset[index_input]['question'])
logger.info(squad_train_dataset[index_input]['answers'])

Something important to notice here is that the 'answers' is an key composed by a dict as a value with two key-values: 'text' and 'answer_start'.
The 'text' key corresponds to a list with the text answers, so yes! We can have more than one answer for each question in the datasets!
In this case, for the first input we just have one answer. But we should check in the rest of the inputs. Let's do this using the filter method.

In [None]:
# Checking in the train dataset if we have just one answer for each question
squad_train_dataset.filter(lambda x: len(x['answers']['text']) != 1)

Great, we don't have more than one answer for each question. 

Well, we're done exploring the train dataset. Let's go to the next part, shall we?

# Tokenizer

<b>🧐 What is a tokenizer?</b>

A tokenizer is in charge of preparing the inputs for a model. There are a range of tokenizers, so let's get to know the BERT tokenizer and what it does with the text.
</div>

## **BERT Tokenizer**

First thing we have to do is to load the [BERT](https://huggingface.co/docs/transformers/model_doc/bert) pre-trained model from the HuggigngFace Hub. In this case, we're using the [distilbert-base-cased](https://huggingface.co/distilbert-base-cased) model.

In [None]:
model_checkpoint_bbc = MODEL_CHECKPOINT #"bert-base-cased" is a larger option if you want to test!
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint_bbc) # getting the model's tokenizer 

For exploring it, let's take the same sample from the train dataset exploration.

In [None]:
# getting just context and the question
#print(index_input) # uncoment to remember the index
context = squad_train_dataset[index_input]['context'] 
question = squad_train_dataset[index_input]['question']
logger.info(context)
logger.info(question)

Let's pass the question and the context for the tokenizer, in this order, since it is the way BERT receives the inputs.

In [None]:
inputs = tokenizer(question, context) # Decoding the inputs
inputs

We can see that it returned a big dict. Let's chekc its keys.

In [None]:
inputs.keys()

We have 2 keys here, but in this experiment, we're just interested in the 'input_ids', okay? Let's understand it.

### **input_ids**

Let's get the first 26 items.

In [None]:
print(inputs['input_ids'][0:27])

You can be asking yourself what are those numbers. Let's decode them with the same tokenizer we've just used.

In [None]:
tokenizer.decode(inputs['input_ids'][0:27])

Just like magic, right? 😌

What the tokenizer did was to concatenate both the question and the context in one long string and give as output the dict with two keys, where the input_ids has a list of each token of the original context and answer in a numerical representation. For example, the token "What" from the text is corresponding to the input_id 1327.

Ok, and how about the tokens [CLS] and [SEP]?

Don't worry, I got you! These are called special tokens. CLS is the classification token and comes before the question. The SEP token is to delimit the begining and the end of the context. So, the format that BERT outputs is follows is:

**[CLS] 'question' [SEP] 'context' [SEP]**

### **BERT and (way too) long contexts**

We saw that the context of the inputs can be very long. Let's get the context from the previous part.

In [None]:
sentences = 0 
for i in context:
    if str(i) == '.':
        sentences += 1

logger.info(context)
print(f"Total number of words: {len(context)}")
print(f"Total number of sentences: {sentences}")

This was something that would worry in the NLP area, since it's pretty different from other commom application like next sentence predictions where the inputs were just single sentences, and not just that: BERT can only handle a limited number of tokens! (In 2023, right now, it is limited to 512 tokens). You could think "Why we just don't truncate the context?". Well, this is a terrible option since our answer could be cut off from it.

The solution that was found is to split the context into **multiple context windows**!
It means that one data sample will turn into multiple data samples, and at least one of them will certanly contain the answer.

But, what if a part of the answer begins in one window and is cut off and then the rest is in the next window?

Well, in this case, we use **overlaping windows**!
It is easier to understand and visualize all of those concepts when we use the tokenizer again. Let's go!

### **Understanding the model's tokenizer**

As before, we're using the same quenstion and context from the previous part and also passing them to the tokenizer in this specific order. 

- max_length: refers to the maximum length of the entire input (including the question, context and special tokens [CLS] and [SEP]). 
- truncation: here, we're saying that we just want to truncate the second input, which is the context.
- stride: this one defines how much overlap there is between the context windows when they're splited up.
- return_overflowing_tokens: this one is to return the overlaping tokens.

In [None]:
inputs = tokenizer(
    question,
    context, 
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True
)

Let's see what we've got in the inputs and its keys.

In [None]:
logger.info(inputs)
logger.info(inputs.keys())

Despite the 'attention_mask' key appears again, we're just interested in the other 2 keys. Let's dive in!

#### **input_ids**

We can see now that the 'input_ids' is now a list of lists.

In [None]:
print(f'Total number of lists: {len(inputs["input_ids"])}')

Let's decode those lists and see what we have.

In [None]:
for id in inputs["input_ids"]:
    print(tokenizer.decode(id))

Wonderful! The tokenizer has splited the context into 4 different inputs and has conserved the question and the special tokens in each one of them.

#### **overflow_to_sample_mapping**

To do a better demonstration of what is this new key, let's pass to the tokenizer more than one input sample.

In [None]:
question_samples = squad_train_dataset[:3]["question"] # Getting the first 3 questions
context_samples = squad_train_dataset[:3]["context"] #  and contexts

for i in question_samples:
    print(i)

In [None]:
logger.info(context) # You can check if you want but those questions are from the same context, so no need to print all of the 3.

Let's set up the tokenizer with these new samples and one more argument to understand the overflow mapping:

- return_offsets_mapping: this returns the start and end character for each token (it will be explained after this part)

In [None]:
inputs = tokenizer(
    question_samples, 
    context_samples,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True
)

Let's see how many windows we have now with 3 inputs:

In [None]:
logger.info(len(inputs["input_ids"]))

Great! But how do we know how much each input was splitted on? That's what the overflow_to_sample_mapping gives to us! Look at this:

In [None]:
inputs["overflow_to_sample_mapping"]

The first sample corresponds to 0, so it was splied into 4 windows. The second sample corresponds to 1, and it was also splitted into 4 windows. The for the third sample, which is number 2. Since they como from the same context, it's usual to have the same amount of windows for the samples.

In [None]:
for id in inputs["input_ids"]:
    logger.info(tokenizer.decode(id))

#### **offset_mapping**

To do a better demonstration, let's go back to the single input and pass to the same tokenizer from the previous part.

In [None]:
# print(question, "\n", context) # descoment this cell if you don't remember them

In [None]:
inputs = tokenizer(
    question,
    context, 
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True
)

inputs.keys()

We have now a new key, the offset_mapping key. Let's print it.

In [None]:
inputs['offset_mapping']

We can see that it is a list composed of lists of tuples.

In [None]:
logger.info(f"Total number of lists: {len(inputs['offset_mapping'])}")

Basically, the offset_mapping tell us the location of  the start and the end of each token from the **ORIGINAL** samples! Take a look at this:

In [None]:
logger.info(tokenizer.decode(inputs['input_ids'][0])) # Taking the first window
logger.info(inputs['offset_mapping'][0])

Looking at the sentence and the offset_mapping, the (0,0) is for the special tokens. The (0,4) is for the location of the token "What", which is composed by 4 chars, so it starts at the position 0 and ends at position 4. The next is the (5, 7) for the location of the token "is", which starts at position 5 and ends at position 7, and so on.

Notice that the word 'original' from the previous text is in bold and upper case. Let's understand why this is important in the next step.

# Targets

We have just seen that we can work with long contexts by spliting them into multiple windows, right?

In [None]:
for id in inputs["input_ids"]:
    logger.info(tokenizer.decode(id))


Well, the thing is that the answer in the dataset comes with a start position, remember?

In [None]:
answer = squad_train_dataset[index_input]['answers']
print(answer)
#print(context) # uncoment here to remember the original context

But this start position is within the original context that has not been splited yet. After spliting it into windows of context, that position is no longer valid, and for the model, this is exactly the target that it is waiting for. So, we need to align this targets within the windows now, considering that sometimes the answer may not exist in one specific window or only exist in part.

There's an useful method we can call from the tokenizer from all models (e.g. DistilBERT, bert-base-uncased, etc) which is the sequence_ids method. Let's take a look.

### **sequence_ids method**

In [None]:
print(inputs.sequence_ids(0)) # Getting the first window
print(tokenizer.decode(inputs['input_ids'][0])) # first window context

The meaning of each item is what follows:

- None: special tokens like [CLS] and [SEP]
- 0: is par of the question sentence
- 1: is par of the the context sentence

So, with this codification, we can now compare the start position of the answer in the original context and the start position of the answer in the windows!

### **Finding the answer: Window contexts and original context**

Let's see now what is the start index of the answer within the windows.

In [None]:
sequence_ids = inputs.sequence_ids(0) # Getting the first window

wind_ctx_start = sequence_ids.index(1) # Getting the first occurence of 1, which means the index where the context begins
wind_ctx_end = (len(sequence_ids) - sequence_ids[::-1].index(1) - 1) # Getting the index of the last 1, where the context ends

wind_ctx_start, wind_ctx_end

And the original start answer position (within the context)

In [None]:
logger.info(answer)
ans_start_char = answer['answer_start'][0]
ans_end_char = ans_start_char + len(answer['text'][0]) # the length of the text plus 515 is the final char of the answer

logger.info((ans_start_char, ans_end_char))

OK! Let's now use the offset mapping since it tells us about the char positions within the context.

In [None]:
offset = inputs['offset_mapping'][0] # First windows
# print(offset) # uncoment to remember the offset
# print(tokenizer.decode(inputs['input_ids'][0])) # and how they correspond to the original sentences

Since the offset has the original char starts and ends and also the (0,0) for special tokens (in this case, we are focosing in the [SEP] token that tell us where the context starts and ends) in the original context, we can compare if those original indeces match with the window context indices we have.

In [None]:
start_idx = 0
end_idx = 0

if offset[wind_ctx_start][0] > ans_start_char or offset[wind_ctx_end][1] < ans_end_char:
    logger.info("target is (0,0)")
else:
    i = wind_ctx_start
    for start_end_char in offset[wind_ctx_start:]:
        start, end = start_end_char
        if start == ans_start_char:
            start_idx = i

        if end == ans_end_char:
            end_idx = i 
            break

        i += 1
    
start_idx, end_idx

We need to get this indexes and use in the input_ids from the first sample and then decode it to see if the answers match.

In [None]:
input_ids = inputs['input_ids'][0]
# tokenizer.decode(input_ids) # uncoment to visualize

In [None]:
# Placing the start_idx and end_idx and decoding.
print(input_ids[start_idx:end_idx+1])
print(tokenizer.decode(input_ids[start_idx : end_idx + 1]))

Real answer

In [None]:
answer['text']

Yay! It matches!

Now, we just have to turn this into a function.

In [None]:
def find_asnwer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset
):
    """
    Finds the token-level start and end indices that align with the character-level 
    start and end positions of an answer within a given context span.

    Args:
        ctx_start (int): The index of the first token in the context span.
        ctx_end (int): The index of the last token in the context span.
        ans_start_char (int): The character-level start index of the answer in the original text.
        ans_end_char (int): The character-level end index of the answer in the original text.
        offset (List[Tuple[int, int]]): A list of tuples, where each tuple contains the 
            (start_char, end_char) positions of each token in the input sequence.

    Returns:
        tuple: A tuple (start_idx, end_idx) indicating the start and end token indices 
        of the answer within the context window. If the answer is not fully contained 
        within the context span, returns (0, 0).
    
    """
    

    # If the context doesn't fully contain the answer, return (0, 0)
    try:
        start_idx = 0
        end_idx = 0
        if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
            pass  # Answer does not exist in the context window
        else:
            i = ctx_start
            # Iterate over the offsets within the context window
            for start_end_char in offset[ctx_start:]:
                start, end = start_end_char
                if start == ans_start_char:
                    start_idx = i
                if end == ans_end_char:
                    end_idx = i
                    break
                i += 1
        logger.info("Finding the token-level start and end indices done successfully")
        return start_idx, end_idx
    
    except Exception as e:
            logger.error(f"Error finding the token-level start and end indices: {str(e)}")  

In [None]:
# now applying to the whole dataset
start_idxs = []
end_idxs = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sequence_ids = inputs.sequence_ids(i)

    ctx_start = sequence_ids.index(1)
    ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

    start_idx, end_idx = find_asnwer_token_idx(
        ctx_start,
        ctx_end,
        ans_start_char,
        ans_end_char,
        offset
    )

    start_idxs.append(start_idx)
    end_idxs.append(end_idx)

start_idxs, end_idxs

They are in this format because of the overlapping, remember?
In this input we have 4 windows, which means that for the firts window, the answer starts at index 53 and ends in index 57. Same for the second window. For the third and last windows, the answer does not appear. 😉

### Applying the tokenizer

One commum issue in this dataset is that some questions are badly formatted and have extra white spaces in the beggining of in the end of it. 

In [None]:
for q in squad_dataset["train"]["question"][:1000]:
    if q.strip() != q:
        logger.info(q)

So, let's define our tokenizer function and add this particular par for dealing with extra white spaces.

In [None]:
# Defining some fixed args
max_length = 384 # Indicated by Google
stride = 128

In [None]:
def tokenize_fn_train(batch):
    """
    Tokenizes a batch of question-context pairs and computes the token-level 
    start and end positions of the answers for training a Question Answering model.

    Args:
        batch (dict): A dictionary containing the keys:
            - 'question': List[str], the questions to be asked.
            - 'context': List[str], the corresponding context paragraphs.
            - 'answers': List[dict], where each dict contains:
                - 'text': List[str], the ground truth answer text(s).
                - 'answer_start': List[int], character-level start index of the answer in the context.

    Returns:
        dict: A dictionary with tokenized inputs ready for model training. It includes:
            - input_ids: Token IDs.
            - token_type_ids: Segment IDs.
            - attention_mask: Attention mask.
            - start_positions: Token-level start indices of the answer.
            - end_positions: Token-level end indices of the answer.

    
    """
    try:
    # Remove leading/trailing whitespaces from questions
        questions = [q.strip() for q in batch['question']]

        # Tokenize with sliding window and offset tracking
        inputs = tokenizer(
            questions,
            batch['context'],
            max_length=max_length,
            truncation="only_second",  # truncate only the context, not the question
            stride=stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length"
        )

        # Remove unused fields
        offset_mapping = inputs.pop("offset_mapping")
        orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")

        answers = batch['answers']
        start_idxs, end_idxs = [], []

        for i, offset in enumerate(offset_mapping):
            sample_idx = orig_sample_idxs[i]
            answer = answers[sample_idx]

            # Extract character positions of the answer
            ans_start_char = answer['answer_start'][0]
            ans_end_char = ans_start_char + len(answer['text'][0])

            sequence_ids = inputs.sequence_ids(i)

            # Locate the context span in the tokenized input
            ctx_start = sequence_ids.index(1)
            ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

            # Align character-level answer to token indices
            start_idx, end_idx = find_asnwer_token_idx(
                ctx_start,
                ctx_end,
                ans_start_char,
                ans_end_char,
                offset
            )

            start_idxs.append(start_idx)
            end_idxs.append(end_idx)

        # Add the computed start and end positions to the tokenized inputs
        inputs["start_positions"] = start_idxs
        inputs["end_positions"] = end_idxs

        logger.info("tokenize_fn_train done successfully")
        return inputs
    
    except Exception as e:
            logger.error(f"Error tokenizing batch of question-context: {str(e)}")
    

#### Tokenizing the train dataset

In [None]:
train_dataset = squad_train_dataset.map(
    tokenize_fn_train,
    batched=True,
    remove_columns=squad_train_dataset.column_names
)

In [None]:
# The actual train dataset ir a little bit  bigger than the original 
# Because we've expanded the context in windows
logger.info(f'Processed dataset: {len(train_dataset)}\nOriginal dataset: {len(squad_dataset["train"])}')

Creating the same function for the validation dataset

In [None]:
def tokenize_fn_validation(batch):
    """
    Tokenizes a batch of question-context pairs and computes the token-level 
    start and end positions of the answers for validating a Question Answering model.

    Args:
        batch (dict): A dictionary containing the keys:
            - 'question': List[str], the questions to be asked.
            - 'context': List[str], the corresponding context paragraphs.
            - 'answers': List[dict], where each dict contains:
                - 'text': List[str], the ground truth answer text(s).
                - 'answer_start': List[int], character-level start index of the answer in the context.

    Returns:
        dict: A dictionary with tokenized inputs ready for model training. It includes:
            - input_ids: Token IDs.
            - token_type_ids: Segment IDs.
            - attention_mask: Attention mask.
            - start_positions: Token-level start indices of the answer.
            - end_positions: Token-level end indices of the answer.

  
    """
    try:
        questions = [q.strip() for q in batch['question']]

        inputs = tokenizer(
            questions, 
            batch['context'],
            max_length=max_length,
            truncation="only_second",
            stride=stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length"
        )

        orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
        sample_ids = []

        for i in range(len(inputs["input_ids"])):
            # Getting the corresponding ID from the original samples (their identify the questions and contexts remember?) 
            sample_idx = orig_sample_idxs[i]
            sample_ids.append(batch["id"][sample_idx])

            sequence_ids = inputs.sequence_ids(i) # 1:context | 0:question | (0,0): special tokens
            offset = inputs["offset_mapping"][i] # getting the sequence_ids for this sample

            # Modifying the original offset_mapping 
            # When it is (0,0) or 0 replace with None
            # And get just the context
            inputs["offset_mapping"][i] = [
                x if sequence_ids[j] == 1 else None for j, x in enumerate(offset)
            ]

        inputs["sample_id"] = sample_ids
        logger.info("tokenize_fn_validation done successfully")
        return inputs
    
    except Exception as e:
            logger.error(f"Error tokenizing batch of question-context: {str(e)}")

In [None]:
validation_dataset = squad_dataset["validation"].map(
    tokenize_fn_validation,
    batched=True,
    remove_columns=squad_dataset["validation"].column_names
)

logger.info(f'Processed dataset: {len(validation_dataset)}\nOriginal dataset: {len(squad_dataset["validation"])}')

## Metrics and Logits

We can load a metric called "squad" anduse it in our problem! Let's see how it will work.

In [None]:
metric = load_metric("squad")

In [None]:
# making some examples

pred_answers = [
    {'id': '1', 'prediction_text': 'Strawberry'},
    {'id': '2', 'prediction_text': 'Agriculture industry'},
    {'id': '3', 'prediction_text': 'Red'}
]

true_answers = [
    {'id': '1', 'answers': {'text': ['Strawberry'], 'answer_start': [80]}},
    {'id': '2', 'answers': {'text': ['Agroindustry'], 'answer_start': [65]}},
    {'id': '3', 'answers': {'text': ['Red'], 'answer_start': [100]}}
]

# checking the metrics

metric.compute(predictions=pred_answers, references=true_answers)

But, before heading to metrics, let's remember that the model outputs are Logits, numbers! We have to make it back to numbers.

For that, we're donwloading a pretrained question-answering model to get predictions that are not random and use those predictions to learn how to convert the logits into answer strings.

With it, we won't need the whole dataset, but just a part of it for learning how to turn them into strings!

Let's dive in!

##### Learning how to transform logits into answers

In [None]:
small_validation_dataset = squad_dataset["validation"].select(range(100)) # Getting just the first 100 samples from the validation set 
trained_checkpoint = "distilbert-base-cased-distilled-squad" # model trained in q&a

tokenizer2 = AutoTokenizer.from_pretrained(trained_checkpoint) # new tokenizer from distilbert-base-cased-distilled-squad

# Here, since the tokenizer is a global variable 
# And we're training it with another model trained in q&a
# We're temporarily exchanging this global variable for the tokenizer2
original_tokenizer = tokenizer
tokenizer = tokenizer2

Now, let's process this small validation dataset

In [None]:
small_validation_processed = small_validation_dataset.map( # Now, we can use this new tokenizer from distilbert-base-cased-distilled-squad
    tokenize_fn_validation,                                 # and map it into our small validation dataset using the function tokenize_fn_validation
    batched=True,
    remove_columns=squad_dataset["validation"].column_names
)

Once this cell is done, let's just get back with the first and original tokenizer from distilbert-base-cased model

In [None]:
tokenizer = original_tokenizer

Now, it's time to change some things in our small dataset like unsed columns and change to torch format in order to pass the inputs to process in the GPU

In [None]:
small_model_inputs =  small_validation_processed.remove_columns(['sample_id', 'offset_mapping']) # unused columns
small_model_inputs.set_format("torch")

#### Setting the GPU

Once the previous step is done, it's time to set the GPUas our device and move the inputs (now tensors) to there 

In [None]:
# Setting the GPU as current device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

In [None]:
small_model_inputs_gpu = {
    k: small_model_inputs[k].to(device) for k in small_model_inputs.column_names
}
# All the data will come from the GPU now

Downloading the distilbert-base-cased-distilled-squad model and setting into the GPU

In [None]:
trained_model =  AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(device)

Getting the model's output

In [None]:
with torch.no_grad(): # This is just saying that we're not using any compution gradient (like we're not training)
    outputs = trained_model(**small_model_inputs_gpu) # passing the inputs to distilbert-base-cased-distilled-squad and getting the outputs

Great! Let's see those outputs

In [None]:
outputs

This kind of QuestionAnsweringModelOutput object is composed with a tuple containing the start_logits and the end_logits.

#### Turning the logits into IDs

In [None]:
# Here, we're getting the logits, moving back to CPU and formatting as a numpy array (we don't need them in the tensor format anymore)
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

Let's remember how the ID's look like in the small_validation_processed

In [None]:
small_validation_processed["sample_id"][:3] # remember that small_validation_processed was processed by distilbert-base-cased-distilled-squad tokenizer

And how they look like in our validation dataset

In [None]:
validation_dataset["sample_id"][:3]

Also, they're not unique! Remebember that one ID could come from more than one question-answer input because of the windows? One input could be splitted into 2 or 3 or more windows, but they still form the same input.

In [None]:
logger.info(f"Total ID's from validation: {len(validation_dataset['sample_id'])},\nTotal unique ID's from validation: {len(set(validation_dataset['sample_id']))}")

So, for handling this case, we're building an dict where the key is the ID and the value is going to be an list pointing the indixes taht this ID corresponds in the small_validation_processed. If we have an input that was splited into 3 parts for examples, we want something like this:

{'56be4db0acb8001400a502ef': [5, 6, 7]}

In [None]:
sample_id2idxs = {}

for i, id_ in enumerate(small_validation_processed['sample_id']): # looping through all the ID's and enumerating to get the index
    if id_ not in sample_id2idxs: # Checking if this ID existis
        sample_id2idxs[id_] = [i] # If not, we create an entry with the format we just saw above.
    else:
        logger.info("here") # If existis,
        sample_id2idxs[id_].append(i) # we just append into the existing list

In [None]:
# sample_id2idxs # uncoment to see the result

Great! Now, let's understand how we're turning this into strings.

Let's check the shape of our logits. We expect them to be in this shape:

(number_of_samples, max_length)

In [None]:
start_logits.shape, end_logits.shape # remember that they come from the outputs we got from distilbert-base-cased-distilled-squad

We will need to sort the indices of the logits in order to get where the values are stored within this indices.
First, we negate the the values by placing a '-' in front of the array. This will sort them in descending order.

In [None]:
# uncoment to see
#print(start_logits[0])
#print()
#print(-start_logits[0])

In this way, the largest values will be at the front. Then, when we call the argsort method, we organize this array by ascending order, getting this result:

In [None]:
indices = (-start_logits[0]).argsort() # here, we are taking just the first position for example.
indices

If we use those indices in the original array, we get this result:

In [None]:
start_logits[0][indices]

We have the start_logits showing in descending order!


Now, let's really transform the logits into string and you'll understand everything.

In [None]:
n_largest = 20 # Number of start and end logits we want to search
max_answer_length = 30 # Max answer length we want to allow
predict_answers = [] # List of predicted answers will be stored 

for sample in small_validation_dataset: # For each sample in the NON-processed (it is not tokenized!) small validation dataset (CTRL+click if you want to remember)
    sample_id = sample["id"] # Get the id from this sample
    context = sample["context"] # and the context

    # Initializing best_score and best_answer (they'll be update in the below looping)
    best_score = float("-inf") 
    best_answer = None

    for idx in sample_id2idxs[sample_id]: # For each id in the sample_id2idxs (samples here are tokenized!) in the sample_id as index (remebmer it is a dict)
        # Grabbing the start and end logits for this index
        start_logit = start_logits[idx] 
        end_logit = end_logits[idx]
        # And also get the offset mapping for this index
        offsets = small_validation_processed[idx]["offset_mapping"] # note that this offset mapping is the processed, containg None for any position
                                                                    # that is not in the context
        # Sorting the logits as we saw                                                           
        start_indices = (-start_logit).argsort() 
        end_indices = (-end_logit).argsort()

        # Next step is to loop through the n_largest start and end logits
        for start_idx in start_indices[:n_largest]:
            for end_idx in end_indices[:n_largest]:
                # Checking the cases where the answer:
                if offsets[start_idx] is None or offsets[end_idx] is None: # Answer is not in the context
                    continue
                if end_idx < start_idx: # Answer does not exist (since is has negative length)
                    continue
                if (end_idx - start_idx + 1) > max_answer_length: # Answer is longer than allowed
                    continue

                # If we have an answer,
                score = start_logit[start_idx] + end_logit[end_idx] # Compute the score for this answer
                if score > best_score: # Checking if score is better than the current best_score
                    best_score = score # If yes, compute

                    # Getting the position of the first character and of the last character
                    first_ch = offsets[start_idx][0] 
                    last_ch = offsets[end_idx][1]
                    # Retrieving the answer as actual text using the them as indices in the context
                    best_answer = context[first_ch:last_ch]
        # And finally append to the list        
        predict_answers.append({"id": sample_id, "prediction_text": best_answer})

Onde this is done, we just need to format the true answer in the right format for computing the metrcis.

To remember the format, go to the beggining of the "Metrics and Logits" outline.

In [None]:
true_answers = [
    {
    "id": x["id"],
    "answers": x["answers"]
    }
    for x in small_validation_dataset
]

In [None]:
#true_answers # uncoment to see the result

Yay! We can now turn the logits into string and finally compute the metrics!

In [None]:
metric.compute(predictions=predict_answers, references=true_answers)

Let's turn the whole process into a function called compute_metrics

#### Computing metrics

In [None]:
def compute_metrics(start_logits, end_logits, processed_dataset, orig_dataset):
    """
    Computes evaluation metrics (e.g., EM/F1) for a Question Answering model by comparing
    predicted answer spans with ground truth answers.

    Args:
        start_logits (List[np.ndarray]): List or array of start logits for each tokenized sample.
        end_logits (List[np.ndarray]): List or array of end logits for each tokenized sample.
        processed_dataset (Dataset or List[dict]): The tokenized dataset used for evaluation.
            Must contain:
                - "sample_id": Original sample ID corresponding to each tokenized chunk.
                - "offset_mapping": Mapping of token indices to character spans in the context.
        orig_dataset (List[dict]): The original dataset (before tokenization) with keys:
            - "id" (str): Unique identifier for each QA sample.
            - "context" (str): Context paragraph containing the answer.
            - "answers" (dict): Ground truth answers with keys:
                - "text": List of correct answer strings.
                - "answer_start": List of starting character positions.

    Returns:
        dict: A dictionary containing computed metrics, such as:
            - "exact_match": Percentage of exact match predictions.
            - "f1": F1 score comparing predicted vs. ground-truth answers.

    Example:
        >>> from datasets import load_metric
        >>> metric = load_metric("squad")
        >>> compute_metrics(start_logits, end_logits, tokenized_data, raw_data)
    """
    try:   
        sample_id2idxs = {}

        for i, id_ in enumerate(processed_dataset["sample_id"]):
            if id_ not in sample_id2idxs:
                sample_id2idxs[id_] = [i]
            else:
                sample_id2idxs[id_].append(i)

        predicted_answers = []
        for sample in tqdm(orig_dataset):

            sample_id = sample["id"]
            context = sample['context']

            best_score = float("-inf")
            best_answer = None

            for idx in sample_id2idxs[sample_id]:
                start_logit = start_logits[idx]
                end_logit = end_logits[idx]

                offsets = processed_dataset[idx]["offset_mapping"]

                start_indices = (-start_logit).argsort()
                end_indices = (-end_logit).argsort()

                for start_idx in start_indices[:n_largest]:
                    for end_idx in end_indices[:n_largest]:
                        if offsets[start_idx] is None or offsets[end_idx] is None:
                            continue

                        if end_idx < start_idx:
                            continue

                        if (end_idx - start_idx + 1) > max_answer_length:
                            continue

                        score = start_logit[start_idx] + end_logit[end_idx]
                        if score > best_score:
                            best_score = score

                            first_ch = offsets[start_idx][0] 
                            last_ch = offsets[end_idx][1]
                            best_answer = context[first_ch:last_ch]
                    
            predicted_answers.append({"id": sample_id, "prediction_text": best_answer})
        true_answers = [{"id": x["id"], "answers": x["answers"]} for x in orig_dataset]
        y = metric.compute(predictions=predicted_answers, references=true_answers)

        logger.info("Computing metrics done successfully")
        return y
    except Exception as e:
            logger.error(f"Error computing the metrics: {str(e)}")

Let's run the function on the small datasets we used earlier

In [None]:
compute_metrics(
    start_logits,
    end_logits,
    small_validation_processed,
    small_validation_dataset
)

Great!
This function will be used after our training step is done!

# Training

In [None]:
mlflow.end_run()
mlflow.set_experiment(EXPERIMENT_SET)

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint_bbc) # Loading the model we want to fine-tune (distilbert-base-cased)

Now it is time to create our TrainingArguments object with all the necssary arguments for the training step

In [None]:
args = TrainingArguments(
    "finetuned-squad", # this is a default name for this model and task
    evaluation_strategy="no", # No, because we'll compute metrics manually
    save_strategy="epoch", # saving for each step (you can use epoch as well)
    learning_rate=2e-5, # learnin rate value 
    num_train_epochs=3, # 3 epoch in total (max is 4 since out inputs are very large, more tha that is not recommended)
    weight_decay=0.01, # regularization technique
    fp16=True # speed up the process
)

Now, let's instantiate a trainer object

In [None]:
trainer = Trainer(
    model=model, # our model
    args=args, # our args
    train_dataset=train_dataset, # our datasets
    eval_dataset=validation_dataset,
    tokenizer=tokenizer # and our tokenizer
)

Checking if the GPU is available

In [None]:
torch.cuda.is_available()

The time has come! Let's train!

In [None]:
mlflow.end_run()
start_time_training = datetime.now() # this is for computing the time it take for the training
with mlflow.start_run():
    trainer.train() 
logger.info(f'Total time for training: {datetime.now() - start_time_training}')

## Prediction

Now, let's do the evaluation.

In [None]:
trainer_prediction = trainer.predict(validation_dataset) # getting the predictions for the validation set
trainer_prediction

And grab just the prediction values from this objetc

In [None]:
predictions, _, _ = trainer_prediction
predictions

We have a tuple with two arrays, the start_logits and end_logits!

In [None]:
start_logits, end_logits = predictions

##### Computing the metrics

In [None]:
compute_metrics(
    start_logits,
    end_logits,
    validation_dataset,
    squad_dataset['validation']
)

Saving the model for further usage

In [None]:
trainer.save_model(SAVE_MODEL_NAME)

# Inference

We can create a question-answering pipeline from transformers and pass our model to it.

In [None]:
qa = pipeline(
    'question-answering',
    model = SAVE_MODEL_NAME,
    device=0 #GPU
)

Testing the pipeline

In [None]:
context = "Tomorrow the Atlântico is going to have a delicious team lunch!"
question = "What did the Atlântico is going to have tomorrow?"
qa(context=context, question=question)

In [None]:
logger.info(f' {datetime.now() - start_time_all_execution}')