# Efficient Fine Tuning and Inference Optimization for `TinyBert`


Welcome to our final project notebook! In this work, we explore the recent integration of `bitsandbytes`, featuring innovative 4-bit quantization techniques by XXX that enable efficient inference and training of large language models without compromising performance.

In this notebook, we will demonstrate how to load and fine-tune a large model in 4-bit precision `huawei-noah/TinyBERT_General_4L_312D` using Google Colab and the Hugging Face 🤗 PEFT library. Let’s dive into democratizing LLM inference and training together!



This notebook is adapted from [bnb-4bit-integration](https://colab.research.google.com/drive/1ge2F1QSK8Q7h0hn3YKuBCOAS0bK8E0wf?usp=sharing) provided by Hugging Face. It demonstrates techniques for efficient 4-bit quantization to optimize LLM inference and training.




To get started, install all the dependencies:

In [None]:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
!pip install evaluate
!pip install xformers flash-attn

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m78.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for peft (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for accelerate (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

First let's load the model we are going to use - `huawei-noah/TinyBERT_General_4L_312D`! Note that the model itself is around 54.74 MB in full precision.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForSequenceClassification

model_id = "huawei-noah/TinyBERT_General_4L_312D"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/62.7M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Then we have to apply some preprocessing to the model to prepare it for training. For that use the `prepare_model_for_kbit_training` method from PEFT.

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "key", "value"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_CLS"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

trainable params: 120434 || all params: 12146284 || trainable%: 0.9915295904492271


Let's load the GLUE dataset, specifically the MRPC (Microsoft Research Paraphrase Corpus), to fine-tune our model on paraphrase detection.

In [None]:
# 4. Load GLUE dataset (MRPC task)
from datasets import load_dataset  # Import the load_dataset function
import evaluate


task = "mrpc"
dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Downloading builder script:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

In [None]:
import numpy as np
import transformers

# 4. Define compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)  # Convert logits to predicted class
    return metric.compute(predictions=predictions, references=labels)

Set up W&B sweeps
=================

In [None]:
import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'accuracy',  # Track accuracy as the metric for saving the best model
        'goal': 'maximize'
    },
    'parameters': {
        'optimizer': {
            'values': ['adamw_torch', 'adafactor', 'adamw_hf', 'adamw_8bit', 'sgd']
        },
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 1e-5,
            'max': 1e-3
        },
        'lr_scheduler': {
            'values': [
                'linear',
                'cosine',
                'constant',
                'constant_with_warmup',
                'polynomial'
            ]
        },
        'weight_decay': {
            'values': [0.0, 0.01, 0.001, 0.1]
        },
        'warmup_ratio': {
            'values': [0.05, 0.1, 0.15, 0.2]
        },
        'train_batch_size': {
            'values': [8, 16, 32, 64, 128]
        },
        'gradient_accumulation_steps': {
            'values': [2, 4, 8, 32, 64]
        },
        'lora_r': {
            'values': [8, 16, 32, 64, 128]
        },
        'lora_alpha': {
            'values': [4, 8, 16, 32, 64]
        },
        'lora_dropout': {
            'values': [0.05, 0.1, 0.2]
        }
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="TinyBert 101")

Create sweep with ID: pbit4fos
Sweep URL: https://wandb.ai/garima440-new-york-university/TinyBert%20101/sweeps/pbit4fos


Run the cell below to run the training! For the sake of the demo, we just ran it for 1 count.

In [None]:
# Set a padding token for the tokenizer
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # Add a padding token

# 3. Preprocessing function for tokenization
def preprocess_function(examples):
    return tokenizer(
        examples["sentence1"],
        examples["sentence2"],
        truncation=True,
        padding="max_length",  # Ensure uniform input size
        max_length=512,       # Typical BERT max length
    )

# 4. Tokenize dataset
encoded_dataset = dataset.map(preprocess_function, batched=True)

# 5. Data collator
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)

