# A deeper dive into faster fine-tuning with Packed BERT

Using the concept of bin-packing for training large language models has been explored previously, for pre-training BERT on enormous datasets such as the Wikipedia corpus, with successful results, achieving a [2x speed increase](https://www.graphcore.ai/posts/introducing-packed-bert-for-2x-faster-training-in-natural-language-processing) using Graphcore's non-negative least-squares histogram packing algorithm. This notebook is an aside to the **Accelerating fine-tuning and inference with PackedBERT** article and the Graphcore Optimum [Packed BERT notebooks](https://github.com/huggingface/optimum-graphcore/tree/main/notebooks/packed_bert) which offer an easy-to-use packageable implementation of packing to rapidly get you going with faster fine-tuning and inference. The contents of this notebook include:

## Contents
1. [Preprocessing: the packing algorithm](#the-packing-algorithm)
2. [Preprocessing: creating the dataset](#creating-the-actual-dataset)
3. [Modifications to the model](#modifications-to-model-processing)
4. [Postprocessing](#postprocessing-the-returned-logits)
5. [A note on training and inference](#a-note-on-training-and-inference)
6. [Summary](#in-summary)

In this notebook walkthrough, we will take a look at the details of *how* to implement packing to provide significant throughput increases to fine-tuning (and inference) tasks, such as sequence classification and question answering, from preprocessing to the implementation of the model. Concepts from prior research and some reconsiderations have allowed us to create a more online-capable packing approach. Fine-tuning and inference for smaller tasks might not take as long as the pre-training itself, but can still take in the order of minutes to hours, the order of which packing can reduce down to *seconds to minutes*, a crucial step for online solutions and tasks like batched inference and even rapid prototyping for speedy deployment. 

For more of a background on packing and a simple and quick walkthrough of how to get going with packing, take a look at the article: [Accelerating fine-tuning and inference with PackedBERT]

The implementation of packing outlined here focuses on the principles of how to functionally enable packing for a dataset, create a packed dataset and prepare it for training or inference. It demonstrates the example of single-label sequence classification as an example, noting examples of what changes are required for other tasks such as multi-label classification and question answering, to provide a conceptual understanding of what kind of changes might be required to implement packing for a new task. A couple of essential points should be noted here when attempting to implement packing for a new dataset or new task:
* Ensure the dataset's sequence length distribution is sufficiently skewed to the shorter side - if greater than 50% of the sequences are of less than 50% of the maximum sequence length, you can expect some performance benefit.
* Packing is designed to work with transformer models - leveraging the token and sequence independent behaviour possible due to the self-attention mechanism, and may generally be applicable to models which do not have significant cross-token interaction within the model*. It may be important to understand the output of the final layer looks like to determine how best to 'un-pack' the logits for the final output.

**We can expect to modify the transformer output stage to avoid cross-token and cross-sequence interaction in the loss calculation and logits retrieval.*

The current implementation of packing is open to optimisation. It demonstrates its principles and performs fast enough for fine-tuning/inference, however, it is a bare-bones implementation in Python, and the preprocessing time could be significantly improved by using multi-processing, Rust/C++ based dataset creation and more. This notebook is largely framework-agnostic in terms of training, it focuses on how to create the dataset, preprocess, create and instantiate the model to help you understand or implement it for your own use case. It doesn't focus on dataset-specific training or inference, but rather how to get your setup to a point where running training and inference does not differ from a generic BERT model.

<a id='the-packing-algorithm'></a>
## Preprocessing: the packing algorithm

For simplicity, lets use a single-label classification dataset with a binary label. `sst2` is small in size, with simple postprocessing.

The first step requires loading and tokenizing the dataset, the packing process operates on the tokenized dataset, this step is not specific to packing, but there is one caveat: Make sure to **not pad** when tokenizing - any necessary padding is done later during the packed dataset creation stage, 

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

max_seq_len = 256

# Load the dataset
dataset = load_dataset("sst2")

# Load tokenizer from the default pre-trained BERT checkpoint
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)

# Use map function to iterate over each sample in dataset
def preprocess_function(examples):
    return tokenizer(examples["sentence"], truncation=True, max_length=max_seq_len)


# Retrieve the tokenized dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)
tokenized_dataset = tokenized_dataset["train"]

In [None]:
tokenized_dataset

That is all we need to do to have the dataset ready for packing. The preprocess function will change for different tasks, and label preprocessing may be needed. This is demonstrated in the task-specific Graphcore Optimum notebooks.

The process of generating a packing strategy and creating an iterable PyTorch dataset for your model is contained within the `PackedDatasetCreator` class for the Optimum notebooks. The first step is to create the histogram, in the function `generate_histogram()`.

### Generating the histogram

We first generate a histogram to pack the data, the histogram is of bin-size 1, with the number of bins equal to the maximum sequence length - this granularity is necessary to account for every sequence in the dataset. It essentially lists the number of sequences that are equal to any sequence length up to the maximum. This provides a summary of the size of sequences in the dataset, allowing the packing algorithm to create a "strategy", i.e., an order in which to arrange the sequences based on the number that can be fit into one pack, to pack as many sequences as possible together in the least number of packs possible.

First, we retrieve sequence lengths and create an empty histogram array.

In [None]:
import numpy as np

dataset_seq_lens = np.array([len(seq) for seq in tokenized_dataset["input_ids"]])
histogram = np.zeros(max_seq_len, dtype=np.int64)

The useful `np.unique` function effectively returns a tuple containing a set of the lengths array (each unique length), and the number of times each unique length occurs (when `return_counts` is set to `True`). This allows us to easily create the histogram, when the bin size is 1, it is equivalent to the returned `counts`.

In [None]:
seq_lens, counts = np.unique(dataset_seq_lens, return_counts=True)
histogram[seq_lens - 1] = counts

We can then plot this histogram, to have a look at the skewed length distribution:

In [None]:
import matplotlib.pyplot as plt

# Formatting
plt.style.use(plt.style.available[-2])
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
plt.rcParams["figure.figsize"] = (10, 4)
plt.rcParams["figure.dpi"] = 100

# Plotting
plt.subplot(1, 2, 1)
plt.hist(histogram, bins=[k for k in range(0, max_seq_len, 10)])
plt.title("Sequence length histogram")
plt.xlabel("Sequence lengths")
plt.ylabel("Frequency")

plt.subplot(1, 2, 2)
plt.plot(np.arange(max_seq_len) + 1, histogram)
plt.title("Sequence length distribution")
plt.xlabel("Sequence lengths")
plt.ylabel("Num. samples")

plt.show()

The histogram is heavily skewed towards smaller sequence lengths, and should be a good fit for packing.

### Generating a strategy using a histogram-based algorithm

There are three algorithms mentioned in the original [pre-training article](https://www.graphcore.ai/posts/introducing-packed-bert-for-2x-faster-training-in-natural-language-processing), these are:
 
* Non-negative least squares histogram packing (NNLSHP)
* Shortest-pack-first histogram packing (SPFHP)
* Longest-pack-first histogram packing (LPFHP)

All of these algorithms use the sequence length histogram to try to optimally create packs of lengths to fill up the least number of packs to the the maximum sequence length. [Bin-packing](https://en.wikipedia.org/wiki/Bin_packing_problem) is a general computational NP-hard problem, and is non-trivial to solve optimally. 

In previous implementations for large dataset pre-training, the most optimal possible configuration of lengths was achieved using NNLSHP. This is a least-squares approach, where for $Ax=b$, the histogram is $b$ and we attempt to retrieve a strategy vector $x$. This algorithm is explained neatly in the original article:

> *"The tricky part was the strategy matrix. Each column has a maximum sum of three and encodes which sequences get packed together to exactly match the desired total length; 512 in our case. The rows encode each of the potential combinations to reach a length the total length. The strategy vector x is what we were looking for, which describes how often we choose whichever of the 20k combinations. Interestingly, only around 600 combinations were selected at the end. To get an exact solution, the strategy counts in x would have to be positive integers, but we realised that an approximate rounded solution with just non-negative x was sufficient. For an approximate solution, a simple out-of-the-box solver could be used to get a result within 30 seconds."*

This algorithm increases significantly in complexity to offer an approximately optimal solution for a sequences-per-pack value greater than 3, where the time taken can become unreasonably long. For 3 sequences at most per pack, it solves in up to 30 seconds, an approach that is reasonable for pre-training, but for online-capable tasks, smaller dataset fine-tuning and inference, this can be far too long. Another advantage of the commonly present severely skewed distributions in fine-tuning datasets is that we would really like to take advantage of the benefits of packing, and pack many more sequences per pack, from 6 up to 12 or even higher if the dataset allows it. This is unfeasible with NNLSHP, so lets take a look at the simpler, slightly more naive but faster algorithm, SPFHP. 

SPFHP scales incredibly well to increasing number of sequences per pack. Unlike NNLSHP, it doesn't attempt to pre-emptively solve the problem. Rather, it operates on a sorted histogram of lengths from longest to shortest and simply looks at each sequence, and checks whether it fits into any pack. If it fits into one pack, it will be placed in that pack - not considering any future sequences - and if it fits into multiple, it will be places into the pack with shortest sequence length. This way, even though it doesn't offer optimality, it is more appropriate for small to medium sized datasets and its complexity is not increased by increasing the number of packs. It solves for any given dataset in almost constant time, taking under 0.02 seconds for up to 16 million samples. The complexity of SPFHP increases with sequence length rather than dataset size or sequences-per-pack, so it remains relatively constant for different dataset sizes!

LPFHP is a shortest-to-longest variant of SPFHP, splitting counts to get more ideal fits. It rarely offers significantly more efficiency and is slightly more computationally complex. In some cases, it can be useful to have slightly more optimality. The utils created for the Optimum Graphcore packed BERT notebooks allows one of SPFHP or LPFHP to be used, to mitigate the potential preprocessing bottleneck of using NNLSHP as well as allow any number of maximum sequences per pack.

*The algorithm code for all packing algorithms is available in the [blogs code](https://github.com/graphcore/tutorials/tree/master/blogs_code/packedBERT).*

Lets have a look at **SPFHP**:

First we reverse the histogram to put it into order from longest to shortest sequence, and in initialise the dictionaries to store our outcome strategies for packing the data:

In [None]:
from collections import defaultdict

reversed_histogram = np.flip(histogram)

tmp_strategies_per_length = defaultdict(list)
strategies_per_length = defaultdict(list)

Then define an `add_pack` function which simply adds a pack, and the number of times the pack can be created, to the final strategies if the pack is complete, or the temporary strategies if space is left.

In [None]:
def add_pack(pack, count, tmp, final, limit, offset):
    if len(pack) == limit or offset == 0:
        final[offset].append((count, pack))
    else:
        tmp[offset].append((count, pack))

Lets also define the `max_seq_per_pack`, i.e. the maximum number of sequences we allow a pack to contain:

In [None]:
max_seq_per_pack = 6

This is the central loop in the algorithm, which iterates through each possible length in longest-to-shortest order, it checks each time whether the new length can fit into any packs of existing lengths, otherwise creates a new pack. It uses the temporary strategies to hold packs which still have space for new sequences to be added, when they cannot be added, it adds them to the final dictionary of packs.

With multiple sequences for each length, it checks whether multiple packs of the same configuration can be formed at each step, reducing the iterations significantly by creating multiple packs in one iteration. 

Lets also time each block so we can summarise the efficiency at the end:

In [None]:
import time

# Iterate over the sequence length for each number of sequences of that length in histogram
st_loop = time.time()

for i in range(max_seq_len):
    n_sequences_to_bin = reversed_histogram[
        i
    ]  # The number of sequences of the length i
    length_to_bin = (
        max_seq_len - i
    )  # Difference between max available length and current sequence length
    offset = i + 1  # Largest possible offset

    # Iterate through the number of sequences for this bin
    while n_sequences_to_bin > 0:

        # Check temporary strategies for packs where this new length could fit.
        if (length_to_bin + offset) in tmp_strategies_per_length:

            # Extract shortest pack that will get modified
            n_sequences_to_pack, pack = tmp_strategies_per_length[
                length_to_bin + offset
            ].pop()

            # Update the pack (create new pack) with added length
            new_pack = pack + [length_to_bin]

            # Check how many temporary packs this length can fit into
            count = min(n_sequences_to_pack, n_sequences_to_bin)

            # Only update the number of packs the length fits into
            if n_sequences_to_pack > n_sequences_to_bin:

                # Old pack gets reduced
                n_sequences_to_pack -= n_sequences_to_bin

                # Update the temporary packs
                tmp_strategies_per_length[length_to_bin + offset].append(
                    (n_sequences_to_pack, pack)
                )
                n_sequences_to_bin = 0
            else:
                n_sequences_to_bin -= n_sequences_to_pack

            # Add the pack to the correct strategy dict
            add_pack(
                new_pack,
                count,
                tmp_strategies_per_length,
                strategies_per_length,
                max_seq_per_pack,
                offset,
            )

            # Clean up to speed up main key search
            if not tmp_strategies_per_length[length_to_bin + offset]:
                tmp_strategies_per_length.pop(length_to_bin + offset)
        else:
            offset -= 1

        # Does not fit anywhere. Create new pack
        if offset < 0:
            add_pack(
                [length_to_bin],
                n_sequences_to_bin,
                tmp_strategies_per_length,
                strategies_per_length,
                max_seq_per_pack,
                i,
            )
            n_sequences_to_bin = 0

en_loop = time.time()

Finally, all of the strategies are merged. By the time the whole dataset has been iterated over, some packs will still be in the temporary packs stage, as they could fit more sequences. This is fine, as we are trying to achieve the fastest practical solution. These are added into the final strategies, and the dictionaries are flattened.

In [None]:
st_clean = time.time()

# Merge all strategies
for key in tmp_strategies_per_length:
    strategies_per_length[key].extend(tmp_strategies_per_length[key])

# Flatten strategies dictionary
strategy_set = []
strategy_repeat_count = []
for key in strategies_per_length:
    for count, pack in strategies_per_length[key]:
        pack.reverse()
        strategy_set.append(pack)
        strategy_repeat_count.append(count)

en_clean = time.time()

We can summarise the efficiency of the packing algorithm:

In [None]:
duration = (en_loop - st_loop) + (en_clean - st_clean)
sequence_lengths = np.arange(1, max_seq_len + 1)
strategy_repeat_count = np.array(strategy_repeat_count)
n_strategies = len(strategy_set)
old_number_of_samples = histogram.sum()
new_number_of_samples = strategy_repeat_count.sum()
sequences = sum(
    [count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)]
)
total_tokens = max_seq_len * new_number_of_samples
empty_tokens = sum(
    [
        count * (max_seq_len - sum(pack))
        for count, pack in zip(strategy_repeat_count, strategy_set)
    ]
)
efficiency = 100 - empty_tokens / total_tokens * 100
speedup_upper_bound = 1.0 / (
    1 - (histogram * (1 - sequence_lengths / max_seq_len)).sum() / old_number_of_samples
)

print(
    f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
    f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
    f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n",
    f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
)

We have packed 67349 sequences in 0.002s, and achieved an approximate 5.15x speed-up. This is an ideal result, highlighting the speed of the algorithm and its speed benefit.

### Understanding the strategy and mixture (repeat count)

The key returned elements from the packing algorithm are the **strategy set** and **strategy repeat count** or mixture. Lets have a look at part of these returned lists and see what these look like:

In [None]:
print(strategy_set[:50])

In [None]:
print(strategy_repeat_count[:50])

The `strategy_set` tells us the **lengths** of sequences that can be packed together. For instance: 

```
strategy_set[0] = [41, 41, 41, 42, 43, 44]
len(strategy_set[0]) = 6
```

means we can pack 6 sequences into a pack corresponding to this first strategy. Those must be sequences of the lengths of 41, 41, 41, 42, 43 and 44 (the order is not essential). Length 41 repeats three times, meaning we can pack three sequences of length 41 in the input.

The repeating sequence numbers are normal, and are a result of evaluating the dataset, so there will always be enough sequences in the data for repeats.

The `strategy_repeat_count` tells us the number of times we can **repeat** the strategy.

```
strategy_repeat_count[0] = 3
```

means we can repeat the above `strategy_set` 3 times, in other words, we are able to have 3 packs with sequence length configurations of `[41, 41, 41, 42, 43, 44]` .

<a id='creating-the-actual-dataset'></a>
## Preprocessing: Creating the actual dataset

The strategy provides a length and position information for the dataset, but it doesn't link directly to the sequences in the dataset. The tokenized inputs for BERT need to be extracted and transformed into a 'packed' style. An assumption would be that this is as simple as concatenating sequences with equivalent length to the strategy values together per sample. This is true for some inputs, but a few further considerations must be made for the model to recognise the sequence-specific info for each input.

### Shape of the data

First, retrieve the columns in the tokenized dataset:

In [None]:
unpacked_input_ids = tokenized_dataset["input_ids"]
unpacked_attention_mask = tokenized_dataset["attention_mask"]
unpacked_token_type_ids = tokenized_dataset["token_type_ids"]
unpacked_labels = tokenized_dataset["label"]

Then, we initialise the fields required for training the model for BERT, which include the `input_ids`, `attention_mask`, `token_type_ids`, `labels`, and the `position_ids` which provide the model with positional information for each individual token within a packed input. These pre-initialised inputs make the packing process more efficient, and are preset to the defined padding values.

We need to know the size of the new dataset we will be creating using the strategy, this is the total number of packs, which we can retrieve by summing the `strategy_repeat_count`:

In [None]:
total_num_packs = np.sum(strategy_repeat_count)
print(total_num_packs)

Then, create the empty arrays to store the packs, each row should have the size of the maximum sequence length:

In [None]:
packed_input_ids = np.zeros((total_num_packs, max_seq_len), dtype=int)
packed_attention_mask = np.zeros((total_num_packs, max_seq_len), dtype=int)
packed_token_type_ids = np.zeros((total_num_packs, max_seq_len), dtype=int)
packed_position_ids = np.zeros((total_num_packs, max_seq_len), dtype=int)
packed_labels = -100 * np.ones((total_num_packs, max_seq_per_pack), dtype=int)
packed_example_ids = -100 * np.ones((total_num_packs, max_seq_per_pack), dtype=int)

For other tasks, like SQuAD, some other columns custom to that task would also need to be added, or existing label shapes may need changing, for instance, having start and end positions instead of labels, or the offset mapping, with the correct data type defined. 

The dataset column format should be fairly intuitive, the shape of a column should be converted from unpacked to packed as such, where `n` is some arbitrary dimension (it could be 0, such as in the case of input_ids), or it could be a value if for instance each sequence has a corresponding one-hot-encoded label array (e.g., for multi-label classification):

```python
# -------- BEFORE PACKING --------
num_labels (multi label) = 5, max_seq_len = 384, total_samples = 1000

unpacked_data.shape        = [1000, 384]
unpacked_data_labels.shape = [1000,   5]



# -------- AFTER PACKING --------
total_num_packs = 250, max_seq_per_pack = 4

packed_data.shape          = [ 250, 384]
packed_data_labels.shape   = [ 250,   4,  5]
``` 


A few key notes here: 
* The **`position_ids`** are not created when tokenizing the dataset with a pre-trained tokenizer, we create them when iterating through the strategies.
* The **`labels`** are padded to -100, rather than to 0. This is so that we can mask labels more easily when they are packed together - as the maximum number of sequences per pack may not always be used, there is often not just end padding, but intermittent padding in an input batch.
* For some tasks, where non-binary loss functions like cross-entropy loss are being used, -100 is the default `ignore_index` value in PyTorch, and the model will not calculate the loss for logits which correspond to ignored indices in the labels, avoiding the need for masking. So, -100 is an ideal value to set label padding to.
* In the case of pure inference (with no labels) we also create an extra column for the indices of the data received: `example_ids`. While packing optimises the size of a dataset by packing according to sequence length, it does not maintain a consistent order of inputs. This is an issue for inference, particularly with large amounts of data, where the position of the input sequence should correspond to its output sequence, passing positional IDs that correspond to the original order of the dataset allows us to re-sort the logits at prediction time.
* It is worth mentioning that some reasonable adjustments may need to be made for models based on the task they perform. For instance, the output logits denote start and end positions of an answer. Concatenating inputs together into one tensor results in the start and end positions provided in the training set to be inaccurate, shifting them by the position that inputs have been shifted by when placed one-after-another using packing. For this case, we introduced a `positions_offset` that shifts the position labels according to the offset created for their respective context tokens when packing. For instance:
   
    - Answer positions before packing:
    
    ```
    Length of sequence 1: 100 tokens (index 0 to 99)   , start position: 30, end position: 35
    Length of sequence 2: 120 tokens (index 0 to 119)  , start position: 15, end position: 25
    ```
    - Answer positions after packing:
    
    ```
    Length of sequence 1 in pack 1: 100 tokens (index 0 to 99)   , start position: 30, end position: 35
    Length of sequence 2 in pack 1: 120 tokens (index 100 to 219), start position: 115, end position: 125 
    ```

    - The positions have been shifted by the total length of preceding sequences in the pack, to align predictions.


### Packing the contents - a simplified example

For some columns of the dataset, packing them is equivalent to a simple concatenation - for others, special considerations must be made. 

The function to pack the data is somewhat dense, to interpret it more easily lets look at the conversion required to create the packed dataset columns using a very minimal example, with 3 fake sequences converted to 1 pack:

Create some fake data:

In [None]:
import random

ex_input_ids = [
    [101, 3, 457, 6, 7, 102],
    [101, 3, 243, 3, 2, 102],
    [101, 65, 341, 23, 11, 24, 81, 3, 102],
]
ex_attention_mask = [[1 for i in j] for j in ex_input_ids]
ex_token_type_ids = [[0 for i in j] for j in ex_input_ids]
ex_labels = [random.randint(0, 2) for i in ex_input_ids]

print(
    f"*input_ids      = {ex_input_ids}\n*attention_mask = {ex_attention_mask}\n*token_type_ids = {ex_token_type_ids}\n*labels         = {ex_labels}"
)

In this simplified example, we have 3 sequences, which will convert into 1 pack, lets use reduced parameters to demonstrate this. First, initialise the example packed columns as in the previous section:

In [None]:
ex_total_num_packs = 1
ex_max_seq_len = 24
ex_max_seq_per_pack = 3

ex_packed_input_ids = np.zeros((ex_total_num_packs, ex_max_seq_len), dtype=int)
ex_packed_attention_mask = np.zeros((ex_total_num_packs, ex_max_seq_len), dtype=int)
ex_packed_token_type_ids = np.zeros((ex_total_num_packs, ex_max_seq_len), dtype=int)
ex_packed_position_ids = np.zeros((ex_total_num_packs, ex_max_seq_len), dtype=int)
ex_packed_labels = -100 * np.ones((ex_total_num_packs, ex_max_seq_per_pack), dtype=int)

A strategy for the 3 input sequences would look like this (sequence lengths of 9, 6 and 6) - defining the order in which to pack sequences of the specified lengths into one input:

In [None]:
ex_strategy = [9, 6, 6]

Once we have the initialised arrays, a definition of the total number of packs, and all the strategies, we retrieve all of the sequence lengths in the dataset, sort them according to size and store the indices which correspond to these lengths in the dataset using `np.argsort`. The indices and sequence lengths are stacked into `sorted_seqs`. This array will help traverse the dataset using the strategy to create the dataset.

In [None]:
dataset_seq_lens = np.array([len(seq) for seq in ex_input_ids])
len_sorted_seq_idxs = np.argsort(dataset_seq_lens)
len_sorted_seq_lens = dataset_seq_lens[len_sorted_seq_idxs]
sorted_seqs = np.stack((len_sorted_seq_lens, len_sorted_seq_idxs))

print(sorted_seqs)

At this point, we begin iterating through the strategies, repeating a strategy over the dataset for its corresponding repeat count. We only have one pack to create here so the loop isn't required. The first stage is to iterate through the strategy, use the *final result* of `np.argwhere` to determine only one sequence which has the same length as the strategy value:

1. Iterate through strategy.
2. Get all sequences which match length of strategy_set. 
3. Select only one of these sequences (easiest is to use `[-1]` to take last appearing one)
4. Add the index of this sequence length from **`sorted_seqs[0,:]`** to the reference indices (`ref_inds`) 
5. Set the value of this sequence length in **`sorted_seqs[0,:]`** to -1, so it cannot be chosen again
6. The final indices for the sequences **in the dataset** are obtained by retrieving the dataset indices from `sorted_seqs[1,:]`

This provides a list of dataset indices to retrieve sequences in the pack from - and removes the risk of creating duplicate packs. The code for this process is as follows:

In [None]:
ref_inds = []
for x in ex_strategy:
    ref_ind = np.argwhere(sorted_seqs[0] == x)[-1]
    sorted_seqs[0, ref_ind] = -1
    ref_inds.append(ref_ind)

inds = sorted_seqs[1, ref_inds].ravel()

print(inds)

To follow this process, take a look at the below diagram as a visualisation of converting the three examples sequences into one packed sequence, assuming the global maximum sequence length of 24:

![Packed input creator example](images/packing-creator-example.png)

So to create the packs, by concatenating `input_ids` and `token_type_ids` (not included in the diagram as it is not required for all tasks, but it is simply concatenated) and `labels`. 

For the attention mask, we concatenate integers increasing by 1 representing each sequence in the pack. 

We also skip the first token for all of the required columns - and place it at the end (labels are not tokenized, so we do not skip the first token for them). This behaviour is specific to classification tasks and not required for prediction tasks like question-answering.

In [None]:
import itertools

input_id_pack = list(itertools.chain(*[ex_input_ids[x][1:] for x in inds]))
attention_mask_pack = list(
    itertools.chain(
        *[
            itertools.repeat(n + 1, len(ex_attention_mask[v]) - 1)
            for n, v in enumerate(inds)
        ]
    )
)
token_type_ids_pack = list(itertools.chain(*[ex_token_type_ids[x][1:] for x in inds]))
position_ids_pack = list(
    itertools.chain(*[range(1, len(ex_attention_mask[v])) for n, v in enumerate(inds)])
)
labels_pack = [ex_labels[x] for x in inds]

ex_packed_input_ids[0, : len(input_id_pack)] = input_id_pack
ex_packed_attention_mask[0, : len(attention_mask_pack)] = attention_mask_pack
ex_packed_token_type_ids[0, : len(token_type_ids_pack)] = token_type_ids_pack
ex_packed_position_ids[0, : len(position_ids_pack)] = position_ids_pack
ex_packed_labels[0, : len(labels_pack)] = labels_pack


print(
    f"*input_ids = {ex_packed_input_ids}\n*attention_mask = {ex_packed_attention_mask}\n*token_type_ids = {ex_packed_token_type_ids}\n*postion_ids={ex_packed_position_ids}\n*labels = {ex_packed_labels}"
)

If we are doing inference, the `example_ids` for inference can simply be stored - this stores the order of rearrangement that happens due to packing, and can be used to re-sort the data for inference. The `example_ids` are just the indices of the sequences chosen (as in the above diagram, and the order they were chosen in).

In [None]:
ex_packed_example_ids = -100 * np.ones(
    (ex_total_num_packs, ex_max_seq_per_pack), dtype=int
)

example_ids_pack = inds
ex_packed_example_ids[0, : len(example_ids_pack)] = example_ids_pack

print(f"*example_ids = {ex_packed_example_ids}")

We have now provisionally packed the data, lets understand the `attention_mask` and the `position_ids` and why the first value of the sequence was skipped for each sequence when packing:

**2D increasing integer attention mask:** This attention mask represents multiple sequences rather than a typical binary attention mask which represents the valid part of a sequence, and is set to 0 to padding. Here, since we have 3 sequences, each part of the attention mask that corresponds to the range of tokens in that sequence has an increasing integer value. 

The attention mask actually used by the model will be a binary 3D attention mask, from size `[batch, max_seq_len]` to size `[batch, max_seq_len, max_seq_len]` where each row for a single index in the batch denotes the sequence that is positionally active for the model per token in a column. The conversion from the 2D attention mask to the 3d 'extended' attention mask is done in the forward pass, and is discussed further in the [Modifications to model processing](#modifications-to-model-processing) section.

**Position IDs:** Where the attention mask outlines which sequence the model should 'pay attention to' at any token, the position IDs are indices for each token within a sequence, i.e. how far into the sequence that specific token is. For a packed input, the indices can simply be concatenated, and the corresponding indices will be interpreted according to the active part of the attention mask for each token.

**Skipping first token (CLS) of sequences:** The BERT pooler will be outlined in the next session. For the preprocessing, we must simply note that the output of the model's hidden states uses the CLS (opening token) representation of the sequence to generate an output, **but only for classification tasks, this step is only necessarily if BERT is using the Pooler.** The CLS token is usually the first token in the sequence and can be retrieved easily. For packing, there are multiple CLS tokens corresponding to each sequence in a pack within the input, these are difficult to locate and obtain efficiently in the model, so we remove them from the beginning of the sequences. This reduces each sequence's length is reduced by 1, and this is reflected by reducing the attention mask length and token type ID length.

The latter point creates a final extra step applicable to classification tasks: we need to put the CLS tokens somewhere else in the sequence, where the model can find them easily, so the CLS tokens are all placed at the end of the sequence. When, in this case the actual number of sequences is lower than the maximum sequences per pack, we still set all of the last `max_seq_per_pack` indices to the CLS value, and then mask at the loss stage.

In [None]:
ex_packed_input_ids[0, -ex_max_seq_per_pack:] = [
    ex_input_ids[0][0] for _ in range(ex_max_seq_per_pack)
]
ex_packed_attention_mask[0, -ex_max_seq_per_pack:] = list(
    range(1, ex_max_seq_per_pack + 1)
)

print(ex_packed_input_ids)

This stage may not be needed for tasks like question-answering, which works on a token level and doesn't require a global representation of the sequence (there is no Pooler stage). With that in mind, note that there may be further task-specific considerations to make for different use cases. When implementing a new use case, we need to consider the most reasonable way to create packed dataset columns with respect to the operation of the model.

For example, one consideration, for SQuAD, due to the models output being *positional*, i.e., it assumes the sequence is starting at position 0, from which the start and end positions of the answer within the context is obtained. Once we pack data, this assumption is false for every sequence after the first one in the pack, so this requires the special consideration of a `position_offset`, which adds the length of previous sequences to the start/end position value of the current sequence within a pack.

### Packing the contents - the full function

The previous section outlined the process for packing for a single packed input example, for clarity. 

Using essentially the same steps as above, lets have a look at how the whole dataset, which we previously generated a strategy for, is iterated over and packed, this is the process within the `create()` function in the `PackedDatasetCreator`:

In [None]:
# Arguments passed into the function
problem_type = "single_label_classification"  # Pass one of the supported problem types for the special considerations
training = True
validation = False
inference = False

strategy_set = strategy_set
strategy_repeat_count = strategy_repeat_count

# Other arguments defined by the problem type
skip_cls = 1  # Use this for classification tasks
adjust_offset_positions = False  # Use this for positional predictions

# Sort the sequences by length
dataset_seq_lens = np.array([len(seq) for seq in unpacked_input_ids])
len_sorted_seq_idxs = np.argsort(dataset_seq_lens)
len_sorted_seq_lens = dataset_seq_lens[len_sorted_seq_idxs]
sorted_seqs = np.stack((len_sorted_seq_lens, len_sorted_seq_idxs))

# Pack the data using the developed strategies
pack_index = 0

st = time.perf_counter()
for i in range(len(strategy_repeat_count)):
    strategy = strategy_set[i]

    # This is the offset we apply to the start positions to account for the positional change of the logits for SQuAD
    if adjust_offset_positions:
        positions_offset = [sum(strategy[:n]) for n in range(len(strategy))]

    for _ in range(strategy_repeat_count[i]):
        ref_inds = []
        for x in strategy:
            ref_ind = np.argwhere(sorted_seqs[0] == x)[-1]
            sorted_seqs[0, ref_ind] = -1
            ref_inds.append(ref_ind)

        inds = sorted_seqs[1, ref_inds].ravel()

        # Exclude the CLS tokens to put them at the end later
        input_id_pack = list(
            itertools.chain(*[unpacked_input_ids[x][skip_cls:] for x in inds])
        )
        attention_mask_pack = list(
            itertools.chain(
                *[
                    itertools.repeat(n + 1, len(unpacked_attention_mask[v]) - skip_cls)
                    for n, v in enumerate(inds)
                ]
            )
        )
        token_type_ids_pack = list(
            itertools.chain(*[unpacked_token_type_ids[x][skip_cls:] for x in inds])
        )
        position_ids_pack = list(
            itertools.chain(
                *[
                    range(skip_cls, len(unpacked_attention_mask[v]))
                    for n, v in enumerate(inds)
                ]
            )
        )

        # Create the equivalent tokenised packed dataset - we operate with python arrays due to inhomogeneous dataset size
        packed_input_ids[pack_index, : len(input_id_pack)] = input_id_pack
        packed_attention_mask[
            pack_index, : len(attention_mask_pack)
        ] = attention_mask_pack
        packed_token_type_ids[
            pack_index, : len(token_type_ids_pack)
        ] = token_type_ids_pack
        packed_position_ids[pack_index, : len(position_ids_pack)] = position_ids_pack

        if problem_type == "single_label_classification":
            if training or validation:
                labels_pack = [unpacked_labels[x] for x in inds]
                packed_labels[pack_index, : len(labels_pack)] = labels_pack
            if inference:
                example_ids_pack = inds
                packed_example_ids[
                    pack_index, : len(example_ids_pack)
                ] = example_ids_pack

        if problem_type == "multi_label_classification":
            if training or validation:
                labels_pack = np.stack([unpacked_labels[x] for x in inds])
                packed_labels[pack_index, : labels_pack.shape[0], :] = labels_pack
            if inference:
                example_ids_pack = inds
                packed_example_ids[
                    pack_index, : len(example_ids_pack)
                ] = example_ids_pack

        if problem_type == "question_answering":
            if training:
                start_positions_pack = [
                    max(start_positions[v] + positions_offset[n], 0)
                    for n, v in enumerate(inds)
                ]
                end_positions_pack = [
                    max(end_positions[v] + positions_offset[n], 0)
                    for n, v in enumerate(inds)
                ]
                packed_start_positions[
                    pack_index, : len(start_positions_pack)
                ] = start_positions_pack
                packed_end_positions[
                    pack_index, : len(end_positions_pack)
                ] = end_positions_pack

            if validation or inference:
                example_ids_pack = [unpacked_example_ids[x] for x in inds]
                offset_mapping_pack = list(
                    itertools.chain(*[unpacked_offset_mapping[x] for x in inds])
                )

                packed_example_ids[
                    pack_index, : len(example_ids_pack)
                ] = example_ids_pack
                packed_offset_mapping[
                    pack_index, : len(offset_mapping_pack)
                ] = offset_mapping_pack

        # Now add the CLS tokens and their masks at the end of the pack if classification task
        if skip_cls:
            packed_input_ids[pack_index, -max_seq_per_pack:] = [
                unpacked_input_ids[0][0] for _ in range(max_seq_per_pack)
            ]
            packed_attention_mask[pack_index, -max_seq_per_pack:] = list(
                range(1, max_seq_per_pack + 1)
            )

        pack_index += 1
en = time.perf_counter()

print(f"Time to pack dataset: {en-st}s")

The `packed_` dataset columns are then used to create a new PyTorch dataset. First, we need to create the dataset class that defines all of the columns required, this base template may differ according to the task. Within it, return a dictionary form of each of the columns, with labels set optionally by evaluating whether or not they were passed before adding them to the return sample dictionary:

In [None]:
from torch.utils.data import Dataset


class PackedClassificationDataset(Dataset):
    def __init__(self, input_ids, attention_mask, token_type_ids, position_ids, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.position_ids = position_ids
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, index):
        input_ids = self.input_ids[index]
        attention_masks = self.attention_mask[index]
        token_type_ids = self.token_type_ids[index]
        position_ids = self.position_ids[index]
        labels = self.labels[index] if self.labels is not None else None

        sample = {
            "input_ids": input_ids,
            "attention_mask": attention_masks,
            "token_type_ids": token_type_ids,
            "position_ids": position_ids,
        }

        if self.labels is not None:
            sample["labels"] = labels

        return sample

Next, pass the `packed_` dataset columns directly to the dataset class:

In [None]:
packed_dataset = PackedClassificationDataset(
    input_ids=packed_input_ids,
    attention_mask=packed_attention_mask,
    token_type_ids=packed_token_type_ids,
    position_ids=packed_position_ids,
    labels=packed_labels,
)

And that's it! We have covered all of the preprocessing steps required and have our packed dataset, ready to be passed to the Dataloader. But there are a few more considerations to make when it comes to the model itself. While we do not need to modify anything in the internal forward pass for BERT, the task-specific model heads need to be slightly modified for packed inputs, as well as some small considerations you may need to make for postprocessing/evaluation.

<a id='modifications-to-model-processing'></a>
## Modifications to model processing

This section covers the modifications made to the Transformers BERT modeling head to enable packing.

### Packed BERT Pooler for classification

The BERT pooler class operates on the output of the hidden states from the generic BERT model forward-pass output, the hidden state corresponding to the first token (the CLS token) of the sequence is taken and passed through a linear layer and activation function to generate the pooled output, resulting in a global representation of the sequence (rather than a token-by-token representation), and is then used to generate the logits.

Note that the BERT Pooler is only needed for classification tasks - more specifically, for tasks where the hidden state outputs must be amalgamated to provide labels. Tasks which return positional outputs, like question answering, do not require a pooling stage.

Lets look at the forward pass in the BERT pooler class (for a single unpacked sequence), as it exists currently:

```python
class BertPooler(nn.Module):
    ...
    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding to first token (CLS)
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output
```

Then, lets look at a chunk of the forward pass of the BERT model head for sequence classification to understand how these pooled outputs are used to generate logits:

```python
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    
    def forward(self, **kwargs):
        #This calls the pooler at the end stage for classification
        outputs = self.bert(**kwargs) 

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        ...
...

```


Packing sequences increases the number of elements per batch, within each sample inside a batch, there are now up to `max_seq_per_pack` sequences. In order to reuse the classifications heads from 🤗 Transformers, we need a special pooler, applying modifications to the existing `BertPooler` class. Instead of pooling the hidden states of a single sequence, it needs to pool multiple sequences, (given the maximum number of sequences in the pack) and order them along the batch dimension. This is so that the the output size of the pooler becomes: 

```
[batch_size * max_sequences_per_pack, hidden_size]
```

To make this process easier, the CLS tokens for the maximum sequences in a pack were moved to the end of each input during preprocessing. The diagram below shows how pooling works for Packed BERT with a maximum of 2 sequences per pack:

![Packed BERT Pooler operation](images/pooling-draw-edit.png)


Based on this, lets instantiate the `PackedBERTPooler` class:


In [2]:
import torch
import torch.nn as nn


class PackedBertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.max_seq_per_pack = config.max_sequences_per_pack
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        sh = hidden_states.shape
        last_tokens_tensors = hidden_states[:, -self.max_seq_per_pack :]
        last_reshape = last_tokens_tensors.reshape(sh[0] * self.max_seq_per_pack, sh[2])
        # output size: [bs x max_sequences_per_pack, hidden_size]
        output = self.dense(last_reshape)
        output = self.activation(output)

        return output

From the loss function point-of-view, everything will appear as if we were simply using a larger batch size (`batch-size * max_seq_per_pack`). When the number of sequences in the pack is lower than max_sequences_per_pack. Remember that we do not always use `max_seq_per_pack` number of sequences, which is the *ideal* case, but we still put the maximum number of CLS tokens at the end of each input to be pooled. So even if 4 out of 6 sequences are in an input, 6 sets of logits are generated from the pooled outputs.

This is fine, because the unused inputs (and then unused logits) are ignored by using the default ignore_index (-100) of the loss as a special label, which we already prepared during preprocessing. This is a useful feature of cross entropy loss (`torch.nn.CrossEntropyLoss`), as it expects an undetermined number of integer labels, it has the capability to ignore an index for the purpose of not training on a specific label. We can exploit this feature to help us deal with intermittent padding.

### Custom inputs: 3D attention mask

The attention mask should be used in a specific way for Packed BERT. During preprocessing, we created an integer type attention mask with increasing values from 1 to `max_seq_per_pack` representing the positional information for each sequence within a packed input. This can now be extended to a 3D attention mask (including the batch dimension) representing the positional information for each sequence as a binary attention mask. It is transformed from the shape `[batch_size, max_seq_len]` to `[batch_size, max_seq_len, max_seq_len]` where each row now represents the binary attention mask applicable to that specific token.

![Custom 3D attention mask](images/attention-mask-draw.png)

We will create the extended attention mask like in the following example. By doing so, the cross-attention will treat separately each sequence of the pack (and it will also ignore the padding), this is done in the modified BERT model head.

To more simply intuit this, lets have a look at a minimal example, using a smaller version of the type of attention mask we created during preprocessing:

In [3]:
# 1 : Flattened attention mask generated by the dataset. Each sequence has a different index. 0 is padding:
ex_attention_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 3, 3, 0, 0, 1, 2, 3]])

In [4]:
# 2: Generate the boolean extended attention mask
ex_attention_mask = ex_attention_mask[:, None, :].repeat(
    1, ex_attention_mask.shape[1], 1
)
ex_attention_mask = (ex_attention_mask == ex_attention_mask.transpose(1, 2)) * (
    ex_attention_mask != 0
)

In [5]:
# Notice that the mask is always False for the padding tokens.
print(ex_attention_mask.to(int))

tensor([[[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1]]])


This results in an attention mask which allows the model to process each sequence separately, and ignores padding.

### Custom outputs: Unpacking logits

For single label classification, as above, we use the `ignore_index` argument in the loss to automatically mask unused logits in the pooled output, and compute loss using these. Tasks which use binary loss functions, for instance `BinaryCrossEntropyWithLogits` loss, do not take an ignore index argument - as they expect just 1 or 0 for the labels, and determine a probabilitic outcome. In these cases (such as in multi-label classification), the logits and loss must be manually masked in the Packed BERT head, to ensure they do not have detrimental effects on the loss calculation (and in turn the training of the model). Finally, the logits need to be reshaped to be the same shape as the logits (essentially detaching the combined logits into ones respective to each sequence).

For question-answering, and similar token-specific tasks, as mentioned before, we are not looking for a global representation of entire sequences, but rather token-specific information. When the hidden states output is received, instead of pooling, it is directly passed through a linear layer to generate a set of logits for **each token**. The logits are then the same size as the entire sequence, and sequence-specific information must be extracted manually. For this kind of task, we create an *unpacking* mask, to extract the logits for each sequence in a pack into separate sets of logits, and then reshape these into a larger batch size of (`batch-size * max_seq_per_pack`) for prediction - we take care to remove tokens corresponding to unused sequences in a pack when applying this 'unpacking mask' allowing us to also mask padding at the same time. 

#### For multi-label classification

Lets look at how we might mask the logits and loss for multi-label classification after receiving the output from the forward pass, note that the attention mask used here is the 2D integer attention mask, which is taken from before the attention mask is modified to the 3D attention mask in the forward pass. This is an excerpt from the `PackedBertOutputsForMultiLabel` class. Once again, this is not the focus of the classification example used in this notebook, but it is useful to observe the alternate customisations that may need to be applied to the model for different tasks:

Obtain the maximum number of labels used for each input in the batch. The pre-transform attention mask's maximum value is equal to the maximum number of sequences, so we can take the maximum of the attention mask on the batch dimension.
```python
max_labels = torch.max(attention_mask[:,:-self.max_seq_per_pack], dim=-1).values.unsqueeze(1)
```

Then, we can use these maximum values to create a mask, first creating an array holding a range of values 1 to `max_seq_per_pack`, then masking this using the `max_labels` to set values below the maximum labels to 1, and the higher to zero, resulting in a mask we can directly apply to the logits.

```python
label_mask = torch.arange(0, self.max_seq_per_pack).unsqueeze(0).repeat(batch_dim, 1)
label_mask = torch.where(label_mask < max_labels, 
                            torch.ones(batch_dim, self.max_seq_per_pack), 
                            torch.zeros(batch_dim, self.max_seq_per_pack))
label_mask = label_mask.view(-1).unsqueeze(1)
```
Then we can multiply this mask by the logits directly, zero-ing all logits which correspond to padding.
```python
logits = label_mask * outputs.logits
```

Then, we can flatten the labels to be of the larger effective batch size `batch_size*max_seq_per_pack` and similarly multiply this by the mask to zero all unused label values:
```python
loss = None
if labels is not None:
    labels = labels.view(-1, *(labels.size()[2:])).to(torch.float32)
    labels = label_mask * labels
```

As the loss uses the Sigmoid function for logits, it returns probabilities for labels even if they were zeroed, so the label mask is applied to the loss:
```python    
    loss = self.multi_loss(logits, labels)
    loss *= label_mask
```

Then take the mean over each multi-class prediction, dividing by the total number of values to obtain the mean loss.
```python
    
    loss = torch.sum(loss) / (torch.sum(max_labels)*labels.shape[-1])
    loss = poptorch.identity_loss(loss, reduction='none')
```

Finally, reshape the logits to be the same shape as the labels for postprocessing/validation
```python
    logits = logits.reshape([batch_dim, self.max_seq_per_pack, logits.shape[-1]])

    return (loss, logits)
```





#### For question answering:
We need a custom output to 'unpack' and mask the logits before performing loss on them, this code excerpt from the `PackedBertOutputsForQA` class shows the process for creating the unpacking mask and using it to extract the logits for each sequence to compute loss on an emulated larger batch size. Once again, this is not the focus of the classification example used in this notebook.

Create unpacking mask to separate packed logits out into sequence-specific logits only, we repeat the 2D attention mask for the max sequences.
```python
unpacking_mask = attention_mask[:,None,:].repeat(1, self.max_sequences_per_pack, 1)
```

Then, generate a list of indices indicating the number of sequences used in the pack, a logical comparison will result in setting all necessary indices in the unpacking mask to 1 and others to 0.
```python
pack_seq_ids = torch.arange(1, self.max_sequences_per_pack + 1).view(self.max_sequences_per_pack, 1) 
# Use this to mask out the indices that are padding
unpacking_mask = (unpacking_mask == pack_seq_ids)
```
Multiplying the start and end logits (for each of the positions) will remove the unused or padded logits.
```python
unpacked_start_logits = final_layer_output.start_logits[:,None,:] * unpacking_mask
unpacked_end_logits = final_layer_output.end_logits[:,None,:] * unpacking_mask
```

Then, we flatten the start and end position labels to compute loss on them
```python
# Calculate loss on logits/labels with initial [bs, mspp, ...] dims collapsed into one [bs*mspp, ...]
total_loss = None
if start_positions is not None and end_positions is not None:
    start_positions = start_positions.view(-1)
    end_positions = end_positions.view(-1)
```

Logits are also flattened
```python
    unpacked_start_logits=unpacked_start_logits.contiguous()
    unpacked_end_logits=unpacked_end_logits.contiguous()

    unpacked_start_logits = unpacked_start_logits.view(-1, unpacked_start_logits.shape[-1])
    unpacked_end_logits = unpacked_end_logits.view(-1, unpacked_end_logits.shape[-1])
```

Then loss can be computed as normal for each set of logits.
```python
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(unpacked_start_logits, start_positions)
    end_loss = loss_fct(unpacked_end_logits, end_positions)
    total_loss = (start_loss + end_loss) / 2
    # Since we need to unpack and mask logits prior to loss, the step to reshape logits is not required for QA

```


When implementing packing for a new model, you may need to make some small considerations as above for how best to 'unpack' the returned logits at the output stage, for the classification/predictions process to handle the outputs as if they were simply a larger batch size. For the single label sequence classification task outlined in postprocessing, as the output is pooled and the loss automatically ignores unused logits, custom outputs are not required. However, the logits still need to be reshaped to the same shape as the labels at the output stage.

We need to instantiate new model heads above the default Transformers BERT heads in order to handle the changes in output created by packing. In 🤗 Optimum-Graphcore, you can find the source for these model heads which include the output processing for all supported tasks in `notebooks/packed_bert/models/modeling_bert_packed.py`.

Finally, for our sequence classification task with `sst2`, lets instantiate the full model class for single label sequence classification with the above modifications for packing.

In [None]:
from optimum.graphcore.models.bert.modeling_bert import BertPipelineMixin
from transformers import BertForSequenceClassification
from typing import Optional, Tuple


class PipelinedPackedBertForSequenceClassification(
    BertForSequenceClassification, BertPipelineMixin
):
    def __init__(self, config):
        super().__init__(config)
        self.max_seq_per_pack = config.max_sequences_per_pack
        self.problem_type = config.problem_type
        self.num_labels = config.num_labels
        self.bert.pooler = PackedBertPooler(config)

    def parallelize(self):
        # The parallelize function enables pipelining on the IPU when inheriting from BertPipelineMixin
        super().parallelize()
        last_ipu = self.ipu_config.ipus_per_replica - 1
        self.classifier = poptorch.BeginBlock(
            self.classifier, "Classifier Output", ipu_id=last_ipu
        )
        return self

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Tuple[torch.Tensor]:

        # We can obtain the global batch size and custom max sequence length from the data shape, so as to not pass extra custom args
        bs = input_ids.shape[0]
        seq_len = input_ids.shape[1]

        # Generate the custom 3D attention mask
        attention_mask_3d = attention_mask[:, None, :].repeat(1, seq_len, 1)
        attention_mask_3d = (attention_mask_3d == attention_mask_3d.transpose(1, 2)) * (
            attention_mask_3d != 0
        )

        # Manual masking of logits and loss only needed for multi-label, single-label loss allows ignore_index
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask_3d,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            labels=labels
            if labels is not None and self.problem_type == "single_label_classification"
            else None,
        )

        if self.problem_type == "single_label_classification":
            if labels is not None:
                # The logits must be reshaped at the end to be the same shape as the labels, for easier validation.
                logits = output.logits.reshape(
                    [-1, self.max_seq_per_pack, self.num_labels]
                )
                output.logits = logits

        return output

### Instantiating the model

A few small considerations also need to be made when instantiating the model. The `max_seq_per_pack` parameter introduces an extra option to add to the config. Also, a few parameters must be passed when initialising one of the packed BERT model classes. The number of labels and the problem type are also essential to add to the BERT `AutoConfig`, so that the custom model classes can easily obtain elements needed for some necessary internal parameters.

In [None]:
from transformers import AutoConfig
import torch

model_checkpoint = "Graphcore/bert-base-uncased"
num_labels = 2

config = AutoConfig.from_pretrained(model_checkpoint)
config.max_sequences_per_pack = max_seq_per_pack
config.num_labels = num_labels
config.problem_type = problem_type

model = (
    PipelinedPackedBertForSequenceClassification.from_pretrained(
        model_checkpoint, config=config
    )
    .train()
    .half()
)

print(model)

<a id='postprocessing-the-returned-logits'></a>
## Postprocessing the returned logits

This notebook won't cover the training process for any specific model. This could be different based on the framework you are using, but there are no specific modifications to training/fine-tuning for packing. However, during validation, we need to postprocess the logits returned, as well as the IDs for inference. This is done on the CPU, so we can easily use dynamic indexing to extract the relevant logits. For instance, for single label classification:

We can simply generate a mask by creating a boolean array for the indices where labels are not equal to -100 (as this corresponds to padding). Then index both the labels and predictions using this mask! 

```python
    # Remove the padding labels
    mask = (labels != -100)
    labels = labels[mask]
    
    predictions = np.argmax(predictions, axis=-1)
    predictions = predictions[mask]
```

The shape of labels may change for different tasks, so be sure to index the correct dimension when creating a postprocessing mask. Multi-label, for instance, returns more than one label, so e.g. if there are 5 possible labels for one input, you might change the padding as:

```python
    mask = (labels != -100)[:,0]
    
    labels = labels[mask,:]
    predictions = predictions[mask,:]
```

Now, since there is a an extra dimension of size 5 (or however many number of labels), we keep the first value of the second dimension of the mask.

Essentially, it is important to make sure you are ignoring or removing the padding when postprocessing your logits, as we do with the loss during training, as this will avoid the accuracy values you get being contaminated by comparisons of unused labels/logits.

<a id='a-note-on-training-and-inference'></a>
## A note on training and inference

After doing the above steps, you should have a correctly formatted packed dataset, and a simple model head to modify the model for it to interpret the packed data. These modifications will allow you to train and validate the model as you would the unpacked model. The dataset can be used in a training/inference loop as you would with the generic unpacked version.

### Hyperparameter tuning
You may have to consider the hyperparameters for the packed model - the method does in effect create a much larger emulated batch size than the defined batch size, so one assumption which results in a more accurate reflection of the improved performance is to increase the learning rate proportionally to the number of sequences packed per input, to reflect the kind of tuning you might need to do if simply increasing the batch size by the order of `max_sequences_per_pack`. This was tested by aligning the performance for convergence with the un-packed version (maintaining an equal global batch size, i.e., reducing gradient accumulation steps to equalise the total samples processed by packed BERT versus standard BERT, and setting a fixed seed, is useful to test convergence when attempting to implement packing for a new task/dataset, to ensure none of the above covered modifications and additions to model/data processing have an adverse effect on training values) and we found that it is more precise to increase the learning rate by the average number of sequences per pack in the dataset, as this is a better reflection of the true batch size increase, rather than the theoretical maximum number of sequences per pack. 

Other hyperparameters may also have some minimal effect on the training values, so it may be useful to perform some hyperparameter tuning if implementing packing for a new task. It is worthwhile to note that hyperparameter value changes for packing will usually reflect those needed for increase the batch size by the same multiplier as the packing does - but without the overhead of a much larger batch size, hence allowing for much higher throughput!

<a id='in-summary'></a>
## In summary

We have covered all of the necessary changes to the model, preprocessing and dataset creation for Packed BERT to understand why and how we achieve this optimisation. After the model is instantiated, you can train it on the dataset of your choice using PopTorch or within Optimum Graphcore using the `IPUTrainer` as you would train BERT without packing. For evaluation, there are a few minor considerations to make, such as for an accuracy calculation on CPU, we can simply slice out logits corresponding to unused inputs by finding the indices of the labels which are -100, this was covered briefly in the postprocessing section. 

The validation metrics for each task are covered in the [PackedBERT notebooks](https://console.paperspace.com/github/gradient-ai/Graphcore-HuggingFace?file=%2Fpacked-bert) available on Paperspace, which also cover the full training process in Optimum, simplifying the steps outlined in this notebook to abstract the complexities of the preprocessing as well as the internal model changes.