# Scaling and Optimization

This notebook explores techniques for scaling transformers efficiently, including model compression, quantization, optimized inference engines, and deployment strategies.

## 1. Introduction to Model Optimization

As models grow larger, optimization becomes crucial for practical deployment. We'll explore various techniques to reduce memory, improve speed, and enable deployment on resource-constrained devices.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import pandas as pd
from IPython.display import display, HTML
import time
import math

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Understanding Model Optimization Challenges

Let's visualize the challenges of deploying large transformer models.

In [None]:
# Model scaling challenges
def visualize_scaling_challenges():
    """Visualize the challenges of scaling transformer models."""
    
    # Model sizes and requirements
    models = ['BERT-Base', 'BERT-Large', 'GPT-2', 'GPT-3', 'LLaMA-7B', 'LLaMA-65B', 'GPT-4*']
    params = [110e6, 340e6, 1.5e9, 175e9, 7e9, 65e9, 1.7e12]
    
    # Memory requirements (GB)
    fp32_memory = [p * 4 / 1e9 for p in params]  # 4 bytes per param
    fp16_memory = [p * 2 / 1e9 for p in params]  # 2 bytes per param
    int8_memory = [p * 1 / 1e9 for p in params]  # 1 byte per param
    int4_memory = [p * 0.5 / 1e9 for p in params]  # 0.5 bytes per param
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Memory requirements by precision
    ax = axes[0, 0]
    x = np.arange(len(models))
    width = 0.2
    
    ax.bar(x - 1.5*width, fp32_memory, width, label='FP32', color='darkred')
    ax.bar(x - 0.5*width, fp16_memory, width, label='FP16', color='orange')
    ax.bar(x + 0.5*width, int8_memory, width, label='INT8', color='green')
    ax.bar(x + 1.5*width, int4_memory, width, label='INT4', color='darkgreen')
    
    # Add GPU memory lines
    gpu_limits = {'T4': 16, 'V100': 32, 'A100': 80}
    for gpu, limit in gpu_limits.items():
        ax.axhline(y=limit, color='gray', linestyle='--', alpha=0.5)
        ax.text(len(models)-0.5, limit+5, gpu, fontsize=9, color='gray')
    
    ax.set_xlabel('Model', fontsize=12)
    ax.set_ylabel('Memory Required (GB)', fontsize=12)
    ax.set_title('Memory Requirements by Precision', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(models, rotation=45, ha='right')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. Inference speed vs model size
    ax = axes[0, 1]
    
    # Simulated inference times (ms per token)
    inference_times = [5, 8, 15, 500, 30, 200, 2000]
    
    scatter = ax.scatter(params, inference_times, s=100, c=range(len(models)), 
                        cmap='viridis', alpha=0.6)
    
    for i, model in enumerate(models):
        ax.annotate(model, (params[i], inference_times[i]), 
                   xytext=(5, 5), textcoords='offset points', fontsize=8)
    
    ax.set_xlabel('Model Parameters', fontsize=12)
    ax.set_ylabel('Inference Time (ms/token)', fontsize=12)
    ax.set_title('Model Size vs Inference Speed', fontsize=14)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)
    
    # 3. Optimization techniques impact
    ax = axes[1, 0]
    
    techniques = ['Original', 'Pruning\n(50%)', 'Quantization\n(INT8)', 
                 'Distillation\n(6x smaller)', 'Combined\nOptimization']
    memory_reduction = [100, 60, 25, 17, 10]
    speed_improvement = [100, 140, 180, 600, 800]
    quality_retention = [100, 98, 97, 95, 93]
    
    x = np.arange(len(techniques))
    width = 0.25
    
    ax.bar(x - width, memory_reduction, width, label='Memory (%)', color='lightcoral')
    ax.bar(x, speed_improvement, width, label='Speed (%)', color='lightgreen')
    ax.bar(x + width, quality_retention, width, label='Quality (%)', color='lightblue')
    
    ax.set_ylabel('Relative Performance (%)', fontsize=12)
    ax.set_title('Impact of Optimization Techniques', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(techniques)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # 4. Deployment scenarios
    ax = axes[1, 1]
    ax.axis('off')
    
    deployment_text = """
    🚀 Deployment Scenarios and Recommended Optimizations:
    
    📱 Mobile/Edge (< 1GB memory):
       • Extreme quantization (INT4)
       • Model distillation (10-50x smaller)
       • Pruning + quantization combo
    
    💻 Desktop/Laptop (4-16GB):
       • INT8 quantization
       • Dynamic quantization
       • Flash attention for long sequences
    
    🖥️ Server (16-80GB):
       • FP16 mixed precision
       • KV caching for generation
       • Batch optimization
    
    ☁️ Cloud (Multi-GPU):
       • Model parallelism
       • Pipeline parallelism
       • Optimized serving frameworks
    """
    
    ax.text(0.05, 0.95, deployment_text, transform=ax.transAxes,
           fontsize=11, verticalalignment='top',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()
    
    print("* GPT-4 parameters are estimated")

visualize_scaling_challenges()

## 3. Model Pruning Techniques

Pruning removes unnecessary weights to reduce model size and improve inference speed.

In [None]:
class MagnitudePruning:
    """Simple magnitude-based pruning implementation."""
    
    def __init__(self, sparsity: float = 0.5):
        self.sparsity = sparsity
        
    def prune_weights(self, weight: torch.Tensor) -> torch.Tensor:
        """Prune weights based on magnitude."""
        # Calculate threshold
        threshold = torch.quantile(torch.abs(weight), self.sparsity)
        
        # Create mask
        mask = torch.abs(weight) > threshold
        
        # Apply mask
        pruned_weight = weight * mask
        
        return pruned_weight, mask
    
    def calculate_sparsity(self, weight: torch.Tensor) -> float:
        """Calculate actual sparsity of tensor."""
        return (weight == 0).sum().item() / weight.numel()

# Demonstrate pruning
def demonstrate_pruning():
    """Show pruning effects on weight matrices."""
    
    # Create a sample weight matrix
    weight = torch.randn(256, 256) * 0.1
    
    # Apply different sparsity levels
    sparsity_levels = [0.0, 0.3, 0.5, 0.7, 0.9]
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    results = []
    
    for i, sparsity in enumerate(sparsity_levels):
        pruner = MagnitudePruning(sparsity=sparsity)
        pruned_weight, mask = pruner.prune_weights(weight)
        actual_sparsity = pruner.calculate_sparsity(pruned_weight)
        
        # Visualize weight matrix
        im = axes[i].imshow(pruned_weight[:50, :50].numpy(), cmap='coolwarm', 
                           vmin=-0.3, vmax=0.3)
        axes[i].set_title(f'Sparsity: {sparsity:.0%} (Actual: {actual_sparsity:.1%})',
                         fontsize=12)
        axes[i].set_xlabel('Input dimension')
        axes[i].set_ylabel('Output dimension')
        
        # Calculate statistics
        results.append({
            'Target Sparsity': f'{sparsity:.0%}',
            'Actual Sparsity': f'{actual_sparsity:.1%}',
            'Non-zero Weights': (~mask).sum().item(),
            'Compression Ratio': f'{1/(1-actual_sparsity):.1f}x'
        })
    
    # Remove extra subplot
    axes[-1].axis('off')
    
    plt.suptitle('Effect of Magnitude Pruning on Weight Matrices', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Show statistics
    df_results = pd.DataFrame(results)
    print("\nPruning Statistics:")
    display(df_results)
    
    # Visualize pruning distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Weight magnitude distribution
    weights_flat = weight.flatten().numpy()
    ax1.hist(weights_flat, bins=50, alpha=0.7, label='Original', density=True)
    
    pruned_50 = pruned_weight.flatten().numpy()
    pruned_50_nonzero = pruned_50[pruned_50 != 0]
    ax1.hist(pruned_50_nonzero, bins=50, alpha=0.7, label='After 50% pruning', density=True)
    
    ax1.set_xlabel('Weight Value', fontsize=12)
    ax1.set_ylabel('Density', fontsize=12)
    ax1.set_title('Weight Distribution Before/After Pruning', fontsize=14)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Sparsity pattern
    ax2.spy(mask[:100, :100].numpy(), markersize=1)
    ax2.set_title('Sparsity Pattern (50% pruning, first 100x100)', fontsize=14)
    ax2.set_xlabel('Input dimension')
    ax2.set_ylabel('Output dimension')
    
    plt.tight_layout()
    plt.show()

demonstrate_pruning()

## 4. Quantization Techniques

Quantization reduces numerical precision to save memory and accelerate computation.

In [None]:
class QuantizationDemo:
    """Demonstrate different quantization techniques."""
    
    @staticmethod
    def quantize_tensor(tensor: torch.Tensor, bits: int = 8, 
                       symmetric: bool = True) -> Tuple[torch.Tensor, float, float]:
        """Quantize tensor to specified bits."""
        if symmetric:
            # Symmetric quantization
            qmin = -(2 ** (bits - 1))
            qmax = 2 ** (bits - 1) - 1
            scale = tensor.abs().max() / qmax
            zero_point = 0
        else:
            # Asymmetric quantization
            qmin = 0
            qmax = 2 ** bits - 1
            scale = (tensor.max() - tensor.min()) / (qmax - qmin)
            zero_point = qmin - tensor.min() / scale
        
        # Quantize
        quantized = torch.round(tensor / scale + zero_point)
        quantized = torch.clamp(quantized, qmin, qmax)
        
        return quantized.to(torch.int8), scale.item(), zero_point
    
    @staticmethod
    def dequantize_tensor(quantized: torch.Tensor, scale: float, 
                         zero_point: float) -> torch.Tensor:
        """Dequantize tensor back to float."""
        return (quantized.float() - zero_point) * scale

# Visualize quantization effects
def visualize_quantization():
    """Show effects of different quantization schemes."""
    
    # Create sample data
    x = torch.linspace(-2, 2, 1000)
    y = torch.sin(2 * np.pi * x) + 0.1 * torch.randn_like(x)
    
    quantizer = QuantizationDemo()
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Different bit widths
    bit_widths = [32, 16, 8, 4, 2, 1]
    
    for i, (ax, bits) in enumerate(zip(axes.flatten(), bit_widths)):
        if bits == 32:
            # Original (float32)
            ax.plot(x, y, 'b-', alpha=0.8, label='Original')
            quantized_y = y
            error = 0
        else:
            # Quantize and dequantize
            q_tensor, scale, zp = quantizer.quantize_tensor(y, bits=bits)
            quantized_y = quantizer.dequantize_tensor(q_tensor, scale, zp)
            
            ax.plot(x, y, 'b-', alpha=0.3, label='Original')
            ax.plot(x, quantized_y, 'r-', alpha=0.8, label=f'{bits}-bit')
            
            # Calculate error
            error = torch.mean((y - quantized_y) ** 2).item()
        
        ax.set_title(f'{bits}-bit Quantization\nMSE: {error:.6f}', fontsize=12)
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('Effect of Quantization on Signal Representation', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Quantization error analysis
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Error vs bit width
    errors = []
    compression_ratios = []
    
    for bits in range(1, 17):
        q_tensor, scale, zp = quantizer.quantize_tensor(y, bits=bits)
        dequantized = quantizer.dequantize_tensor(q_tensor, scale, zp)
        error = torch.mean((y - dequantized) ** 2).item()
        errors.append(error)
        compression_ratios.append(32 / bits)
    
    ax1.semilogy(range(1, 17), errors, 'bo-')
    ax1.set_xlabel('Bit Width', fontsize=12)
    ax1.set_ylabel('Mean Squared Error', fontsize=12)
    ax1.set_title('Quantization Error vs Bit Width', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    # Add annotations for common bit widths
    for bits in [1, 2, 4, 8, 16]:
        idx = bits - 1
        ax1.annotate(f'{bits}bit', (bits, errors[idx]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=9)
    
    # Compression ratio
    ax2.plot(range(1, 17), compression_ratios, 'go-')
    ax2.set_xlabel('Bit Width', fontsize=12)
    ax2.set_ylabel('Compression Ratio', fontsize=12)
    ax2.set_title('Memory Compression vs Bit Width', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_quantization()

# Compare quantization schemes
print("\n--- Quantization Schemes Comparison ---")

# Create a sample weight matrix
weight = torch.randn(100, 100) * 0.5

schemes = [
    ('Symmetric INT8', 8, True),
    ('Asymmetric INT8', 8, False),
    ('INT4', 4, True),
    ('Binary', 1, True)
]

quantizer = QuantizationDemo()
results = []

for name, bits, symmetric in schemes:
    q_weight, scale, zp = quantizer.quantize_tensor(weight, bits, symmetric)
    dq_weight = quantizer.dequantize_tensor(q_weight, scale, zp)
    
    error = torch.mean((weight - dq_weight) ** 2).item()
    compression = 32 / bits
    
    results.append({
        'Scheme': name,
        'Bits': bits,
        'Scale': f'{scale:.4f}',
        'Zero Point': f'{zp:.1f}',
        'MSE': f'{error:.6f}',
        'Compression': f'{compression:.1f}x'
    })

df_quant = pd.DataFrame(results)
display(df_quant)

## 5. Knowledge Distillation

Distillation transfers knowledge from a large teacher model to a smaller student model.

In [None]:
def visualize_distillation():
    """Visualize knowledge distillation concept and process."""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Distillation concept
    ax = axes[0, 0]
    ax.axis('off')
    ax.set_title('Knowledge Distillation Concept', fontsize=14)
    
    # Teacher model
    teacher_rect = plt.Rectangle((0.1, 0.5), 0.3, 0.4, 
                                facecolor='lightblue', edgecolor='black', linewidth=2)
    ax.add_patch(teacher_rect)
    ax.text(0.25, 0.7, 'Teacher\nModel\n(Large)', ha='center', va='center', 
           fontsize=12, weight='bold')
    
    # Student model
    student_rect = plt.Rectangle((0.6, 0.6), 0.2, 0.2, 
                                facecolor='lightgreen', edgecolor='black', linewidth=2)
    ax.add_patch(student_rect)
    ax.text(0.7, 0.7, 'Student\nModel\n(Small)', ha='center', va='center', 
           fontsize=10, weight='bold')
    
    # Knowledge transfer
    ax.arrow(0.4, 0.7, 0.15, 0, head_width=0.03, head_length=0.03, 
            fc='red', ec='red', linewidth=2)
    ax.text(0.475, 0.75, 'Knowledge\nTransfer', ha='center', fontsize=10, color='red')
    
    # Inputs/Outputs
    ax.text(0.25, 0.4, 'Input', ha='center', fontsize=10)
    ax.text(0.25, 0.3, 'Soft Labels\n(Probabilities)', ha='center', fontsize=9, style='italic')
    ax.text(0.7, 0.5, 'Input', ha='center', fontsize=10)
    ax.text(0.7, 0.4, 'Learn from\nSoft Labels', ha='center', fontsize=9, style='italic')
    
    # Size comparison
    ax.text(0.5, 0.1, 'Teacher: 340M params → Student: 60M params (5.7x compression)', 
           ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='wheat'))
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    
    # 2. Temperature effect
    ax = axes[0, 1]
    
    # Softmax with different temperatures
    logits = torch.tensor([2.0, 1.0, 0.5, 0.3, -0.5])
    temperatures = [0.5, 1.0, 3.0, 10.0]
    
    x = np.arange(len(logits))
    
    for temp in temperatures:
        probs = F.softmax(logits / temp, dim=0)
        ax.plot(x, probs, 'o-', label=f'T={temp}', markersize=8)
    
    ax.set_xlabel('Class', fontsize=12)
    ax.set_ylabel('Probability', fontsize=12)
    ax.set_title('Effect of Temperature on Softmax', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels(['Cat', 'Dog', 'Bird', 'Fish', 'Other'])
    
    # 3. Loss composition
    ax = axes[1, 0]
    
    # Pie chart of loss components
    sizes = [70, 30]  # Distillation loss, Student loss
    labels = ['Distillation Loss\n(KL Divergence)', 'Student Loss\n(Cross Entropy)']
    colors = ['lightcoral', 'lightblue']
    explode = (0.1, 0)
    
    ax.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.0f%%',
          shadow=True, startangle=90)
    ax.set_title('Typical Loss Composition (α=0.7)', fontsize=14)
    
    # 4. Performance comparison
    ax = axes[1, 1]
    
    methods = ['Teacher\n(BERT-L)', 'Student\n(No KD)', 'Student\n(With KD)', 
              'Student\n(KD+Data Aug)']
    accuracy = [93.5, 85.2, 91.8, 92.5]
    params = [340, 60, 60, 60]
    
    x = np.arange(len(methods))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, accuracy, width, label='Accuracy (%)', color='lightgreen')
    
    # Add parameter count on top
    for i, (bar, param) in enumerate(zip(bars1, params)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
               f'{param}M', ha='center', va='bottom', fontsize=9)
    
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title('Distillation Performance Comparison', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(methods)
    ax.set_ylim(80, 95)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add performance retention percentage
    teacher_acc = accuracy[0]
    for i in range(1, len(accuracy)):
        retention = accuracy[i] / teacher_acc * 100
        ax.text(i, 82, f'{retention:.1f}%\nretention', ha='center', fontsize=9, 
               color='darkgreen', weight='bold')
    
    plt.tight_layout()
    plt.show()

visualize_distillation()

# Distillation loss implementation
print("\n--- Distillation Loss Example ---")

def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
    """Calculate combined distillation and student loss."""
    
    # Soft targets from teacher
    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    
    # Student predictions with temperature
    soft_predictions = F.log_softmax(student_logits / temperature, dim=-1)
    
    # KL divergence loss (distillation)
    distill_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean')
    distill_loss *= temperature ** 2  # Scale by T^2
    
    # Standard cross-entropy loss
    student_loss = F.cross_entropy(student_logits, labels)
    
    # Combined loss
    total_loss = alpha * distill_loss + (1 - alpha) * student_loss
    
    return total_loss, distill_loss, student_loss

# Example
batch_size = 4
num_classes = 10

teacher_logits = torch.randn(batch_size, num_classes) * 2
student_logits = torch.randn(batch_size, num_classes)
labels = torch.randint(0, num_classes, (batch_size,))

total, distill, student = distillation_loss(student_logits, teacher_logits, labels)

print(f"Distillation Loss: {distill:.4f}")
print(f"Student Loss: {student:.4f}")
print(f"Total Loss: {total:.4f}")
print(f"\nLoss Breakdown:")
print(f"  Distillation: {0.7 * distill / total * 100:.1f}%")
print(f"  Student: {0.3 * student / total * 100:.1f}%")

## 6. Optimized Attention Mechanisms

Efficient attention implementations are crucial for scaling to longer sequences.

In [None]:
def visualize_attention_optimizations():
    """Visualize different attention optimization techniques."""
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    seq_len = 16
    
    # 1. Standard Attention
    ax = axes[0, 0]
    standard_mask = np.ones((seq_len, seq_len))
    im = ax.imshow(standard_mask, cmap='Blues', vmin=0, vmax=1)
    ax.set_title('Standard Attention\nO(n²) memory', fontsize=12)
    ax.set_xlabel('Keys')
    ax.set_ylabel('Queries')
    
    # 2. Sparse Attention
    ax = axes[0, 1]
    sparse_mask = np.zeros((seq_len, seq_len))
    # Local attention window
    window_size = 4
    for i in range(seq_len):
        for j in range(max(0, i-window_size//2), min(seq_len, i+window_size//2+1)):
            sparse_mask[i, j] = 1
    # Global attention tokens
    sparse_mask[0, :] = 1
    sparse_mask[:, 0] = 1
    
    ax.imshow(sparse_mask, cmap='Greens', vmin=0, vmax=1)
    ax.set_title('Sparse Attention\nO(n·w) memory', fontsize=12)
    ax.set_xlabel('Keys')
    ax.set_ylabel('Queries')
    
    # 3. Flash Attention blocks
    ax = axes[0, 2]
    block_size = 4
    flash_visual = np.zeros((seq_len, seq_len))
    
    # Show block processing
    for i in range(0, seq_len, block_size):
        for j in range(0, seq_len, block_size):
            # Different colors for different blocks
            color = (i//block_size + j//block_size) % 3 + 1
            flash_visual[i:i+block_size, j:j+block_size] = color / 3
    
    ax.imshow(flash_visual, cmap='viridis')
    ax.set_title('Flash Attention\nBlock-wise computation', fontsize=12)
    ax.set_xlabel('Keys')
    ax.set_ylabel('Queries')
    
    # 4. Multi-Query Attention
    ax = axes[1, 0]
    ax.axis('off')
    
    # Visualize MQA
    num_heads = 8
    head_dim = 64
    
    # Standard MHA
    ax.text(0.2, 0.9, 'Multi-Head Attention', fontsize=12, weight='bold')
    for i in range(num_heads):
        # Q, K, V for each head
        y_pos = 0.7 - i * 0.08
        for j, (label, color) in enumerate([('Q', 'lightblue'), 
                                           ('K', 'lightgreen'), 
                                           ('V', 'lightcoral')]):
            rect = plt.Rectangle((0.1 + j*0.06, y_pos), 0.05, 0.06,
                               facecolor=color, edgecolor='black')
            ax.add_patch(rect)
    
    # MQA
    ax.text(0.6, 0.9, 'Multi-Query Attention', fontsize=12, weight='bold')
    # Multiple Q heads
    for i in range(num_heads):
        y_pos = 0.7 - i * 0.08
        rect = plt.Rectangle((0.5, y_pos), 0.05, 0.06,
                           facecolor='lightblue', edgecolor='black')
        ax.add_patch(rect)
    
    # Single K, V
    rect_k = plt.Rectangle((0.56, 0.4), 0.05, 0.06,
                         facecolor='lightgreen', edgecolor='black', linewidth=2)
    rect_v = plt.Rectangle((0.62, 0.4), 0.05, 0.06,
                         facecolor='lightcoral', edgecolor='black', linewidth=2)
    ax.add_patch(rect_k)
    ax.add_patch(rect_v)
    
    ax.text(0.2, 0.05, f'Memory: {num_heads * 3 * head_dim} per token', 
           ha='center', fontsize=10)
    ax.text(0.6, 0.05, f'Memory: {num_heads + 2} * {head_dim} per token\n'
                       f'({(num_heads + 2) / (num_heads * 3) * 100:.0f}% of MHA)', 
           ha='center', fontsize=10)
    
    ax.set_xlim(0, 0.8)
    ax.set_ylim(0, 1)
    
    # 5. Attention complexity comparison
    ax = axes[1, 1]
    
    seq_lengths = np.array([512, 1024, 2048, 4096, 8192, 16384])
    
    # Memory requirements (relative)
    standard_mem = seq_lengths ** 2
    sparse_mem = seq_lengths * 256  # window size 256
    flash_mem = seq_lengths * 64  # block size 64
    
    ax.loglog(seq_lengths, standard_mem, 'b-o', label='Standard Attention')
    ax.loglog(seq_lengths, sparse_mem, 'g-s', label='Sparse Attention')
    ax.loglog(seq_lengths, flash_mem, 'r-^', label='Flash Attention')
    
    ax.set_xlabel('Sequence Length', fontsize=12)
    ax.set_ylabel('Memory (relative)', fontsize=12)
    ax.set_title('Attention Memory Scaling', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 6. Speed comparison
    ax = axes[1, 2]
    
    methods = ['Standard', 'Flash\nAttention', 'xFormers', 'Sparse\n(BigBird)']
    relative_speed = [1.0, 2.5, 2.2, 3.8]
    memory_usage = [100, 30, 35, 25]
    
    x = np.arange(len(methods))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, relative_speed, width, label='Speed (x)', color='lightgreen')
    bars2 = ax.bar(x + width/2, memory_usage, width, label='Memory (%)', color='lightcoral')
    
    ax.set_ylabel('Relative Performance', fontsize=12)
    ax.set_title('Attention Implementation Comparison\n(8K sequence length)', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(methods)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

visualize_attention_optimizations()

# KV Cache demonstration
print("\n--- KV Cache for Generation ---")

class KVCacheDemo:
    """Simple KV cache demonstration."""
    
    def __init__(self, max_seq_len=2048, num_heads=8, head_dim=64):
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # Pre-allocate cache
        self.k_cache = torch.zeros(1, num_heads, max_seq_len, head_dim)
        self.v_cache = torch.zeros(1, num_heads, max_seq_len, head_dim)
        self.cache_len = 0
        
    def update(self, k_new, v_new):
        """Update cache with new key-value pairs."""
        seq_len = k_new.shape[2]
        
        # Store in cache
        self.k_cache[:, :, self.cache_len:self.cache_len + seq_len] = k_new
        self.v_cache[:, :, self.cache_len:self.cache_len + seq_len] = v_new
        self.cache_len += seq_len
        
        return self.k_cache[:, :, :self.cache_len], self.v_cache[:, :, :self.cache_len]
    
    def memory_usage(self):
        """Calculate memory usage in MB."""
        total_elements = 2 * self.max_seq_len * self.num_heads * self.head_dim
        return total_elements * 4 / (1024 ** 2)  # float32 in MB

# Demonstrate cache usage
cache = KVCacheDemo()
print(f"KV Cache initialized:")
print(f"  Max sequence length: {cache.max_seq_len}")
print(f"  Memory allocated: {cache.memory_usage():.1f} MB")

# Simulate generation
print("\nSimulating token generation:")
for i in range(5):
    # New token generates new K, V
    k_new = torch.randn(1, cache.num_heads, 1, cache.head_dim)
    v_new = torch.randn(1, cache.num_heads, 1, cache.head_dim)
    
    k_full, v_full = cache.update(k_new, v_new)
    print(f"  Step {i+1}: Cache contains {cache.cache_len} tokens")

print(f"\nWithout cache: {5 * (5+1) / 2} attention computations")
print(f"With cache: {5} attention computations (5x speedup!)")

## 7. Deployment Optimization Strategies

Let's explore strategies for efficient model deployment in production.

In [None]:
def visualize_deployment_strategies():
    """Visualize different deployment optimization strategies."""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Batching strategies
    ax = axes[0, 0]
    
    # Timeline visualization
    requests = [
        {'id': 'A', 'arrival': 0, 'length': 100},
        {'id': 'B', 'arrival': 10, 'length': 50},
        {'id': 'C', 'arrival': 15, 'length': 150},
        {'id': 'D', 'arrival': 20, 'length': 75},
        {'id': 'E', 'arrival': 25, 'length': 100}
    ]
    
    # Static batching
    y_static = 3
    batch_size = 3
    for i in range(0, len(requests), batch_size):
        batch = requests[i:i+batch_size]
        start = max(r['arrival'] for r in batch)
        duration = max(r['length'] for r in batch)
        
        rect = plt.Rectangle((start, y_static), duration, 0.8,
                           facecolor='lightblue', edgecolor='black')
        ax.add_patch(rect)
        
        for j, req in enumerate(batch):
            ax.text(start + 5, y_static + 0.2 + j*0.2, req['id'], fontsize=8)
    
    ax.text(-20, y_static + 0.4, 'Static\nBatching', fontsize=10, va='center')
    
    # Dynamic batching
    y_dynamic = 1.5
    current_time = 0
    while requests:
        # Collect requests within time window
        batch = []
        for req in requests[:]:
            if req['arrival'] <= current_time + 10:
                batch.append(req)
                requests.remove(req)
            if len(batch) >= 3:
                break
        
        if batch:
            start = current_time
            duration = max(r['length'] for r in batch)
            
            rect = plt.Rectangle((start, y_dynamic), duration, 0.8,
                               facecolor='lightgreen', edgecolor='black')
            ax.add_patch(rect)
            
            for j, req in enumerate(batch):
                ax.text(start + 5, y_dynamic + 0.2 + j*0.2, req['id'], fontsize=8)
            
            current_time = start + duration
        else:
            current_time += 10
    
    ax.text(-20, y_dynamic + 0.4, 'Dynamic\nBatching', fontsize=10, va='center')
    
    # Continuous batching
    y_continuous = 0
    ax.text(-20, y_continuous + 0.4, 'Continuous\nBatching', fontsize=10, va='center')
    ax.text(50, y_continuous + 0.4, '(Iteration-level batching for generation)',
           fontsize=9, style='italic')
    
    ax.set_xlim(-30, 200)
    ax.set_ylim(-0.5, 4.5)
    ax.set_xlabel('Time (ms)', fontsize=12)
    ax.set_title('Batching Strategies Comparison', fontsize=14)
    ax.grid(True, alpha=0.3, axis='x')
    
    # 2. Model serving architecture
    ax = axes[0, 1]
    ax.axis('off')
    ax.set_title('Production Serving Architecture', fontsize=14)
    
    # Components
    components = [
        {'name': 'Load\nBalancer', 'pos': (0.5, 0.9), 'color': 'lightcoral'},
        {'name': 'Model\nServer 1', 'pos': (0.2, 0.6), 'color': 'lightblue'},
        {'name': 'Model\nServer 2', 'pos': (0.5, 0.6), 'color': 'lightblue'},
        {'name': 'Model\nServer 3', 'pos': (0.8, 0.6), 'color': 'lightblue'},
        {'name': 'KV Cache\nStore', 'pos': (0.2, 0.3), 'color': 'lightyellow'},
        {'name': 'Model\nRegistry', 'pos': (0.5, 0.3), 'color': 'lightgreen'},
        {'name': 'Metrics\nCollector', 'pos': (0.8, 0.3), 'color': 'lightgray'}
    ]
    
    for comp in components:
        rect = plt.Rectangle((comp['pos'][0] - 0.08, comp['pos'][1] - 0.05),
                           0.16, 0.1, facecolor=comp['color'], edgecolor='black')
        ax.add_patch(rect)
        ax.text(comp['pos'][0], comp['pos'][1], comp['name'], 
               ha='center', va='center', fontsize=9)
    
    # Connections
    connections = [
        ((0.5, 0.85), (0.2, 0.65)),
        ((0.5, 0.85), (0.5, 0.65)),
        ((0.5, 0.85), (0.8, 0.65)),
        ((0.2, 0.55), (0.2, 0.35)),
        ((0.5, 0.55), (0.5, 0.35)),
        ((0.8, 0.55), (0.8, 0.35))
    ]
    
    for start, end in connections:
        ax.plot([start[0], end[0]], [start[1], end[1]], 'k-', linewidth=1)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    
    # 3. Optimization techniques comparison
    ax = axes[1, 0]
    
    techniques = ['Baseline', 'Quantized\n(INT8)', 'Pruned\n(50%)', 
                 'Distilled\n(6x)', 'All\nCombined']
    throughput = [100, 180, 150, 600, 1000]
    latency = [50, 30, 35, 10, 8]
    
    x = np.arange(len(techniques))
    width = 0.35
    
    ax2 = ax.twinx()
    
    bars1 = ax.bar(x - width/2, throughput, width, label='Throughput', color='lightgreen')
    bars2 = ax2.bar(x + width/2, latency, width, label='Latency', color='lightcoral')
    
    ax.set_xlabel('Optimization Technique', fontsize=12)
    ax.set_ylabel('Throughput (req/s)', fontsize=12, color='green')
    ax2.set_ylabel('Latency (ms)', fontsize=12, color='red')
    ax.set_title('Performance Impact of Optimizations', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(techniques)
    
    # Legends
    ax.legend(loc='upper left')
    ax2.legend(loc='upper right')
    
    ax.grid(True, alpha=0.3)
    
    # 4. Deployment checklist
    ax = axes[1, 1]
    ax.axis('off')
    
    checklist = """
    ✅ Deployment Optimization Checklist:
    
    🔧 Model Optimization:
    □ Quantization (INT8/INT4)
    □ Pruning (structured/unstructured)
    □ Knowledge distillation
    □ Operator fusion
    
    🚀 Inference Optimization:
    □ KV caching for generation
    □ Flash/Memory-efficient attention
    □ Dynamic batching
    □ Continuous batching
    
    🖥️ Infrastructure:
    □ GPU/Hardware selection
    □ Model parallelism setup
    □ Load balancing
    □ Auto-scaling policies
    
    📊 Monitoring:
    □ Latency tracking (p50, p95, p99)
    □ Throughput metrics
    □ GPU utilization
    □ Memory usage
    □ Error rates
    """
    
    ax.text(0.05, 0.95, checklist, transform=ax.transAxes,
           fontsize=10, verticalalignment='top', family='monospace',
           bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

visualize_deployment_strategies()

# Performance benchmarking example
print("\n--- Performance Benchmarking Example ---")

def benchmark_optimization(model_size_mb=1000, optimization='none'):
    """Simulate performance metrics for different optimizations."""
    
    # Base metrics
    base_latency = model_size_mb * 0.05  # ms
    base_throughput = 1000 / base_latency  # req/s
    base_memory = model_size_mb  # MB
    
    # Apply optimization effects
    optimizations = {
        'none': {'latency': 1.0, 'throughput': 1.0, 'memory': 1.0, 'quality': 1.0},
        'int8': {'latency': 0.6, 'throughput': 1.8, 'memory': 0.25, 'quality': 0.98},
        'int4': {'latency': 0.5, 'throughput': 2.2, 'memory': 0.125, 'quality': 0.95},
        'pruning': {'latency': 0.7, 'throughput': 1.5, 'memory': 0.6, 'quality': 0.97},
        'distillation': {'latency': 0.2, 'throughput': 6.0, 'memory': 0.17, 'quality': 0.93},
        'combined': {'latency': 0.15, 'throughput': 10.0, 'memory': 0.1, 'quality': 0.90}
    }
    
    opt = optimizations[optimization]
    
    return {
        'optimization': optimization,
        'latency_ms': base_latency * opt['latency'],
        'throughput_rps': base_throughput * opt['throughput'],
        'memory_mb': base_memory * opt['memory'],
        'quality_retention': opt['quality'] * 100,
        'speedup': 1 / opt['latency'],
        'compression': 1 / opt['memory']
    }

# Benchmark different optimizations
print(f"Benchmarking 1GB model with different optimizations:\n")

results = []
for opt in ['none', 'int8', 'int4', 'pruning', 'distillation', 'combined']:
    metrics = benchmark_optimization(1000, opt)
    results.append(metrics)

df_benchmark = pd.DataFrame(results)
display(df_benchmark.round(1))

# Cost analysis
print("\n--- Deployment Cost Analysis ---")

gpu_cost_per_hour = 2.0  # $/hour for A100
requests_per_day = 1_000_000

print(f"Assumptions: {requests_per_day:,} requests/day, ${gpu_cost_per_hour}/GPU-hour\n")

for _, row in df_benchmark.iterrows():
    # Calculate required GPUs
    rps_per_gpu = row['throughput_rps']
    gpus_needed = np.ceil(requests_per_day / (rps_per_gpu * 86400))
    daily_cost = gpus_needed * gpu_cost_per_hour * 24
    
    print(f"{row['optimization']:12} | GPUs: {int(gpus_needed):2} | "
          f"Cost: ${daily_cost:6.0f}/day | "
          f"Quality: {row['quality_retention']:.0f}%")

## 8. Hardware-Specific Optimizations

Different hardware platforms require different optimization strategies.

In [None]:
# Hardware optimization guide
hardware_optimizations = pd.DataFrame({
    'Platform': ['NVIDIA GPU', 'TPU', 'CPU', 'Mobile/Edge', 'Apple Silicon'],
    'Key Optimizations': [
        'TensorRT, FP16, Flash Attention, CUDA graphs',
        'XLA compilation, bfloat16, TPU embeddings',
        'ONNX Runtime, INT8, OpenVINO, operator fusion',
        'TFLite, Core ML, NNAPI, extreme quantization',
        'Core ML, Metal Performance Shaders, ANE'
    ],
    'Best Practices': [
        'Use mixed precision, optimize memory transfers',
        'Batch efficiently, use TPU-specific ops',
        'Vectorize operations, use MKL-DNN',
        'Model splitting, on-device caching',
        'Leverage unified memory, use Metal'
    ],
    'Typical Speedup': ['5-10x', '10-20x', '2-5x', '10-50x', '5-15x']
})

print("\n🔧 Hardware-Specific Optimization Guide")
print("=" * 80)
display(hardware_optimizations)

# Memory bandwidth analysis
print("\n📊 Memory Bandwidth Requirements")

def calculate_bandwidth_requirements(model_size_gb, batch_size, seq_len, dtype_bytes=2):
    """Calculate memory bandwidth requirements for transformer inference."""
    
    # Model weights read
    weights_read = model_size_gb * 1e9 * dtype_bytes
    
    # Activations (rough estimate)
    hidden_size = int((model_size_gb * 1e9 / 100) ** 0.5)
    num_layers = 32  # typical
    
    activation_memory = (
        batch_size * seq_len * hidden_size * num_layers * dtype_bytes * 4
    )
    
    # KV cache
    kv_memory = batch_size * seq_len * hidden_size * num_layers * 2 * dtype_bytes
    
    total_memory = weights_read + activation_memory + kv_memory
    
    # Assume 100ms latency target
    bandwidth_gbps = total_memory / 1e9 / 0.1
    
    return {
        'weights_gb': weights_read / 1e9,
        'activations_gb': activation_memory / 1e9,
        'kv_cache_gb': kv_memory / 1e9,
        'total_gb': total_memory / 1e9,
        'bandwidth_gbps': bandwidth_gbps
    }

# Compare different scenarios
scenarios = [
    ('7B model, batch=1', 7, 1, 2048),
    ('7B model, batch=32', 7, 32, 2048),
    ('70B model, batch=1', 70, 1, 2048),
    ('70B model, batch=8', 70, 8, 2048)
]

bandwidth_results = []
for name, size, batch, seq in scenarios:
    reqs = calculate_bandwidth_requirements(size, batch, seq)
    reqs['scenario'] = name
    bandwidth_results.append(reqs)

df_bandwidth = pd.DataFrame(bandwidth_results)
df_bandwidth = df_bandwidth[['scenario', 'weights_gb', 'activations_gb', 
                           'kv_cache_gb', 'total_gb', 'bandwidth_gbps']]

print("\nMemory bandwidth requirements (100ms latency target):")
display(df_bandwidth.round(1))

# Hardware capabilities
print("\n🖥️ Hardware Memory Bandwidth Capabilities:")
hardware_bandwidth = {
    'CPU (DDR4)': 100,
    'V100': 900,
    'A100': 1555,
    'H100': 3350,
    'TPU v4': 1200
}

for hw, bw in hardware_bandwidth.items():
    print(f"  {hw:<15}: {bw:>5} GB/s")

## 9. End-to-End Optimization Pipeline

Let's create a complete optimization pipeline for a model.

In [None]:
class OptimizationPipeline:
    """Complete optimization pipeline for deployment."""
    
    def __init__(self, model_name: str = "transformer-base"):
        self.model_name = model_name
        self.optimization_steps = []
        
    def add_optimization(self, name: str, speedup: float, compression: float, 
                        quality_impact: float):
        """Add an optimization step."""
        self.optimization_steps.append({
            'name': name,
            'speedup': speedup,
            'compression': compression,
            'quality_impact': quality_impact
        })
        
    def analyze(self, base_params: int = 1e9, base_latency: float = 100):
        """Analyze cumulative effects of optimizations."""
        results = [{
            'step': 'Baseline',
            'params': base_params,
            'latency_ms': base_latency,
            'speedup': 1.0,
            'compression': 1.0,
            'quality': 100.0
        }]
        
        current_params = base_params
        current_latency = base_latency
        current_quality = 100.0
        cumulative_speedup = 1.0
        cumulative_compression = 1.0
        
        for opt in self.optimization_steps:
            current_params /= opt['compression']
            current_latency /= opt['speedup']
            current_quality *= opt['quality_impact']
            cumulative_speedup *= opt['speedup']
            cumulative_compression *= opt['compression']
            
            results.append({
                'step': opt['name'],
                'params': current_params,
                'latency_ms': current_latency,
                'speedup': cumulative_speedup,
                'compression': cumulative_compression,
                'quality': current_quality
            })
            
        return pd.DataFrame(results)

# Create optimization pipeline
pipeline = OptimizationPipeline("BERT-Large")

# Add optimization steps
pipeline.add_optimization("Pruning (50% sparsity)", speedup=1.5, compression=1.8, 
                         quality_impact=0.99)
pipeline.add_optimization("INT8 Quantization", speedup=1.8, compression=4.0, 
                         quality_impact=0.98)
pipeline.add_optimization("Knowledge Distillation", speedup=3.0, compression=6.0, 
                         quality_impact=0.96)
pipeline.add_optimization("Flash Attention", speedup=2.0, compression=1.0, 
                         quality_impact=1.0)
pipeline.add_optimization("Operator Fusion", speedup=1.2, compression=1.0, 
                         quality_impact=1.0)

# Analyze pipeline
results = pipeline.analyze(base_params=340e6, base_latency=50)

print("\n🚀 Optimization Pipeline Analysis")
print("=" * 80)
display(results.round(2))

# Visualize optimization progression
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Speedup and quality
ax1.plot(results['step'], results['speedup'], 'g-o', label='Speedup', markersize=8)
ax1_twin = ax1.twinx()
ax1_twin.plot(results['step'], results['quality'], 'r-s', label='Quality', markersize=8)

ax1.set_xlabel('Optimization Step', fontsize=12)
ax1.set_ylabel('Speedup (x)', fontsize=12, color='green')
ax1_twin.set_ylabel('Quality (%)', fontsize=12, color='red')
ax1.set_title('Optimization Impact on Speed and Quality', fontsize=14)
ax1.tick_params(axis='x', rotation=45)
ax1.grid(True, alpha=0.3)

# Add annotations
for i, row in results.iterrows():
    if i > 0:  # Skip baseline
        ax1.annotate(f'{row["speedup"]:.1f}x', 
                    (i, row['speedup']), 
                    xytext=(0, 10), textcoords='offset points', 
                    fontsize=9, ha='center')

# Memory and latency
ax2.bar(results['step'], results['params'] / 1e6, color='lightblue', alpha=0.7)
ax2_twin = ax2.twinx()
ax2_twin.plot(results['step'], results['latency_ms'], 'ko-', markersize=8)

ax2.set_xlabel('Optimization Step', fontsize=12)
ax2.set_ylabel('Model Size (M params)', fontsize=12, color='blue')
ax2_twin.set_ylabel('Latency (ms)', fontsize=12)
ax2.set_title('Model Size and Latency Progression', fontsize=14)
ax2.tick_params(axis='x', rotation=45)
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Final summary
final = results.iloc[-1]
original = results.iloc[0]

print(f"\n✅ Final Optimization Results:")
print(f"  Total Speedup: {final['speedup']:.1f}x")
print(f"  Total Compression: {final['compression']:.1f}x")
print(f"  Quality Retention: {final['quality']:.1f}%")
print(f"  Latency: {original['latency_ms']:.1f}ms → {final['latency_ms']:.1f}ms")
print(f"  Model Size: {original['params']/1e6:.0f}M → {final['params']/1e6:.1f}M params")

## 10. Summary and Best Practices

Let's summarize the key optimization techniques and provide practical recommendations.

In [None]:
# Create optimization decision matrix
print("\n🎯 Optimization Technique Selection Guide")
print("=" * 80)

decision_matrix = pd.DataFrame({
    'Technique': ['Pruning', 'INT8 Quantization', 'INT4 Quantization', 
                 'Knowledge Distillation', 'Flash Attention', 'KV Caching'],
    'Use When': [
        'Model has redundancy, structured pruning preferred',
        'Need 4x memory reduction with minimal quality loss',
        'Extreme memory constraints, can tolerate 2-5% quality loss',
        'Can train smaller model, have teacher model available',
        'Long sequences (>1K tokens), memory-bound',
        'Autoregressive generation, multiple tokens per sequence'
    ],
    'Speedup': ['1.5-2x', '1.5-3x', '2-4x', '5-10x', '2-10x', '5-100x'],
    'Quality Impact': ['< 1%', '1-2%', '2-5%', '2-10%', '0%', '0%'],
    'Implementation': ['Medium', 'Easy', 'Medium', 'Hard', 'Hard', 'Easy']
})

display(decision_matrix)

# Best practices summary
print("\n📋 Optimization Best Practices:")
print("=" * 60)

best_practices = [
    "1. **Profile First**: Identify bottlenecks before optimizing",
    "2. **Combine Techniques**: Use multiple optimizations together",
    "3. **Measure Quality**: Always validate model quality after optimization",
    "4. **Hardware-Aware**: Optimize for your specific deployment hardware",
    "5. **Iterative Approach**: Start with least invasive optimizations",
    "6. **Monitor Production**: Track latency, throughput, and errors",
    "7. **Cache Strategically**: Use KV cache for generation workloads",
    "8. **Batch Dynamically**: Implement dynamic batching for better utilization"
]

for practice in best_practices:
    print(f"\n{practice}")

# Optimization workflow
print("\n\n🔄 Recommended Optimization Workflow:")
print("=" * 60)

workflow = """
1. Baseline Benchmarking
   └─ Profile model performance
   └─ Identify bottlenecks
   
2. Quick Wins
   └─ Mixed precision (FP16)
   └─ Operator fusion
   └─ Batch size optimization
   
3. Memory Optimization
   └─ Quantization (INT8 first)
   └─ KV caching
   └─ Memory-efficient attention
   
4. Model Compression
   └─ Pruning (if applicable)
   └─ Distillation (if feasible)
   
5. Deployment Optimization
   └─ Dynamic batching
   └─ Model serving framework
   └─ Hardware-specific optimizations
   
6. Production Monitoring
   └─ A/B testing
   └─ Performance tracking
   └─ Continuous optimization
"""

print(workflow)

print("\n\n✅ You're now ready to optimize transformer models for production deployment!")
print("Remember: Measure twice, optimize once! 📊")