# Quantizing a BERT-based Model for Question Answering

This notebook demonstrates how to optimize a pre-trained BERT model for question answering tasks using various quantization techniques.


### Table of Contents

1. Setup and Imports
2. Load Pre-trained BERT Model and QA Dataset
3. Fine-tune with Mixed Precision
4. Apply Post-Training Quantization (PTQ)
5. Implement Per-Token Quantization
6. Quantized Layer Normalization
7. Dynamic Quantization for Variable Sequence Lengths
8. Optimize Quantized Attention Mechanism
9. Deploy with Quantization-Aware Inference Framework
10. Performance Benchmarking

## 1. Setup and Imports

First, let's import the necessary libraries:

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertForQuestionAnswering, BertTokenizerFast, squad_convert_examples_to_features
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import DataCollatorWithPadding
from datasets import load_dataset
import numpy as np
from torch.cuda.amp import autocast, GradScaler
import math

## 2. Load Pre-trained BERT Model and QA Dataset

Now, let's load a pre-trained BERT model and the SQuAD dataset:

In [9]:
# Load the model and fast tokenizer
model_name = 'bert-base-uncased'
model = BertForQuestionAnswering.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)  # Use BertTokenizerFast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load the SQuAD dataset
squad_dataset = load_dataset("squad")

# Preprocessing function
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,  # Now supported by the fast tokenizer
        padding="max_length",
    )
    
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

# Apply preprocessing to the dataset
processed_datasets = squad_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=squad_dataset["train"].column_names,
)

# Use a data collator to ensure proper batching and padding
data_collator = DataCollatorWithPadding(tokenizer)

# Prepare dataloaders
train_dataloader = DataLoader(processed_datasets["train"], shuffle=True, batch_size=16, collate_fn=data_collator)
eval_dataloader = DataLoader(processed_datasets["validation"], batch_size=16, collate_fn=data_collator)

print("Loaded BERT model and SQuAD dataset")

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

Loaded BERT model and SQuAD dataset


## 3. Fine-tune with Mixed Precision

Let's implement mixed precision training for fine-tuning:

In [10]:
def train_epoch(model, dataloader, optimizer, scheduler, scaler):
    model.train()
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        with autocast():
            outputs = model(input_ids, attention_mask=attention_mask,
                            start_positions=start_positions,
                            end_positions=end_positions)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Fine-tuning setup
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
scaler = GradScaler()

# Fine-tuning loop
for epoch in range(num_epochs):
    avg_loss = train_epoch(model, train_dataloader, optimizer, scheduler, scaler)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print("Mixed precision fine-tuning completed")

  scaler = GradScaler()
  with autocast():


## 4. Apply Post-Training Quantization (PTQ)

Now, let's apply Post-Training Quantization:

In [None]:
def apply_ptq(model, dataloader):
    model.eval()
    model.cpu()
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    # Keep embeddings and softmax in FP16
    model.bert.embeddings.qconfig = None
    model.qa_outputs.qconfig = None
    
    model_prepared = torch.quantization.prepare(model)

    # Calibration
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            model_prepared(input_ids, attention_mask=attention_mask)
    
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized

# Use a subset of the training data for calibration
calibration_dataloader = DataLoader(processed_datasets["train"].select(range(1000)), batch_size=16)
model_int8 = apply_ptq(model, calibration_dataloader)
print("Post-Training Quantization completed")


## 5. Implement Per-Token Quantization

For per-token quantization, we need to modify the BERT encoder:

In [None]:
class QuantizedBertEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.layer = nn.ModuleList([QuantizedBertLayer(layer) for layer in encoder.layer])

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        for i, layer_module in enumerate(self.layer):
            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
            hidden_states = layer_outputs[0]
        return (hidden_states,)

class QuantizedBertLayer(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.attention = layer.attention
        self.intermediate = layer.intermediate
        self.output = layer.output
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        hidden_states = self.quant(hidden_states)
        attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        layer_output = self.dequant(layer_output)
        return (layer_output,) + attention_outputs[1:]

model_int8.bert.encoder = QuantizedBertEncoder(model_int8.bert.encoder)
print("Implemented per-token quantization")

## 6. Quantized Layer Normalization

Now, let's implement a quantized version of Layer Normalization:

In [None]:
class QuantizedLayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__(normalized_shape, eps, elementwise_affine)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, input):
        input = self.quant(input)
        output = super().forward(input)
        output = self.dequant(output)
        return output

def replace_layer_norm(module):
    for name, child in module.named_children():
        if isinstance(child, nn.LayerNorm):
            setattr(module, name, QuantizedLayerNorm(child.normalized_shape, child.eps, child.elementwise_affine))
        else:
            replace_layer_norm(child)

replace_layer_norm(model_int8)
print("Implemented quantized layer normalization")

## 7. Dynamic Quantization for Variable Sequence Lengths

Dynamic quantization is applied at runtime:

In [None]:
def apply_dynamic_quantization(model):
    return torch.quantization.quantize_dynamic(
        model, {nn.Linear}, dtype=torch.qint8
    )

model_dynamic = apply_dynamic_quantization(model_int8)
print("Applied dynamic quantization")

## 8. Optimize Quantized Attention Mechanism

Let's optimize the attention mechanism for quantized inference:

In [None]:
class OptimizedQuantizedAttention(nn.Module):
    def __init__(self, attention):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        
        # Quantize all linear layers in the attention module
        for name, module in attention.named_modules():
            if isinstance(module, nn.Linear):
                setattr(attention, name, torch.quantization.quantize_dynamic(
                    module, {nn.Linear}, dtype=torch.qint8
                ))
        
        self.attention = attention

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        hidden_states = self.quant(hidden_states)
        outputs = self.attention(hidden_states, attention_mask, head_mask)
        outputs = (self.dequant(outputs[0]),) + outputs[1:]
        return outputs

def replace_attention(module):
    for name, child in module.named_children():
        if "attention" in name.lower():
            setattr(module, name, OptimizedQuantizedAttention(child))
        else:
            replace_attention(child)

replace_attention(model_dynamic)
print("Optimized quantized attention mechanism")

## 9. Deploy with Quantization-Aware Inference Framework

For deployment, we'll use TorchScript:

In [None]:
def deploy_model(model):
    model.eval()
    example_input = torch.randint(0, 1000, (1, 384)).to(device)
    traced_model = torch.jit.trace(model, example_input)
    return traced_model

deployed_model = deploy_model(model_dynamic)
torch.jit.save(deployed_model, "quantized_bert_qa.pt")
print("Model deployed using TorchScript")

## 10. Performance Benchmarking

Finally, let's benchmark our optimized model:

In [None]:
def benchmark(model, input_shape, num_runs=100):
    model.eval()
    input_tensor = torch.randint(0, 1000, input_shape).to(device)
    
    start_time = torch.cuda.Event(enable_timing=True)
    end_time = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        # Warm-up run
        for _ in range(10):
            _ = model(input_tensor)
        
        # Timed runs
        start_time.record()
        for _ in range(num_runs):
            _ = model(input_tensor)
        end_time.record()
    
    torch.cuda.synchronize()
    elapsed_time = start_time.elapsed_time(end_time) / num_runs
    return elapsed_time

original_time = benchmark(model, (1, 384))
quantized_time = benchmark(model_dynamic, (1, 384))
deployed_time = benchmark(deployed_model, (1, 384))

print(f"Original model inference time: {original_time:.2f} ms")
print(f"Quantized model inference time: {quantized_time:.2f} ms")
print(f"Deployed model inference time: {deployed_time:.2f} ms")
print(f"Speedup: {original_time / deployed_time:.2f}x")