# Optimizing for Inference Speed

This notebook demonstrates techniques for optimizing a trained language model for faster inference.

## 1. Setup and Imports

First, let's import the necessary libraries:

In [3]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import onnx
import onnxruntime as ort

print(f"transformers version: {transformers.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

transformers version: 4.45.1
PyTorch version: 2.4.1+cu121
CUDA available: False


## 2. Load Pre-trained Model

We'll use a GPT-2 model for this example:

In [None]:
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

## 3. Apply Post-Training Quantization to INT8

We'll use PyTorch's dynamic quantization:

In [None]:
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

print("Model quantized to INT8")

## 4. Use Knowledge Distillation (Simulated)

In practice, knowledge distillation involves training a smaller model to mimic a larger one. For this notebook, we'll simulate a distilled model by using a smaller version of GPT-2:

In [None]:
distilled_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
distilled_model.to(device)
print("Loaded distilled model (DistilGPT-2)")

## 5. Implement Efficient Attention Mechanism
For this example, we'll use the Flash Attention mechanism using PyTorch's `scaled_dot_product_attention` function, which under the hood uses Flash Attention when enabled `torch.backends.cuda.enable_flash_sdp(True)`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel
import torch.backends.cuda
from transformers import GPT2Config, GPT2Model

class FlashAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        
        self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
        
        self.attn_dropout = config.attn_pdrop
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

    def _split_heads(self, tensor, num_heads, attn_head_size):
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
        tensor = tensor.view(new_shape)
        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        return tensor.view(new_shape)

    def forward(self, hidden_states, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False):
        qkv = self.c_attn(hidden_states)
        query, key, value = qkv.split(self.split_size, dim=2)
        
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)
        
        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        
        if use_cache is True:
            present = (key, value)
        else:
            present = None
        
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
            attention_mask = attention_mask.to(dtype=query.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min

        # Use scaled_dot_product_attention with Flash Attention
        with sdpa_kernel():
            attn_output = F.scaled_dot_product_attention(
                query, key, value,
                attn_mask=attention_mask,
                dropout_p=self.attn_dropout if self.training else 0.0,
                is_causal=True,
                need_weights=output_attentions
            )
        
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)
        
        outputs = (attn_output, present)
        if output_attentions:
            outputs += (None,)  # We don't have attention weights due to using scaled_dot_product_attention
        
        return outputs  # a, present, (attentions)

# Enable Flash Attention globally
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

# Function to check if Flash Attention is being used
def is_using_flash_attention():
    x = torch.randn(2, 4, 8, 16, device='cuda')
    with sdpa_kernel():
        y = F.scaled_dot_product_attention(x, x, x)
    return y.is_contiguous()

# Modified function to replace attention layers
def replace_attention_layers(module, config):
    for name, child in module.named_children():
        if isinstance(child, transformers.models.gpt2.modeling_gpt2.GPT2Attention):
            setattr(module, name, FlashAttention(config))
        else:
            replace_attention_layers(child, config)

# Load the model and its configuration
model_name = "gpt2"
model = GPT2Model.from_pretrained(model_name)
config = model.config

# After loading the model, replace the attention layers
replace_attention_layers(model, config)
print("Replaced attention layers with FlashAttention")

# Verify Flash Attention usage
print(f"Using Flash Attention: {is_using_flash_attention()}")

## 6. Optimize Model Architecture

We'll replace LayerNorm with RMSNorm for efficiency:

In [None]:
from torch import nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

def replace_layernorm_with_rmsnorm(model):
    for name, module in model.named_children():
        if isinstance(module, nn.LayerNorm):
            setattr(model, name, RMSNorm(module.normalized_shape[0]))
        else:
            replace_layernorm_with_rmsnorm(module)

replace_layernorm_with_rmsnorm(model)
print("Replaced LayerNorm with RMSNorm")

## 7. Benchmark Inference Speed

Let's compare the inference speed of our original and optimized models:

In [None]:
def benchmark_inference(model_or_func, input_ids, num_runs=100):
    if hasattr(model_or_func, 'eval'):
        model_or_func.eval()
    
    with torch.no_grad():
        # Warmup
        for _ in range(10):
            _ = model_or_func(input_ids)
        
        start_time = time.time()
        for _ in range(num_runs):
            _ = model_or_func(input_ids)
        end_time = time.time()
    
    return (end_time - start_time) / num_runs

input_text = "Once upon a time"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

print("Original model inference time:")
original_time = benchmark_inference(model, input_ids)
print(f"{original_time:.4f} seconds")

print("\nQuantized model inference time:")
quantized_time = benchmark_inference(quantized_model, input_ids)
print(f"{quantized_time:.4f} seconds")

print("\nDistilled model inference time:")
distilled_time = benchmark_inference(distilled_model, input_ids)
print(f"{distilled_time:.4f} seconds")

print("\nFlash Attention model inference time:")
flash_time = benchmark_inference(model, input_ids)
print(f"{flash_time:.4f} seconds")

## 8. Export to ONNX for Optimized Inference

Finally, let's export our optimized model to ONNX format for even faster inference:

In [None]:
# Export the model to ONNX
dummy_input = torch.randint(0, 50000, (1, 512), dtype=torch.long).to(device)
torch.onnx.export(model, dummy_input, "optimized_gpt2.onnx",
                  input_names=['input_ids'],
                  output_names=['logits'],
                  dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'},
                                'logits': {0: 'batch_size', 1: 'sequence'}},
                  opset_version=12)

# Create an ONNX inference session
import onnxruntime as ort
ort_session = ort.InferenceSession("optimized_gpt2.onnx")

# Run inference with ONNX Runtime
def onnx_inference(session, input_ids):
    ort_inputs = {'input_ids': input_ids.cpu().numpy()}
    ort_outputs = session.run(None, ort_inputs)
    return ort_outputs[0]

print("\nONNX model inference time:")
onnx_time = benchmark_inference(lambda x: onnx_inference(ort_session, x), input_ids)
print(f"{onnx_time:.4f} seconds")