def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config

        # 6. Define Trainer with TrainingArguments
        trainer = transformers.Trainer(
            model=model,
            train_dataset=encoded_dataset["train"],
            eval_dataset=encoded_dataset["validation"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            args=transformers.TrainingArguments(
                per_device_train_batch_size=config.train_batch_size,  # BERT can handle larger batch sizes
                gradient_accumulation_steps=config.gradient_accumulation_steps,  # Adjust if GPU memory is limited
                warmup_ratio=config.warmup_ratio,
                max_steps=300,
                learning_rate=config.learning_rate,
                fp16=True,  # Enable mixed-precision if supported by your hardware
                logging_steps=50,
                evaluation_strategy="steps",  # Evaluate periodically
                output_dir="./outputs",
                save_steps=100,
                save_total_limit=2,  # Keep only the latest 2 checkpoints
                optim=config.optimizer,
                weight_decay=config.weight_decay,
                lr_scheduler_type=config.lr_scheduler,
            ),
                compute_metrics=compute_metrics,

        )

        # 7. Disable caching for training
        model.config.use_cache = False

        # 8. Train the model
        trainer.train()


Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

In [None]:
wandb.agent(sweep_id, train, count=1)

[34m[1mwandb[0m: Agent Starting Run: m65nty3u with config:
[34m[1mwandb[0m: 	gradient_accumulation_steps: 2
[34m[1mwandb[0m: 	learning_rate: 0.0003916496770060197
[34m[1mwandb[0m: 	lora_alpha: 8
[34m[1mwandb[0m: 	lora_dropout: 0.2
[34m[1mwandb[0m: 	lora_r: 16
[34m[1mwandb[0m: 	lr_scheduler: polynomial
[34m[1mwandb[0m: 	optimizer: adafactor
[34m[1mwandb[0m: 	train_batch_size: 128
[34m[1mwandb[0m: 	warmup_ratio: 0.2
[34m[1mwandb[0m: 	weight_decay: 0


  trainer = transformers.Trainer(
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.6618,0.6526,0.683824,0.812227
100,0.6024,0.57295,0.681373,0.810496
150,0.5324,0.499376,0.789216,0.857616
200,0.4823,0.458111,0.813725,0.87541
250,0.4597,0.434768,0.821078,0.877311
300,0.4397,0.430333,0.823529,0.879195


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


VBox(children=(Label(value='0.027 MB of 0.027 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/accuracy,▁▁▆███
eval/f1,▁▁▆███
eval/loss,█▅▃▂▁▁
eval/runtime,▂▂█▆▂▁
eval/samples_per_second,▇▇▁▃▇█
eval/steps_per_second,▇▇▁▃▇█
train/epoch,▁▁▂▂▄▄▅▅▇▇███
train/global_step,▁▁▂▂▄▄▅▅▇▇███
train/grad_norm,▃▁▁█▃▆
train/learning_rate,██▆▅▃▁

0,1
eval/accuracy,0.82353
eval/f1,0.87919
eval/loss,0.43033
eval/runtime,2.234
eval/samples_per_second,182.635
eval/steps_per_second,22.829
total_flos,1079050000465920.0
train/epoch,20.0
train/global_step,300.0
train/grad_norm,0.51951


Running inference on fine-tuned and optimized model
===================================================

In [None]:
# run this cell if you load the model files as zip
!unzip fine-tuned-model.zip -d fine-tuned-model

Archive:  fine-tuned-model.zip
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/special_tokens_map.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/README.md  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/scheduler.pt  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/adapter_model.safetensors  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/vocab.txt  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/adapter_config.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/tokenizer_config.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/rng_state.pth  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/tokenizer.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/training_args.bin  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/trainer_state.json  
  inflating: fine-tuned-model/model_files_HPML_accuracy_82/optimizer.pt  


Load the fine tuned model and move it to CUDA
--------------------------------------------

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import time

model = AutoModelForSequenceClassification.from_pretrained("fine-tuned-model/model-files")

# Load tokenizer and trained model
tokenizer = AutoTokenizer.from_pretrained("fine-tuned-model/model-files")  # Path to your saved model
model.eval()  # Set to evaluation mode


config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/62.7M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): lora.Linear(
                (base_layer): Linear(in_features=312, out_features=312, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=312, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_featur

In [None]:
model.to('cuda')

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): lora.Linear(
                (base_layer): Linear(in_features=312, out_features=312, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=312, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_featur

Inference Profiling using Flash Attention ONLY
---------------------------------------------

Following code imports necessary libraries for model inference, including PyTorch, transformers for tokenization and model loading, flash attention for efficient attention mechanisms, and datasets for handling datasets. The `prepare_batch_inputs` function tokenizes pairs of input texts, ensuring they are padded, truncated, and moved to the appropriate device for model processing.

In [None]:
import torch
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from flash_attn import flash_attn_func
import torch.nn.functional as F
import numpy as np

from datasets import Dataset


def prepare_batch_inputs(texts1, texts2):
    """
    Prepare batch of input tokens and attention masks

    :param texts: List of input texts to process
    :return: Tokenized batch inputs
    """
    inputs = tokenizer(
        texts1, texts2,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )

    # Move inputs to the same device as the model
    return {k: v.to(device) for k, v in inputs.items()}

The following function is the custom implementation of Flash Attention. It reshapes the tensor, applies attention using `flash_attn_func`, and then restores the output to its original form. A turbo boost for attention layers, it's designed to handle TinyBERT’s architecture with precision and speed, all while keeping data in float16 for performance.

In [None]:
def apply_flash_attention(hidden_states):
    """
    Apply Flash Attention to the input batch

    :param hidden_states: Hidden states tensor
    :return: Attention output
    """
    # Reshape hidden_states to (batch_size, sequence_length, num_heads, head_dim)
    batch_size, seq_len, hidden_size = hidden_states.shape  # Get the original shape
    num_heads = 12
    head_dim = hidden_size // num_heads

    # Reshape for Flash Attention (batch_size, seq_len, num_heads, head_dim)
    hidden_states = hidden_states.reshape(batch_size, seq_len, num_heads, head_dim)

    # Cast hidden_states to float16 before applying Flash Attention
    hidden_states = hidden_states.type(torch.float16)

    # Apply Flash Attention
    q, k, v = hidden_states, hidden_states, hidden_states
    attn_output = flash_attn_func(q, k, v, dropout_p=0.0)

    # Reshape back to the original shape (batch_size, seq_len, hidden_size)
    attn_output = attn_output.reshape(batch_size, seq_len, hidden_size)

    # Cast attn_output back to the original dtype if necessary
    attn_output = attn_output.type(hidden_states.dtype)

    return attn_output

The following function handles batch inference on a dataset, processing each batch through our model to get predictions. It tracks inference time per batch and calculates performance metrics such as latency and sample throughput. For each batch, the function extracts text pairs, prepares inputs, applies `Flash Attention` on hidden states, and computes softmax probabilities. After processing, it returns the predictions and a dictionary with performance statistics.

In [None]:
def batch_inference(dataset, batch_size):
    inference_times = []
    all_predictions = []

    # Iterate through dataset in batches
    for i in range(0, len(dataset), batch_size):
        # Get the batch (select 'sentence1' and 'sentence2' columns)
        batch = dataset[i : i + batch_size]

        # Extract texts from the batch
        batch_texts1 = batch['sentence1']
        batch_texts2 = batch['sentence2']

        inputs = prepare_batch_inputs(batch_texts1, batch_texts2)

        torch.cuda.synchronize()
        start_time = time.time()

        with torch.no_grad():
            outputs = model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                output_hidden_states=True,
            )

            hidden_states = outputs.hidden_states[-1]
            attention_output = apply_flash_attention(hidden_states)
            logits = model.classifier(attention_output.mean(dim=1).type(torch.float32))
            probabilities = F.softmax(logits, dim=-1)

        torch.cuda.synchronize()
        end_time = time.time()
        batch_inference_time = end_time - start_time
        inference_times.append(batch_inference_time)

        all_predictions.extend(probabilities.cpu().numpy())

    performance_metrics = {
        'total_samples': len(dataset),
        'batch_size': batch_size,
        'device': str(device),
        'inference_times': inference_times,
        'avg_batch_latency_ms': np.mean(inference_times) * 1000,
        'std_batch_latency_ms': np.std(inference_times) * 1000,
        'avg_sample_latency_ms': (np.mean(inference_times) * 1000) / batch_size
    }

    return all_predictions, performance_metrics


In [None]:
# Select test set
test_dataset = dataset['test']

# Configuration
batch_sizes = [32, 64, 128, 512, 1024]

for batch in batch_sizes:

    # Perform batch inference on test set
    all_predictions, performance_metrics = batch_inference(test_dataset, batch)

    # Print performance metrics
    print("\n--- Batch Inference Performance Metrics ---")
    for metric, value in performance_metrics.items():
        print(f"{metric}: {value}")



--- Batch Inference Performance Metrics ---
total_samples: 1725
batch_size: 32
device: cuda
inference_times: [0.00932002067565918, 0.006934165954589844, 0.0069463253021240234, 0.00872182846069336, 0.006987094879150391, 0.00688624382019043, 0.008557319641113281, 0.0065784454345703125, 0.006702899932861328, 0.006498098373413086, 0.009231328964233398, 0.0065708160400390625, 0.007443904876708984, 0.006562471389770508, 0.0064852237701416016, 0.006543159484863281, 0.006667375564575195, 0.0066373348236083984, 0.006604194641113281, 0.0065975189208984375, 0.006800413131713867, 0.0066988468170166016, 0.006528615951538086, 0.006893157958984375, 0.00730586051940918, 0.0065572261810302734, 0.006543636322021484, 0.00662684440612793, 0.008111000061035156, 0.007271766662597656, 0.0065920352935791016, 0.006590843200683594, 0.006611347198486328, 0.006510734558105469, 0.006804704666137695, 0.006814241409301758, 0.008165121078491211, 0.008106708526611328, 0.006478548049926758, 0.006487369537353516, 0.006

In [None]:
from torch.profiler import profile, ProfilerActivity, record_function

def profile_inference_latency(dataset, batch_size):
        """
        Use PyTorch Profiler to measure inference latency
        """
        # Take a small subset of the dataset for profiling
        batch = dataset[:batch_size]
        batch_texts1 = batch['sentence1']
        batch_texts2 = batch['sentence2']

        inputs = prepare_batch_inputs(batch_texts1, batch_texts2)

        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                     on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler_logs"),
                     record_shapes=True, with_stack=True) as prof:
            with record_function("model_inference"):
                with torch.no_grad():
                    outputs = model(
                        input_ids=inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        output_hidden_states=True
                    )

                    hidden_states = outputs.hidden_states[-1]
                    attention_output = apply_flash_attention(hidden_states)
                    logits = model.classifier(attention_output.mean(dim=1).type(torch.float32))
                    probabilities = F.softmax(logits, dim=-1)

        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Inference Benchmarking with PyTorch Profiler
--------------------------------------------------




In [None]:
# Perform PyTorch profiler on test set
# Configuration
batch_sizes = [32, 64, 128, 512, 1024]

for batch in batch_sizes:
    print(f"\n--- PyTorch Profiler with batch size {batch} ---")
    profile_inference_latency(test_dataset, batch)


--- PyTorch Profiler with batch size 32 ---
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.00%       0.000us         0.00%       0.000us       0.000us      16.491ms       374.77%      16.491ms      16.491ms             1  
                                        model_inference        44.82%       9.166ms        99.94%      20.438ms      20.438ms       0.000us         0.00%       4.

Inference profiling using Flash Attention + KV caching
-------------------------------------------------------

This time we are performing inference with `KV caching` in combination with `flash attention` to evaluate potential performance optimizations by leveraging cached keys and values for more efficient attention computation.

In [None]:
import torch
import time
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from flash_attn import flash_attn_func
from datasets import Dataset

class InferenceModel:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.past_key_values = None

    def prepare_batch_inputs(self, texts1, texts2):
        """
        Prepare batch of input tokens and attention masks.
        :param texts1: List of input texts
        :param texts2: List of input texts
        :return: Tokenized batch inputs
        """
        inputs = self.tokenizer(
            texts1, texts2,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        return {k: v.to(self.device) for k, v in inputs.items()}

    def apply_flash_attention(self, hidden_states, past_key_values=None):
        """
        Apply Flash Attention to the input batch with KV caching.

        :param hidden_states: Hidden states tensor (current input)
        :param past_key_values: Cached keys and values from previous steps
        :return: Attention output, updated past_key_values
        """
        # Reshape hidden_states to (batch_size, sequence_length, num_heads, head_dim)
        batch_size, seq_len, hidden_size = hidden_states.shape  # Get the original shape
        num_heads = 12
        head_dim = hidden_size // num_heads

        # Reshape for Flash Attention (batch_size, seq_len, num_heads, head_dim)
        hidden_states = hidden_states.reshape(batch_size, seq_len, num_heads, head_dim)

        # Cast hidden_states to float16 before applying Flash Attention
        hidden_states = hidden_states.type(torch.float16)

        # Initialize past_key_values if they are not provided
        if past_key_values is None:
            past_key_values = (None, None)  # Initialize empty cache

        # Use past keys and values, if available, for efficient computation
        k, v = past_key_values

        # If no past keys/values, use the current hidden states for k and v
        if k is None or v is None:
            k, v = hidden_states, hidden_states
        else:
            # Concatenate new keys/values with the past ones (for autoregressive tasks)
            k = torch.cat((k, hidden_states), dim=1)  # Concatenate along the sequence dimension
            v = torch.cat((v, hidden_states), dim=1)

        # Apply Flash Attention with cached keys and values
        q = hidden_states  # Query is always the current hidden states
        attn_output = flash_attn_func(q, k, v, dropout_p=0.0)

        # Reshape back to the original shape (batch_size, seq_len, hidden_size)
        attn_output = attn_output.reshape(batch_size, seq_len, hidden_size)

        # Cast attn_output back to the original dtype if necessary
        attn_output = attn_output.type(hidden_states.dtype)

        # Return attention output and updated cached keys/values
        return attn_output, (k, v)


    def batch_inference(self, dataset, batch_size):
        inference_times = []
        all_predictions = []

        for i in range(0, len(dataset), batch_size):
            batch = dataset[i: i + batch_size]
            batch_texts1 = batch['sentence1']
            batch_texts2 = batch['sentence2']

            inputs = self.prepare_batch_inputs(batch_texts1, batch_texts2)

            torch.cuda.synchronize()
            start_time = time.time()

            with torch.no_grad():
                outputs = self.model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    output_hidden_states=True,
                )

                hidden_states = outputs.hidden_states[-1]
                # Apply Flash Attention with KV caching support
                attention_output, _ = self.apply_flash_attention(hidden_states)
                logits = self.model.classifier(attention_output.mean(dim=1).type(torch.float32))
                probabilities = F.softmax(logits, dim=-1)

            torch.cuda.synchronize()
            end_time = time.time()
            batch_inference_time = end_time - start_time
            inference_times.append(batch_inference_time)

            all_predictions.extend(probabilities.cpu().numpy())

        performance_metrics = {
            'total_samples': len(dataset),
            'batch_size': batch_size,
            'device': str(self.device),
            'inference_times': inference_times,
            'avg_batch_latency_ms': np.mean(inference_times) * 1000,
            'std_batch_latency_ms': np.std(inference_times) * 1000,
            'avg_sample_latency_ms': (np.mean(inference_times) * 1000) / batch_size
        }

        return all_predictions, performance_metrics


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.config.use_cache = True
inference_model = InferenceModel(model=model, tokenizer=tokenizer, device=device)

batch_sizes = [32, 64, 128, 512, 1024]

for batch in batch_sizes:
    all_predictions, performance_metrics = inference_model.batch_inference(dataset["test"], batch_size=batch)

    for metric, value in performance_metrics.items():
            print(f"{metric}: {value}")
    print("\n---\n")


total_samples: 1725
batch_size: 32
device: cuda
inference_times: [0.008347272872924805, 0.006960391998291016, 0.006997823715209961, 0.006956577301025391, 0.006833553314208984, 0.0068645477294921875, 0.007044076919555664, 0.006682395935058594, 0.006621122360229492, 0.0066111087799072266, 0.006586551666259766, 0.006570339202880859, 0.006794929504394531, 0.006838560104370117, 0.006660938262939453, 0.006474494934082031, 0.006488323211669922, 0.006643533706665039, 0.006551504135131836, 0.006559848785400391, 0.00655364990234375, 0.006518840789794922, 0.010175466537475586, 0.006676435470581055, 0.006707906723022461, 0.006567955017089844, 0.0064737796783447266, 0.006512165069580078, 0.007138490676879883, 0.006635904312133789, 0.006411552429199219, 0.0066967010498046875, 0.0069310665130615234, 0.006721973419189453, 0.006443977355957031, 0.006539344787597656, 0.0064661502838134766, 0.006536960601806641, 0.00652003288269043, 0.006426811218261719, 0.0066106319427490234, 0.0066525936126708984, 0.00

Conclusion
===========
We enhanced `TinyBERT's` performance on the `GLUE` "MRPC" task by adding a `LoRA adapter` and applying `quantization-aware training`, which resulted in an accuracy increase from **~40% to 82%** on the validation set. To optimize inference, we experimented with techniques like `Flash Attention` and `KV caching`. Flash Attention led to significant improvements in inference efficiency. However, when combined with KV caching, there was no notable performance gain, likely because KV caching is more beneficial for sequence-to-sequence tasks, whereas MRPC is a classification task.