# Optimizing a Large Language Model for Low-Latency Inference

This notebook demonstrates how to optimize a pre-trained large language model for low-latency inference using quantization and other optimization techniques.

### Table of Contents

1. Setup and Imports
2. Load Pre-trained Model
3. Mixed Precision Fine-tuning
4. Post-Training Quantization (PTQ)
5. Efficient INT8 GEMM Operations
6. Quantized Key-Value Caching
7. Dynamic Quantization for Activations
8. Deployment with ONNX Runtime
9. Performance Benchmarking

## 1. Setup and Imports

First, lets import the necessary liberaries

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import onnx
import onnxruntime as ort

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)

## 2. Load Pre-trained Model

We'll use a pre-trained GPT-2 model for this example:

In [None]:
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Add a padding token (use eos_token or a new token)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
    model.resize_token_embeddings(len(tokenizer))  # Adjust token embeddings for the new token
        
print(f"Model loaded: {model_name}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

## 3. Mixed Precision Fine-tuning

Now, let's implement mixed precision fine-tuning:

In [None]:
def fine_tune(model, train_dataloader, optimizer, epochs=3):
    scaler = torch.amp.GradScaler()
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():  # autocast for mixed precision
                inputs = batch["input_ids"].to(device)
                labels = inputs.clone()
                outputs = model(inputs, labels=labels)
                loss = outputs.loss
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

# Load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

# Tokenize the dataset properly
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Convert the dataset to the format expected by the model
tokenized_dataset = tokenized_dataset.with_format("torch")

# Use a data collator to ensure uniform batches and tensor conversion
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Create DataLoader
train_dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Fine-tune the model
fine_tune(model, train_dataloader, optimizer)

## 4. Post-Training Quantization (PTQ)

Now that we have fine-tuned our model, let's apply Post-Training Quantization:

In [None]:
import torch.quantization

def apply_ptq(model):
    # Specify quantization configuration
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    # Prepare model for quantization
    model_prepared = torch.quantization.prepare(model)
    
    # Calibrate the model (you would typically do this with a calibration dataset)
    with torch.no_grad():
        for batch in train_dataloader:
            inputs = batch["input_ids"].to(device)
            model_prepared(inputs)
    
    # Convert to quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    
    # Keep embedding layers and softmax in FP16
    model_quantized.transformer.wte = model.transformer.wte.half()
    model_quantized.transformer.wpe = model.transformer.wpe.half()
    model_quantized.lm_head = model.lm_head.half()
    
    return model_quantized

model_int8 = apply_ptq(model.cpu())  # PTQ requires CPU
print("Model quantized to INT8")

## 5. Implement efficient INT8 GEMM operations for attention mechanisms

For this step, we'll use PyTorch's built-in quantized linear layers:

In [None]:
class QuantizedAttention(nn.Module):
    def __init__(self, attention):
        super().__init__()
        self.query = torch.quantization.quantize_dynamic(
            attention.query, {torch.nn.Linear}, dtype=torch.qint8
        )
        self.key = torch.quantization.quantize_dynamic(
            attention.key, {torch.nn.Linear}, dtype=torch.qint8
        )
        self.value = torch.quantization.quantize_dynamic(
            attention.value, {torch.nn.Linear}, dtype=torch.qint8
        )

    def forward(self, hidden_states):
        query_layer = self.query(hidden_states)
        key_layer = self.key(hidden_states)
        value_layer = self.value(hidden_states)
        return query_layer, key_layer, value_layer

# Replace attention layers with quantized versions
for layer in model_int8.transformer.h:
    layer.attn = QuantizedAttention(layer.attn)

print("Implemented efficient INT8 GEMM operations for attention mechanisms")

## 6. Apply quantized key-value caching for faster autoregressive inference

In [None]:
class CachedQuantizedAttention(QuantizedAttention):
    def __init__(self, attention):
        super().__init__(attention)
        self.cached_key = None
        self.cached_value = None

    def forward(self, hidden_states):
        query_layer = self.query(hidden_states)
        
        if self.cached_key is None:
            key_layer = self.key(hidden_states)
            value_layer = self.value(hidden_states)
            self.cached_key = key_layer
            self.cached_value = value_layer
        else:
            last_hidden_state = hidden_states[:, -1:, :]
            key_layer = self.key(last_hidden_state)
            value_layer = self.value(last_hidden_state)
            self.cached_key = torch.cat([self.cached_key, key_layer], dim=1)
            self.cached_value = torch.cat([self.cached_value, value_layer], dim=1)
        
        return query_layer, self.cached_key, self.cached_value

# Replace attention layers with cached quantized versions
for layer in model_int8.transformer.h:
    layer.attn = CachedQuantizedAttention(layer.attn)

print("Applied quantized key-value caching")

## 7. Use dynamic quantization for activations to handle varying sequence lengths

Dynamic quantization is applied at runtime, so we'll implement a wrapper for inference:

In [None]:
def dynamic_quantized_inference(model, input_ids):
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    with torch.no_grad():
        return quantized_model(input_ids)

print("Implemented dynamic quantization for activations")

## 8. Deploy using an inference-optimized runtime like ONNX Runtime

For this step, we'll export our model to ONNX format and use ONNX Runtime for inference:

In [None]:
def export_to_onnx(model, input_shape):
    dummy_input = torch.randint(0, 1000, input_shape)
    torch.onnx.export(model, dummy_input, "quantized_gpt2.onnx",
                      input_names=['input_ids'],
                      output_names=['logits'],
                      dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'},
                                    'logits': {0: 'batch_size', 1: 'sequence'}})
    print("Model exported to ONNX format")

def onnx_inference(onnx_path, input_ids):
    session = ort.InferenceSession(onnx_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    return session.run([output_name], {input_name: input_ids.numpy()})[0]

# Export the model to ONNX
export_to_onnx(model_int8, (1, 512))

# Perform inference using ONNX Runtime
input_ids = torch.randint(0, 1000, (1, 128))
onnx_output = onnx_inference("quantized_gpt2.onnx", input_ids)
print("Performed inference using ONNX Runtime")

## 9. Performance Benchmarking

Finally, let's benchmark our optimized model against the original

In [None]:
import time

def benchmark(model, input_ids, runs=100):
    start_time = time.time()
    for _ in range(runs):
        with torch.no_grad():
            _ = model(input_ids)
    end_time = time.time()
    return (end_time - start_time) / runs

input_ids = torch.randint(0, 1000, (1, 512)).to(device)

original_time = benchmark(model, input_ids)
quantized_time = benchmark(model_int8, input_ids)
onnx_time = benchmark(lambda x: onnx_inference("quantized_gpt2.onnx", x), input_ids.cpu())

print(f"Original model average inference time: {original_time:.4f} seconds")
print(f"Quantized model average inference time: {quantized_time:.4f} seconds")
print(f"ONNX Runtime average inference time: {onnx_time:.4f} seconds")
print(f"Speedup: {original_time / onnx_time:.2f}x")