# Inference at Scale: Interactive Tutorial

This notebook provides hands-on experience with serving transformer models at scale, including batching strategies, optimization techniques, and production deployment patterns.

## 📋 Learning Objectives

- **Understand** production serving architectures
- **Implement** different batching strategies
- **Apply** caching and memory optimization
- **Design** load balancing and auto-scaling
- **Deploy** models for edge and cloud environments

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import asyncio
import threading
from collections import defaultdict, deque
from typing import Dict, List, Optional, Tuple, Any
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('default')
sns.set_palette("husl")

print("✅ Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 1. Production Serving Fundamentals

Let's start by understanding the key metrics and challenges in serving models at scale.

In [None]:
class ProductionMetrics:
    """Track and visualize production serving metrics."""
    
    def __init__(self):
        self.metrics = {
            'latency': [],
            'throughput': [],
            'memory_usage': [],
            'cpu_usage': [],
            'error_rate': []
        }
        
    def record(self, metric_name: str, value: float):
        """Record a metric value."""
        if metric_name in self.metrics:
            self.metrics[metric_name].append(value)
            
    def get_summary(self):
        """Get summary statistics."""
        summary = {}
        for name, values in self.metrics.items():
            if values:
                summary[name] = {
                    'mean': np.mean(values),
                    'p50': np.percentile(values, 50),
                    'p95': np.percentile(values, 95),
                    'p99': np.percentile(values, 99)
                }
        return summary
        
    def plot_metrics(self):
        """Plot key metrics."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        # Latency histogram
        if self.metrics['latency']:
            axes[0, 0].hist(self.metrics['latency'], bins=30, alpha=0.7)
            axes[0, 0].set_title('Latency Distribution')
            axes[0, 0].set_xlabel('Latency (ms)')
            axes[0, 0].set_ylabel('Frequency')
            
        # Throughput over time
        if self.metrics['throughput']:
            axes[0, 1].plot(self.metrics['throughput'])
            axes[0, 1].set_title('Throughput Over Time')
            axes[0, 1].set_xlabel('Time Step')
            axes[0, 1].set_ylabel('Requests/sec')
            
        # Memory usage
        if self.metrics['memory_usage']:
            axes[0, 2].plot(self.metrics['memory_usage'], color='red')
            axes[0, 2].set_title('Memory Usage')
            axes[0, 2].set_xlabel('Time Step')
            axes[0, 2].set_ylabel('Memory (GB)')
            
        # CPU usage
        if self.metrics['cpu_usage']:
            axes[1, 0].plot(self.metrics['cpu_usage'], color='green')
            axes[1, 0].set_title('CPU Usage')
            axes[1, 0].set_xlabel('Time Step')
            axes[1, 0].set_ylabel('CPU %')
            
        # Error rate
        if self.metrics['error_rate']:
            axes[1, 1].plot(self.metrics['error_rate'], color='orange')
            axes[1, 1].set_title('Error Rate')
            axes[1, 1].set_xlabel('Time Step')
            axes[1, 1].set_ylabel('Error %')
            
        # SLA compliance
        if self.metrics['latency']:
            sla_threshold = 100  # 100ms SLA
            compliance = [1 if lat <= sla_threshold else 0 for lat in self.metrics['latency']]
            axes[1, 2].plot(np.cumsum(compliance) / np.arange(1, len(compliance) + 1) * 100)
            axes[1, 2].axhline(y=99.9, color='r', linestyle='--', label='99.9% SLA')
            axes[1, 2].set_title('SLA Compliance')
            axes[1, 2].set_xlabel('Request')
            axes[1, 2].set_ylabel('Compliance %')
            axes[1, 2].legend()
        
        plt.tight_layout()
        plt.show()

# Create metrics tracker
metrics = ProductionMetrics()

# Simulate production traffic
print("Simulating production traffic...")
for i in tqdm(range(1000)):
    # Simulate varying load
    base_latency = 50
    load_factor = 1 + np.sin(i / 100) * 0.5  # Periodic load
    
    latency = base_latency * load_factor + np.random.normal(0, 10)
    throughput = 1000 / max(latency, 1)  # Inverse relationship
    memory = 4 + np.random.normal(0, 0.5)  # 4GB baseline
    cpu = min(95, 30 + latency * 0.5 + np.random.normal(0, 5))
    error_rate = max(0, min(10, (latency - 100) * 0.1 + np.random.normal(0, 0.5)))
    
    metrics.record('latency', max(0, latency))
    metrics.record('throughput', max(0, throughput))
    metrics.record('memory_usage', max(0, memory))
    metrics.record('cpu_usage', max(0, cpu))
    metrics.record('error_rate', max(0, error_rate))

# Display summary
summary = metrics.get_summary()
print("\n📊 Production Metrics Summary:")
for metric, stats in summary.items():
    print(f"\n{metric.replace('_', ' ').title()}:")
    for stat, value in stats.items():
        print(f"  {stat}: {value:.2f}")

# Plot metrics
metrics.plot_metrics()

## 2. Batching Strategies

Batching is crucial for throughput optimization. Let's explore different strategies.

In [None]:
class BatchingComparison:
    """Compare different batching strategies."""
    
    def __init__(self):
        self.results = {}
        
    def static_batching(self, requests, batch_size=32):
        """Process requests with fixed batch size."""
        batches = []
        latencies = []
        
        for i in range(0, len(requests), batch_size):
            batch = requests[i:i + batch_size]
            
            # Wait for full batch (except last)
            if len(batch) < batch_size and i + batch_size < len(requests):
                continue
                
            # Simulate processing time
            max_seq_len = max(req['length'] for req in batch)
            process_time = max_seq_len * 0.001 + len(batch) * 0.01
            
            # Calculate latency for each request
            for req in batch:
                wait_time = req['arrival_time']
                total_latency = wait_time + process_time
                latencies.append(total_latency)
                
            batches.append({
                'size': len(batch),
                'max_length': max_seq_len,
                'process_time': process_time
            })
            
        return batches, latencies
        
    def dynamic_batching(self, requests, max_batch_size=32, max_wait_ms=50):
        """Process requests with dynamic batching."""
        batches = []
        latencies = []
        queue = []
        
        for req in requests:
            queue.append(req)
            
            # Check if should process batch
            should_process = (
                len(queue) >= max_batch_size or
                (queue and (req['arrival_time'] - queue[0]['arrival_time']) >= max_wait_ms)
            )
            
            if should_process:
                batch = queue[:max_batch_size]
                queue = queue[max_batch_size:]
                
                # Process batch
                max_seq_len = max(r['length'] for r in batch)
                process_time = max_seq_len * 0.001 + len(batch) * 0.01
                
                for r in batch:
                    wait_time = req['arrival_time'] - r['arrival_time']
                    total_latency = wait_time + process_time
                    latencies.append(total_latency)
                    
                batches.append({
                    'size': len(batch),
                    'max_length': max_seq_len,
                    'process_time': process_time
                })
                
        # Process remaining requests
        if queue:
            max_seq_len = max(r['length'] for r in queue)
            process_time = max_seq_len * 0.001 + len(queue) * 0.01
            
            for r in queue:
                wait_time = requests[-1]['arrival_time'] - r['arrival_time']
                total_latency = wait_time + process_time
                latencies.append(total_latency)
                
        return batches, latencies
        
    def bucket_batching(self, requests, bucket_sizes=[128, 256, 512]):
        """Group requests by sequence length buckets."""
        buckets = {size: [] for size in bucket_sizes}
        batches = []
        latencies = []
        
        # Sort into buckets
        for req in requests:
            bucket_size = min(b for b in bucket_sizes if b >= req['length'])
            buckets[bucket_size].append(req)
            
        # Process each bucket
        for bucket_size, bucket_requests in buckets.items():
            if not bucket_requests:
                continue
                
            # Process in batches of 32
            for i in range(0, len(bucket_requests), 32):
                batch = bucket_requests[i:i + 32]
                
                # All requests in bucket have similar length
                process_time = bucket_size * 0.001 + len(batch) * 0.01
                
                for req in batch:
                    latencies.append(process_time)  # Simplified
                    
                batches.append({
                    'size': len(batch),
                    'max_length': bucket_size,
                    'process_time': process_time
                })
                
        return batches, latencies
        
    def compare_strategies(self, num_requests=1000):
        """Compare all batching strategies."""
        # Generate realistic request pattern
        requests = []
        for i in range(num_requests):
            arrival_time = i * (50 + np.random.exponential(20))  # Poisson-like arrivals
            length = int(np.random.lognormal(5, 1))  # Log-normal sequence lengths
            length = min(max(length, 10), 1000)  # Clamp to reasonable range
            
            requests.append({
                'id': i,
                'arrival_time': arrival_time,
                'length': length
            })
            
        # Test strategies
        strategies = {
            'Static': lambda: self.static_batching(requests, 32),
            'Dynamic': lambda: self.dynamic_batching(requests, 32, 50),
            'Bucket': lambda: self.bucket_batching(requests, [128, 256, 512, 1024])
        }
        
        results = {}
        for name, strategy in strategies.items():
            print(f"Testing {name} batching...")
            batches, latencies = strategy()
            
            # Calculate metrics
            total_padding = sum(
                batch['max_length'] * batch['size'] - 
                sum(req['length'] for req in requests[:batch['size']])
                for batch in batches
            )
            
            results[name] = {
                'avg_latency': np.mean(latencies),
                'p95_latency': np.percentile(latencies, 95),
                'throughput': len(requests) / max(latencies) if latencies else 0,
                'num_batches': len(batches),
                'avg_batch_size': np.mean([b['size'] for b in batches]),
                'total_padding': total_padding,
                'latencies': latencies
            }
            
        return results, requests
        
    def visualize_comparison(self, results):
        """Visualize batching strategy comparison."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        strategies = list(results.keys())
        
        # Average latency
        latencies = [results[s]['avg_latency'] for s in strategies]
        bars1 = axes[0, 0].bar(strategies, latencies)
        axes[0, 0].set_title('Average Latency')
        axes[0, 0].set_ylabel('Latency (ms)')
        
        # P95 latency
        p95_latencies = [results[s]['p95_latency'] for s in strategies]
        bars2 = axes[0, 1].bar(strategies, p95_latencies, color='orange')
        axes[0, 1].set_title('P95 Latency')
        axes[0, 1].set_ylabel('Latency (ms)')
        
        # Throughput
        throughputs = [results[s]['throughput'] for s in strategies]
        bars3 = axes[0, 2].bar(strategies, throughputs, color='green')
        axes[0, 2].set_title('Throughput')
        axes[0, 2].set_ylabel('Requests/sec')
        
        # Batch size distribution
        avg_batch_sizes = [results[s]['avg_batch_size'] for s in strategies]
        bars4 = axes[1, 0].bar(strategies, avg_batch_sizes, color='red')
        axes[1, 0].set_title('Average Batch Size')
        axes[1, 0].set_ylabel('Batch Size')
        
        # Padding overhead
        paddings = [results[s]['total_padding'] for s in strategies]
        bars5 = axes[1, 1].bar(strategies, paddings, color='purple')
        axes[1, 1].set_title('Total Padding Tokens')
        axes[1, 1].set_ylabel('Padding Tokens')
        
        # Latency distribution
        for i, strategy in enumerate(strategies):
            axes[1, 2].hist(
                results[strategy]['latencies'][:100],  # Sample for visibility
                alpha=0.6,
                label=strategy,
                bins=20
            )
        axes[1, 2].set_title('Latency Distribution')
        axes[1, 2].set_xlabel('Latency (ms)')
        axes[1, 2].set_ylabel('Frequency')
        axes[1, 2].legend()
        
        plt.tight_layout()
        plt.show()
        
        # Print summary
        print("\n📊 Batching Strategy Comparison:")
        print("-" * 80)
        print(f"{'Strategy':<12} {'Avg Latency':<12} {'P95 Latency':<12} {'Throughput':<12} {'Avg Batch':<12}")
        print("-" * 80)
        
        for strategy in strategies:
            r = results[strategy]
            print(f"{strategy:<12} {r['avg_latency']:<12.2f} {r['p95_latency']:<12.2f} "
                  f"{r['throughput']:<12.2f} {r['avg_batch_size']:<12.2f}")

# Run batching comparison
comparison = BatchingComparison()
results, requests = comparison.compare_strategies(1000)
comparison.visualize_comparison(results)

## 3. KV Cache Management

For generation tasks, KV cache is critical for performance.

In [None]:
class KVCacheDemo:
    """Demonstrate KV cache efficiency."""
    
    def __init__(self):
        self.cache_stats = defaultdict(list)
        
    def simulate_generation_without_cache(self, seq_length=100, vocab_size=30000):
        """Simulate generation without KV cache."""
        hidden_size = 768
        num_heads = 12
        head_dim = hidden_size // num_heads
        
        # Start with prompt
        input_ids = torch.randint(0, vocab_size, (1, 10))
        
        total_flops = 0
        memory_usage = []
        
        for step in range(seq_length):
            current_length = 10 + step
            
            # Without cache: recompute everything each time
            # Attention computation: O(L²) where L is sequence length
            attention_flops = current_length ** 2 * hidden_size
            total_flops += attention_flops
            
            # Memory: store full attention matrices
            attention_memory = current_length ** 2 * num_heads * 4  # bytes
            memory_usage.append(attention_memory / 1e6)  # MB
            
            self.cache_stats['without_cache_flops'].append(attention_flops)
            
        return total_flops, memory_usage
        
    def simulate_generation_with_cache(self, seq_length=100, vocab_size=30000):
        """Simulate generation with KV cache."""
        hidden_size = 768
        num_heads = 12
        head_dim = hidden_size // num_heads
        
        # Start with prompt
        input_ids = torch.randint(0, vocab_size, (1, 10))
        prompt_length = 10
        
        total_flops = 0
        memory_usage = []
        
        # Initial computation for prompt
        initial_flops = prompt_length ** 2 * hidden_size
        total_flops += initial_flops
        
        # Cache memory: store K,V for all positions
        cache_memory_per_token = num_heads * head_dim * 2 * 4  # K+V, float32
        
        for step in range(seq_length):
            current_length = prompt_length + step
            
            # With cache: only compute attention for new token
            # New token attends to all previous tokens: O(L)
            attention_flops = current_length * hidden_size
            total_flops += attention_flops
            
            # Memory: cache grows linearly
            cache_memory = current_length * cache_memory_per_token
            memory_usage.append(cache_memory / 1e6)  # MB
            
            self.cache_stats['with_cache_flops'].append(attention_flops)
            
        return total_flops, memory_usage
        
    def compare_cache_strategies(self):
        """Compare generation with and without cache."""
        seq_lengths = [50, 100, 200, 500]
        
        results = {
            'without_cache': {'flops': [], 'memory': [], 'time': []},
            'with_cache': {'flops': [], 'memory': [], 'time': []}
        }
        
        for seq_len in seq_lengths:
            print(f"Testing sequence length: {seq_len}")
            
            # Without cache
            start_time = time.time()
            flops_no_cache, memory_no_cache = self.simulate_generation_without_cache(seq_len)
            time_no_cache = time.time() - start_time
            
            # With cache
            start_time = time.time()
            flops_with_cache, memory_with_cache = self.simulate_generation_with_cache(seq_len)
            time_with_cache = time.time() - start_time
            
            results['without_cache']['flops'].append(flops_no_cache)
            results['without_cache']['memory'].append(max(memory_no_cache))
            results['without_cache']['time'].append(time_no_cache)
            
            results['with_cache']['flops'].append(flops_with_cache)
            results['with_cache']['memory'].append(max(memory_with_cache))
            results['with_cache']['time'].append(time_with_cache)
            
        return results, seq_lengths
        
    def visualize_cache_efficiency(self, results, seq_lengths):
        """Visualize KV cache benefits."""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # FLOPS comparison
        axes[0, 0].plot(seq_lengths, results['without_cache']['flops'], 
                       'ro-', label='Without Cache', linewidth=2)
        axes[0, 0].plot(seq_lengths, results['with_cache']['flops'], 
                       'go-', label='With Cache', linewidth=2)
        axes[0, 0].set_xlabel('Sequence Length')
        axes[0, 0].set_ylabel('Total FLOPs')
        axes[0, 0].set_title('Computation Complexity')
        axes[0, 0].legend()
        axes[0, 0].set_yscale('log')
        
        # Memory usage
        axes[0, 1].plot(seq_lengths, results['without_cache']['memory'], 
                       'ro-', label='Without Cache', linewidth=2)
        axes[0, 1].plot(seq_lengths, results['with_cache']['memory'], 
                       'go-', label='With Cache', linewidth=2)
        axes[0, 1].set_xlabel('Sequence Length')
        axes[0, 1].set_ylabel('Peak Memory (MB)')
        axes[0, 1].set_title('Memory Usage')
        axes[0, 1].legend()
        
        # Speedup
        speedups = [no_cache / with_cache for no_cache, with_cache in 
                   zip(results['without_cache']['flops'], results['with_cache']['flops'])]
        axes[1, 0].bar(range(len(seq_lengths)), speedups, color='blue', alpha=0.7)
        axes[1, 0].set_xlabel('Sequence Length')
        axes[1, 0].set_ylabel('Speedup Factor')
        axes[1, 0].set_title('KV Cache Speedup')
        axes[1, 0].set_xticks(range(len(seq_lengths)))
        axes[1, 0].set_xticklabels(seq_lengths)
        
        # Per-step FLOPS comparison
        if self.cache_stats['without_cache_flops'] and self.cache_stats['with_cache_flops']:
            steps = range(min(50, len(self.cache_stats['without_cache_flops'])))
            axes[1, 1].plot(steps, self.cache_stats['without_cache_flops'][:50], 
                           'ro-', label='Without Cache', alpha=0.7)
            axes[1, 1].plot(steps, self.cache_stats['with_cache_flops'][:50], 
                           'go-', label='With Cache', alpha=0.7)
            axes[1, 1].set_xlabel('Generation Step')
            axes[1, 1].set_ylabel('FLOPs per Step')
            axes[1, 1].set_title('Per-Step Computation')
            axes[1, 1].legend()
            axes[1, 1].set_yscale('log')
        
        plt.tight_layout()
        plt.show()
        
        # Print efficiency summary
        print("\n🚀 KV Cache Efficiency Summary:")
        print("-" * 60)
        print(f"{'Seq Length':<12} {'Speedup':<12} {'Memory Ratio':<15} {'FLOP Ratio':<12}")
        print("-" * 60)
        
        for i, seq_len in enumerate(seq_lengths):
            speedup = results['without_cache']['flops'][i] / results['with_cache']['flops'][i]
            memory_ratio = results['with_cache']['memory'][i] / results['without_cache']['memory'][i]
            flop_ratio = results['with_cache']['flops'][i] / results['without_cache']['flops'][i]
            
            print(f"{seq_len:<12} {speedup:<12.2f} {memory_ratio:<15.2f} {flop_ratio:<12.4f}")

# Run KV cache demo
cache_demo = KVCacheDemo()
results, seq_lengths = cache_demo.compare_cache_strategies()
cache_demo.visualize_cache_efficiency(results, seq_lengths)

## 4. Load Balancing and Auto-scaling

Let's explore how to distribute load and automatically scale based on demand.

In [None]:
class LoadBalancingDemo:
    """Demonstrate load balancing strategies."""
    
    def __init__(self, num_servers=4):
        self.num_servers = num_servers
        self.server_loads = [0] * num_servers
        self.server_latencies = [[] for _ in range(num_servers)]
        self.server_capacities = [100 + i * 20 for i in range(num_servers)]  # Varying capacity
        self.request_history = []
        
    def round_robin(self, request_id):
        """Simple round-robin load balancing."""
        server_id = request_id % self.num_servers
        return server_id
        
    def least_loaded(self, request_id):
        """Route to least loaded server."""
        server_id = self.server_loads.index(min(self.server_loads))
        return server_id
        
    def weighted_least_loaded(self, request_id):
        """Route based on load and capacity."""
        # Calculate load ratios
        load_ratios = [
            load / capacity 
            for load, capacity in zip(self.server_loads, self.server_capacities)
        ]
        server_id = load_ratios.index(min(load_ratios))
        return server_id
        
    def latency_aware(self, request_id):
        """Route based on recent latencies."""
        avg_latencies = []
        for latencies in self.server_latencies:
            if latencies:
                avg = sum(latencies[-10:]) / min(len(latencies), 10)
            else:
                avg = 0
            avg_latencies.append(avg)
            
        server_id = avg_latencies.index(min(avg_latencies))
        return server_id
        
    def simulate_request(self, server_id, request_size=1):
        """Simulate processing a request on a server."""
        # Update load
        self.server_loads[server_id] += request_size
        
        # Simulate latency based on load
        load_factor = self.server_loads[server_id] / self.server_capacities[server_id]
        base_latency = 50  # ms
        latency = base_latency * (1 + load_factor ** 2) + np.random.normal(0, 5)
        latency = max(10, latency)  # Minimum latency
        
        # Record latency
        self.server_latencies[server_id].append(latency)
        
        # Process request (reduce load)
        self.server_loads[server_id] = max(0, self.server_loads[server_id] - request_size)
        
        return latency
        
    def test_strategy(self, strategy_name, strategy_func, num_requests=1000):
        """Test a load balancing strategy."""
        # Reset state
        self.server_loads = [0] * self.num_servers
        self.server_latencies = [[] for _ in range(self.num_servers)]
        
        latencies = []
        server_assignments = []
        
        for i in range(num_requests):
            # Generate request with varying size
            request_size = 1 + np.random.poisson(0.5)  # 1-3 units typically
            
            # Route request
            server_id = strategy_func(i)
            server_assignments.append(server_id)
            
            # Process request
            latency = self.simulate_request(server_id, request_size)
            latencies.append(latency)
            
        return {
            'strategy': strategy_name,
            'latencies': latencies,
            'server_assignments': server_assignments,
            'avg_latency': np.mean(latencies),
            'p95_latency': np.percentile(latencies, 95),
            'server_utilization': [
                len([s for s in server_assignments if s == i]) / num_requests
                for i in range(self.num_servers)
            ]
        }
        
    def compare_strategies(self):
        """Compare all load balancing strategies."""
        strategies = {
            'Round Robin': self.round_robin,
            'Least Loaded': self.least_loaded,
            'Weighted Least Loaded': self.weighted_least_loaded,
            'Latency Aware': self.latency_aware
        }
        
        results = {}
        for name, func in strategies.items():
            print(f"Testing {name} strategy...")
            results[name] = self.test_strategy(name, func)
            
        return results
        
    def visualize_load_balancing(self, results):
        """Visualize load balancing results."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        strategies = list(results.keys())
        
        # Average latency
        avg_latencies = [results[s]['avg_latency'] for s in strategies]
        bars1 = axes[0, 0].bar(strategies, avg_latencies, color='skyblue')
        axes[0, 0].set_title('Average Latency by Strategy')
        axes[0, 0].set_ylabel('Latency (ms)')
        axes[0, 0].tick_params(axis='x', rotation=45)
        
        # P95 latency
        p95_latencies = [results[s]['p95_latency'] for s in strategies]
        bars2 = axes[0, 1].bar(strategies, p95_latencies, color='lightcoral')
        axes[0, 1].set_title('P95 Latency by Strategy')
        axes[0, 1].set_ylabel('Latency (ms)')
        axes[0, 1].tick_params(axis='x', rotation=45)
        
        # Latency distribution
        for strategy in strategies:
            latencies = results[strategy]['latencies'][:200]  # Sample for visibility
            axes[0, 2].hist(latencies, alpha=0.6, label=strategy, bins=20)
        axes[0, 2].set_title('Latency Distribution')
        axes[0, 2].set_xlabel('Latency (ms)')
        axes[0, 2].set_ylabel('Frequency')
        axes[0, 2].legend()
        
        # Server utilization heatmap
        utilization_matrix = np.array([
            results[s]['server_utilization'] for s in strategies
        ])
        
        im = axes[1, 0].imshow(utilization_matrix, cmap='YlOrRd', aspect='auto')
        axes[1, 0].set_title('Server Utilization')
        axes[1, 0].set_xlabel('Server ID')
        axes[1, 0].set_ylabel('Strategy')
        axes[1, 0].set_yticks(range(len(strategies)))
        axes[1, 0].set_yticklabels(strategies)
        axes[1, 0].set_xticks(range(self.num_servers))
        plt.colorbar(im, ax=axes[1, 0], label='Utilization %')
        
        # Server capacity vs utilization
        for i, strategy in enumerate(strategies):
            utilization = results[strategy]['server_utilization']
            axes[1, 1].scatter(
                self.server_capacities, utilization, 
                label=strategy, s=60, alpha=0.7
            )
        axes[1, 1].set_title('Capacity vs Utilization')
        axes[1, 1].set_xlabel('Server Capacity')
        axes[1, 1].set_ylabel('Utilization %')
        axes[1, 1].legend()
        
        # Efficiency score (lower latency + balanced utilization)
        efficiency_scores = []
        for strategy in strategies:
            r = results[strategy]
            # Normalize metrics
            latency_score = 1 / (r['avg_latency'] / min(avg_latencies))
            balance_score = 1 / (np.std(r['server_utilization']) + 0.01)
            efficiency = latency_score * balance_score
            efficiency_scores.append(efficiency)
            
        bars6 = axes[1, 2].bar(strategies, efficiency_scores, color='lightgreen')
        axes[1, 2].set_title('Overall Efficiency')
        axes[1, 2].set_ylabel('Efficiency Score')
        axes[1, 2].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        # Print summary
        print("\n⚖️ Load Balancing Strategy Comparison:")
        print("-" * 90)
        print(f"{'Strategy':<20} {'Avg Latency':<12} {'P95 Latency':<12} {'Std Dev Util':<15} {'Best Server':<12}")
        print("-" * 90)
        
        for strategy in strategies:
            r = results[strategy]
            std_util = np.std(r['server_utilization'])
            best_server = np.argmax(r['server_utilization'])
            
            print(f"{strategy:<20} {r['avg_latency']:<12.2f} {r['p95_latency']:<12.2f} "
                  f"{std_util:<15.3f} {best_server:<12}")

# Run load balancing demo
lb_demo = LoadBalancingDemo(num_servers=4)
lb_results = lb_demo.compare_strategies()
lb_demo.visualize_load_balancing(lb_results)

## 5. Auto-scaling Simulation

Let's see how auto-scaling responds to changing load patterns.

In [None]:
class AutoScalingSimulation:
    """Simulate auto-scaling behavior."""
    
    def __init__(self, initial_instances=2, min_instances=1, max_instances=10):
        self.current_instances = initial_instances
        self.min_instances = min_instances
        self.max_instances = max_instances
        
        # Scaling thresholds
        self.scale_up_cpu = 70
        self.scale_down_cpu = 30
        self.scale_up_latency = 100  # ms
        self.scale_down_latency = 20  # ms
        
        # History tracking
        self.history = {
            'instances': [],
            'cpu_usage': [],
            'latency': [],
            'throughput': [],
            'cost': [],
            'scaling_events': []
        }
        
    def generate_traffic_pattern(self, time_steps=288):  # 24 hours in 5-min intervals
        """Generate realistic traffic pattern."""
        traffic = []
        
        for t in range(time_steps):
            hour = (t * 5 / 60) % 24  # Convert to hour of day
            
            # Base traffic pattern (higher during business hours)
            if 9 <= hour <= 17:  # Business hours
                base_traffic = 80 + 20 * np.sin((hour - 9) / 8 * np.pi)
            elif 18 <= hour <= 22:  # Evening
                base_traffic = 60 + 10 * np.sin((hour - 18) / 4 * np.pi)
            else:  # Night/early morning
                base_traffic = 20 + 10 * np.random.normal(0, 1)
                
            # Add weekly pattern (lower on weekends)
            day_of_week = (t // 288) % 7
            if day_of_week >= 5:  # Weekend
                base_traffic *= 0.6
                
            # Add random spikes
            if np.random.random() < 0.05:  # 5% chance of spike
                base_traffic *= (2 + np.random.random())
                
            # Add noise
            traffic.append(max(10, base_traffic + np.random.normal(0, 5)))
            
        return traffic
        
    def calculate_metrics(self, traffic_load):
        """Calculate system metrics based on traffic and instances."""
        # CPU usage (inversely related to instance count)
        cpu_per_instance = traffic_load / self.current_instances
        cpu_usage = min(100, cpu_per_instance)
        
        # Latency (increases with CPU usage)
        if cpu_usage < 50:
            latency = 30 + cpu_usage * 0.5
        elif cpu_usage < 80:
            latency = 55 + (cpu_usage - 50) * 1.5
        else:
            latency = 100 + (cpu_usage - 80) * 3  # Exponential increase
            
        # Add random variation
        latency += np.random.normal(0, 5)
        latency = max(10, latency)
        
        # Throughput (decreases with high latency)
        throughput = min(1000, traffic_load * self.current_instances / max(latency, 30) * 30)
        
        # Cost (linear with instances)
        cost_per_instance_per_hour = 2.0  # $2/hour per instance
        cost = self.current_instances * cost_per_instance_per_hour / 12  # 5-min interval
        
        return cpu_usage, latency, throughput, cost
        
    def make_scaling_decision(self, cpu_usage, latency):
        """Decide whether to scale up, down, or stay the same."""
        # Scale up conditions
        if (cpu_usage > self.scale_up_cpu or latency > self.scale_up_latency) and \
           self.current_instances < self.max_instances:
            return 1  # Scale up
            
        # Scale down conditions (more conservative)
        if cpu_usage < self.scale_down_cpu and latency < self.scale_down_latency and \
           self.current_instances > self.min_instances:
            return -1  # Scale down
            
        return 0  # No scaling
        
    def run_simulation(self):
        """Run the auto-scaling simulation."""
        traffic_pattern = self.generate_traffic_pattern()
        
        for t, traffic_load in enumerate(traffic_pattern):
            # Calculate current metrics
            cpu_usage, latency, throughput, cost = self.calculate_metrics(traffic_load)
            
            # Record metrics
            self.history['instances'].append(self.current_instances)
            self.history['cpu_usage'].append(cpu_usage)
            self.history['latency'].append(latency)
            self.history['throughput'].append(throughput)
            self.history['cost'].append(cost)
            
            # Make scaling decision
            scaling_decision = self.make_scaling_decision(cpu_usage, latency)
            
            if scaling_decision != 0:
                old_instances = self.current_instances
                self.current_instances += scaling_decision
                self.current_instances = max(self.min_instances, 
                                           min(self.max_instances, self.current_instances))
                
                self.history['scaling_events'].append({
                    'time': t,
                    'action': 'scale_up' if scaling_decision > 0 else 'scale_down',
                    'from': old_instances,
                    'to': self.current_instances,
                    'trigger': f"CPU: {cpu_usage:.1f}%, Latency: {latency:.1f}ms"
                })
                
        return self.history
        
    def visualize_simulation(self):
        """Visualize the auto-scaling simulation results."""
        time_hours = np.arange(len(self.history['instances'])) * 5 / 60  # Convert to hours
        
        fig, axes = plt.subplots(3, 2, figsize=(16, 14))
        
        # Instance count over time
        axes[0, 0].plot(time_hours, self.history['instances'], 'bo-', linewidth=2, markersize=3)
        axes[0, 0].set_title('Instance Count Over Time')
        axes[0, 0].set_xlabel('Time (hours)')
        axes[0, 0].set_ylabel('Number of Instances')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Mark scaling events
        for event in self.history['scaling_events']:
            time_point = event['time'] * 5 / 60
            color = 'green' if event['action'] == 'scale_up' else 'red'
            axes[0, 0].axvline(x=time_point, color=color, alpha=0.5, linestyle='--')
            
        # CPU usage
        axes[0, 1].plot(time_hours, self.history['cpu_usage'], color='orange', linewidth=1)
        axes[0, 1].axhline(y=self.scale_up_cpu, color='red', linestyle='--', label='Scale Up Threshold')
        axes[0, 1].axhline(y=self.scale_down_cpu, color='green', linestyle='--', label='Scale Down Threshold')
        axes[0, 1].set_title('CPU Usage')
        axes[0, 1].set_xlabel('Time (hours)')
        axes[0, 1].set_ylabel('CPU Usage (%)')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Latency
        axes[1, 0].plot(time_hours, self.history['latency'], color='red', linewidth=1)
        axes[1, 0].axhline(y=self.scale_up_latency, color='red', linestyle='--', label='Scale Up Threshold')
        axes[1, 0].axhline(y=self.scale_down_latency, color='green', linestyle='--', label='Scale Down Threshold')
        axes[1, 0].set_title('Latency')
        axes[1, 0].set_xlabel('Time (hours)')
        axes[1, 0].set_ylabel('Latency (ms)')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Throughput
        axes[1, 1].plot(time_hours, self.history['throughput'], color='green', linewidth=1)
        axes[1, 1].set_title('Throughput')
        axes[1, 1].set_xlabel('Time (hours)')
        axes[1, 1].set_ylabel('Requests/sec')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Cumulative cost
        cumulative_cost = np.cumsum(self.history['cost'])
        axes[2, 0].plot(time_hours, cumulative_cost, color='purple', linewidth=2)
        axes[2, 0].set_title('Cumulative Cost')
        axes[2, 0].set_xlabel('Time (hours)')
        axes[2, 0].set_ylabel('Cost ($)')
        axes[2, 0].grid(True, alpha=0.3)
        
        # Scaling events timeline
        scale_up_times = []
        scale_down_times = []
        
        for event in self.history['scaling_events']:
            time_point = event['time'] * 5 / 60
            if event['action'] == 'scale_up':
                scale_up_times.append(time_point)
            else:
                scale_down_times.append(time_point)
                
        if scale_up_times:
            axes[2, 1].scatter(scale_up_times, [1] * len(scale_up_times), 
                             color='green', s=50, label='Scale Up', alpha=0.7)
        if scale_down_times:
            axes[2, 1].scatter(scale_down_times, [0] * len(scale_down_times), 
                             color='red', s=50, label='Scale Down', alpha=0.7)
            
        axes[2, 1].set_title('Scaling Events')
        axes[2, 1].set_xlabel('Time (hours)')
        axes[2, 1].set_ylabel('Event Type')
        axes[2, 1].set_yticks([0, 1])
        axes[2, 1].set_yticklabels(['Scale Down', 'Scale Up'])
        axes[2, 1].legend()
        axes[2, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print summary statistics
        total_cost = sum(self.history['cost'])
        avg_latency = np.mean(self.history['latency'])
        p95_latency = np.percentile(self.history['latency'], 95)
        avg_cpu = np.mean(self.history['cpu_usage'])
        num_scale_events = len(self.history['scaling_events'])
        
        print("\n📈 Auto-scaling Simulation Summary:")
        print("-" * 50)
        print(f"Total simulation time: {len(self.history['instances']) * 5 / 60:.1f} hours")
        print(f"Total cost: ${total_cost:.2f}")
        print(f"Average latency: {avg_latency:.2f} ms")
        print(f"P95 latency: {p95_latency:.2f} ms")
        print(f"Average CPU usage: {avg_cpu:.1f}%")
        print(f"Number of scaling events: {num_scale_events}")
        print(f"Instance range: {min(self.history['instances'])} - {max(self.history['instances'])}")
        
        # Show recent scaling events
        if self.history['scaling_events']:
            print("\n🔄 Recent Scaling Events:")
            for event in self.history['scaling_events'][-5:]:  # Last 5 events
                time_str = f"{event['time'] * 5 / 60:.1f}h"
                print(f"  {time_str}: {event['action']} ({event['from']} → {event['to']}) - {event['trigger']}")

# Run auto-scaling simulation
print("Running auto-scaling simulation...")
scaling_sim = AutoScalingSimulation(initial_instances=2, min_instances=1, max_instances=8)
history = scaling_sim.run_simulation()
scaling_sim.visualize_simulation()

## 6. Edge Deployment Optimization

Let's explore optimizations specifically for edge deployment scenarios.

In [None]:
class EdgeOptimizationDemo:
    """Demonstrate edge deployment optimizations."""
    
    def __init__(self):
        self.optimization_results = {}
        
    def simulate_model_sizes(self):
        """Simulate different model compression techniques."""
        # Base model stats
        base_model = {
            'size_mb': 1200,  # 1.2GB model
            'latency_ms': 150,
            'accuracy': 95.5,
            'memory_mb': 2400
        }
        
        optimizations = {
            'Original': {
                'size_mb': base_model['size_mb'],
                'latency_ms': base_model['latency_ms'],
                'accuracy': base_model['accuracy'],
                'memory_mb': base_model['memory_mb']
            },
            'Quantization (INT8)': {
                'size_mb': base_model['size_mb'] * 0.25,  # 4x smaller
                'latency_ms': base_model['latency_ms'] * 0.7,  # 30% faster
                'accuracy': base_model['accuracy'] - 0.5,  # Slight accuracy drop
                'memory_mb': base_model['memory_mb'] * 0.4
            },
            'Pruning (50%)': {
                'size_mb': base_model['size_mb'] * 0.5,  # 50% smaller
                'latency_ms': base_model['latency_ms'] * 0.6,  # 40% faster
                'accuracy': base_model['accuracy'] - 1.0,  # Moderate accuracy drop
                'memory_mb': base_model['memory_mb'] * 0.6
            },
            'Distillation': {
                'size_mb': base_model['size_mb'] * 0.3,  # 70% smaller
                'latency_ms': base_model['latency_ms'] * 0.4,  # 60% faster
                'accuracy': base_model['accuracy'] - 2.0,  # Larger accuracy drop
                'memory_mb': base_model['memory_mb'] * 0.5
            },
            'Quantization + Pruning': {
                'size_mb': base_model['size_mb'] * 0.125,  # 8x smaller
                'latency_ms': base_model['latency_ms'] * 0.45,  # 55% faster
                'accuracy': base_model['accuracy'] - 1.8,  # Combined accuracy drop
                'memory_mb': base_model['memory_mb'] * 0.3
            },
            'Mobile Optimized': {
                'size_mb': base_model['size_mb'] * 0.08,  # 12.5x smaller
                'latency_ms': base_model['latency_ms'] * 0.3,  # 70% faster
                'accuracy': base_model['accuracy'] - 3.5,  # Significant accuracy drop
                'memory_mb': base_model['memory_mb'] * 0.2
            }
        }
        
        return optimizations
        
    def simulate_device_constraints(self):
        """Simulate different edge device capabilities."""
        devices = {
            'High-End Mobile': {
                'max_memory_mb': 8000,
                'max_model_size_mb': 500,
                'target_latency_ms': 100,
                'power_budget_w': 5
            },
            'Mid-Range Mobile': {
                'max_memory_mb': 4000,
                'max_model_size_mb': 200,
                'target_latency_ms': 200,
                'power_budget_w': 3
            },
            'IoT Device': {
                'max_memory_mb': 1000,
                'max_model_size_mb': 50,
                'target_latency_ms': 500,
                'power_budget_w': 1
            },
            'Edge Server': {
                'max_memory_mb': 16000,
                'max_model_size_mb': 2000,
                'target_latency_ms': 50,
                'power_budget_w': 100
            }
        }
        
        return devices
        
    def check_compatibility(self, model_stats, device_constraints):
        """Check if model is compatible with device constraints."""
        compatible = True
        issues = []
        
        if model_stats['size_mb'] > device_constraints['max_model_size_mb']:
            compatible = False
            issues.append("Model too large")
            
        if model_stats['memory_mb'] > device_constraints['max_memory_mb']:
            compatible = False
            issues.append("Memory requirement too high")
            
        if model_stats['latency_ms'] > device_constraints['target_latency_ms']:
            compatible = False
            issues.append("Latency too high")
            
        return compatible, issues
        
    def analyze_edge_deployment(self):
        """Analyze which optimizations work for which devices."""
        models = self.simulate_model_sizes()
        devices = self.simulate_device_constraints()
        
        compatibility_matrix = {}
        
        for device_name, device_constraints in devices.items():
            compatibility_matrix[device_name] = {}
            
            for model_name, model_stats in models.items():
                compatible, issues = self.check_compatibility(model_stats, device_constraints)
                
                compatibility_matrix[device_name][model_name] = {
                    'compatible': compatible,
                    'issues': issues,
                    'efficiency_score': self._calculate_efficiency_score(
                        model_stats, device_constraints
                    )
                }
                
        return models, devices, compatibility_matrix
        
    def _calculate_efficiency_score(self, model_stats, device_constraints):
        """Calculate efficiency score for model-device combination."""
        # Normalize metrics
        size_score = min(1.0, device_constraints['max_model_size_mb'] / model_stats['size_mb'])
        memory_score = min(1.0, device_constraints['max_memory_mb'] / model_stats['memory_mb'])
        latency_score = min(1.0, device_constraints['target_latency_ms'] / model_stats['latency_ms'])
        accuracy_score = model_stats['accuracy'] / 100.0
        
        # Weighted combination
        efficiency = (size_score * 0.2 + memory_score * 0.2 + 
                     latency_score * 0.3 + accuracy_score * 0.3)
        
        return efficiency
        
    def visualize_edge_analysis(self, models, devices, compatibility_matrix):
        """Visualize edge deployment analysis."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        model_names = list(models.keys())
        device_names = list(devices.keys())
        
        # Model size vs accuracy trade-off
        sizes = [models[m]['size_mb'] for m in model_names]
        accuracies = [models[m]['accuracy'] for m in model_names]
        
        scatter = axes[0, 0].scatter(sizes, accuracies, c=range(len(model_names)), 
                                   s=100, cmap='viridis', alpha=0.7)
        
        for i, name in enumerate(model_names):
            axes[0, 0].annotate(name, (sizes[i], accuracies[i]), 
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        
        axes[0, 0].set_xlabel('Model Size (MB)')
        axes[0, 0].set_ylabel('Accuracy (%)')
        axes[0, 0].set_title('Size vs Accuracy Trade-off')
        axes[0, 0].set_xscale('log')
        
        # Latency vs memory usage
        latencies = [models[m]['latency_ms'] for m in model_names]
        memories = [models[m]['memory_mb'] for m in model_names]
        
        scatter2 = axes[0, 1].scatter(latencies, memories, c=range(len(model_names)), 
                                    s=100, cmap='viridis', alpha=0.7)
        
        for i, name in enumerate(model_names):
            axes[0, 1].annotate(name, (latencies[i], memories[i]), 
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        
        axes[0, 1].set_xlabel('Latency (ms)')
        axes[0, 1].set_ylabel('Memory Usage (MB)')
        axes[0, 1].set_title('Latency vs Memory Usage')
        
        # Compatibility heatmap
        compatibility_data = np.zeros((len(device_names), len(model_names)))
        
        for i, device in enumerate(device_names):
            for j, model in enumerate(model_names):
                compatibility_data[i, j] = compatibility_matrix[device][model]['compatible']
                
        im = axes[0, 2].imshow(compatibility_data, cmap='RdYlGn', aspect='auto')
        axes[0, 2].set_title('Model-Device Compatibility')
        axes[0, 2].set_xlabel('Model')
        axes[0, 2].set_ylabel('Device')
        axes[0, 2].set_xticks(range(len(model_names)))
        axes[0, 2].set_xticklabels(model_names, rotation=45, ha='right')
        axes[0, 2].set_yticks(range(len(device_names)))
        axes[0, 2].set_yticklabels(device_names)
        
        # Add text annotations
        for i in range(len(device_names)):
            for j in range(len(model_names)):
                text = '✓' if compatibility_data[i, j] else '✗'
                axes[0, 2].text(j, i, text, ha='center', va='center', 
                               color='white', fontsize=12, fontweight='bold')
        
        # Efficiency scores heatmap
        efficiency_data = np.zeros((len(device_names), len(model_names)))
        
        for i, device in enumerate(device_names):
            for j, model in enumerate(model_names):
                efficiency_data[i, j] = compatibility_matrix[device][model]['efficiency_score']
                
        im2 = axes[1, 0].imshow(efficiency_data, cmap='viridis', aspect='auto')
        axes[1, 0].set_title('Efficiency Scores')
        axes[1, 0].set_xlabel('Model')
        axes[1, 0].set_ylabel('Device')
        axes[1, 0].set_xticks(range(len(model_names)))
        axes[1, 0].set_xticklabels(model_names, rotation=45, ha='right')
        axes[1, 0].set_yticks(range(len(device_names)))
        axes[1, 0].set_yticklabels(device_names)
        plt.colorbar(im2, ax=axes[1, 0], label='Efficiency Score')
        
        # Device constraints comparison
        constraints = ['max_memory_mb', 'max_model_size_mb', 'target_latency_ms']
        x = np.arange(len(device_names))
        width = 0.25
        
        for i, constraint in enumerate(constraints):
            values = [devices[d][constraint] for d in device_names]
            # Normalize for visualization
            if constraint == 'max_memory_mb':
                values = [v / 1000 for v in values]  # Convert to GB
            elif constraint == 'max_model_size_mb':
                values = [v / 100 for v in values]  # Scale down
            elif constraint == 'target_latency_ms':
                values = [v / 100 for v in values]  # Scale down
                
            axes[1, 1].bar(x + i * width, values, width, 
                          label=constraint.replace('_', ' ').replace('mb', '').replace('ms', '').title())
        
        axes[1, 1].set_xlabel('Device Type')
        axes[1, 1].set_ylabel('Normalized Constraint Value')
        axes[1, 1].set_title('Device Constraints Comparison')
        axes[1, 1].set_xticks(x + width)
        axes[1, 1].set_xticklabels(device_names)
        axes[1, 1].legend()
        
        # Optimization effectiveness
        original_size = models['Original']['size_mb']
        compression_ratios = [original_size / models[m]['size_mb'] for m in model_names]
        accuracy_drops = [models['Original']['accuracy'] - models[m]['accuracy'] for m in model_names]
        
        scatter3 = axes[1, 2].scatter(compression_ratios, accuracy_drops, 
                                    c=range(len(model_names)), s=100, cmap='viridis', alpha=0.7)
        
        for i, name in enumerate(model_names):
            axes[1, 2].annotate(name, (compression_ratios[i], accuracy_drops[i]), 
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        
        axes[1, 2].set_xlabel('Compression Ratio')
        axes[1, 2].set_ylabel('Accuracy Drop (%)')
        axes[1, 2].set_title('Compression vs Accuracy Trade-off')
        axes[1, 2].set_xscale('log')
        
        plt.tight_layout()
        plt.show()
        
        # Print deployment recommendations
        print("\n📱 Edge Deployment Recommendations:")
        print("-" * 70)
        
        for device_name in device_names:
            print(f"\n{device_name}:")
            
            # Find best compatible model
            best_model = None
            best_score = 0
            
            for model_name in model_names:
                compat_info = compatibility_matrix[device_name][model_name]
                if compat_info['compatible'] and compat_info['efficiency_score'] > best_score:
                    best_model = model_name
                    best_score = compat_info['efficiency_score']
                    
            if best_model:
                model_stats = models[best_model]
                print(f"  ✅ Recommended: {best_model}")
                print(f"     Size: {model_stats['size_mb']:.1f} MB")
                print(f"     Latency: {model_stats['latency_ms']:.1f} ms")
                print(f"     Accuracy: {model_stats['accuracy']:.1f}%")
                print(f"     Efficiency Score: {best_score:.3f}")
            else:
                print(f"  ❌ No compatible model found")
                print(f"     Consider: More aggressive optimization")

# Run edge optimization analysis
edge_demo = EdgeOptimizationDemo()
models, devices, compatibility_matrix = edge_demo.analyze_edge_deployment()
edge_demo.visualize_edge_analysis(models, devices, compatibility_matrix)

## 7. Complete Inference Pipeline

Let's put it all together in a complete inference pipeline demonstration.

In [None]:
class InferencePipelineDemo:
    """Demonstrate a complete inference pipeline."""
    
    def __init__(self):
        self.pipeline_metrics = defaultdict(list)
        
    def simulate_end_to_end_pipeline(self, num_requests=1000):
        """Simulate complete inference pipeline."""
        
        # Pipeline stages with their latencies
        stages = {
            'request_parsing': {'min_ms': 1, 'max_ms': 5, 'std_ms': 1},
            'tokenization': {'min_ms': 2, 'max_ms': 10, 'std_ms': 2},
            'batching_wait': {'min_ms': 0, 'max_ms': 50, 'std_ms': 15},
            'model_inference': {'min_ms': 30, 'max_ms': 200, 'std_ms': 20},
            'post_processing': {'min_ms': 1, 'max_ms': 8, 'std_ms': 2},
            'response_formatting': {'min_ms': 1, 'max_ms': 3, 'std_ms': 0.5}
        }
        
        results = {
            'total_latency': [],
            'stage_latencies': {stage: [] for stage in stages.keys()},
            'cache_hits': [],
            'batch_sizes': [],
            'queue_times': []
        }
        
        # Simulate cache
        cache_hit_rate = 0.15  # 15% cache hit rate
        
        # Simulate batching
        batch_queue = []
        max_batch_size = 32
        
        print("Simulating end-to-end inference pipeline...")
        
        for i in tqdm(range(num_requests)):
            total_latency = 0
            stage_times = {}
            
            # Check cache hit
            cache_hit = np.random.random() < cache_hit_rate
            results['cache_hits'].append(cache_hit)
            
            if cache_hit:
                # Cache hit - much faster
                total_latency = np.random.normal(5, 1)  # 5ms average for cache hit
                for stage in stages.keys():
                    stage_times[stage] = 0 if stage != 'request_parsing' else 1
            else:
                # Full pipeline
                for stage, timing in stages.items():
                    if stage == 'batching_wait':
                        # Simulate dynamic batching
                        batch_queue.append(i)
                        
                        if len(batch_queue) >= max_batch_size or (i % 50 == 0):
                            # Process batch
                            batch_size = len(batch_queue)
                            results['batch_sizes'].append(batch_size)
                            
                            # Batching reduces per-request inference time
                            inference_speedup = min(2.0, 1 + (batch_size - 1) * 0.02)
                            stage_times['model_inference'] = max(10, 
                                np.random.normal(timing['min_ms'] + 50, timing['std_ms']) / inference_speedup
                            )
                            
                            batch_queue = []
                            wait_time = np.random.normal(timing['min_ms'] + 20, timing['std_ms'])
                        else:
                            wait_time = np.random.normal(timing['min_ms'] + 10, timing['std_ms'])
                            
                        stage_times[stage] = max(0, wait_time)
                        
                    else:
                        # Normal stage processing
                        if stage not in stage_times:  # Skip if already set by batching
                            latency = np.random.normal(
                                (timing['min_ms'] + timing['max_ms']) / 2,
                                timing['std_ms']
                            )
                            stage_times[stage] = max(timing['min_ms'], latency)
                            
                total_latency = sum(stage_times.values())
                
            # Record results
            results['total_latency'].append(total_latency)
            for stage, latency in stage_times.items():
                results['stage_latencies'][stage].append(latency)
                
        return results
        
    def visualize_pipeline_analysis(self, results):
        """Visualize pipeline performance analysis."""
        fig, axes = plt.subplots(3, 2, figsize=(16, 18))
        
        # Total latency distribution
        axes[0, 0].hist(results['total_latency'], bins=50, alpha=0.7, color='skyblue')
        axes[0, 0].axvline(np.mean(results['total_latency']), color='red', 
                          linestyle='--', label=f'Mean: {np.mean(results["total_latency"]):.1f}ms')
        axes[0, 0].axvline(np.percentile(results['total_latency'], 95), color='orange', 
                          linestyle='--', label=f'P95: {np.percentile(results["total_latency"], 95):.1f}ms')
        axes[0, 0].set_title('Total Latency Distribution')
        axes[0, 0].set_xlabel('Latency (ms)')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].legend()
        
        # Stage latency breakdown
        stage_names = list(results['stage_latencies'].keys())
        stage_means = [np.mean(results['stage_latencies'][stage]) for stage in stage_names]
        
        bars = axes[0, 1].bar(range(len(stage_names)), stage_means)
        axes[0, 1].set_title('Average Latency by Stage')
        axes[0, 1].set_xlabel('Pipeline Stage')
        axes[0, 1].set_ylabel('Average Latency (ms)')
        axes[0, 1].set_xticks(range(len(stage_names)))
        axes[0, 1].set_xticklabels([s.replace('_', '\n') for s in stage_names], rotation=0)
        
        # Cache hit impact
        cache_hits = np.array(results['cache_hits'])
        cache_hit_latencies = np.array(results['total_latency'])[cache_hits]
        cache_miss_latencies = np.array(results['total_latency'])[~cache_hits]
        
        axes[1, 0].hist(cache_hit_latencies, bins=30, alpha=0.7, 
                       label=f'Cache Hits (n={len(cache_hit_latencies)})', color='green')
        axes[1, 0].hist(cache_miss_latencies, bins=30, alpha=0.7, 
                       label=f'Cache Misses (n={len(cache_miss_latencies)})', color='red')
        axes[1, 0].set_title('Latency: Cache Hits vs Misses')
        axes[1, 0].set_xlabel('Latency (ms)')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].legend()
        
        # Batch size distribution
        if results['batch_sizes']:
            axes[1, 1].hist(results['batch_sizes'], bins=range(1, max(results['batch_sizes']) + 2), 
                           alpha=0.7, color='purple', edgecolor='black')
            axes[1, 1].set_title('Batch Size Distribution')
            axes[1, 1].set_xlabel('Batch Size')
            axes[1, 1].set_ylabel('Frequency')
            axes[1, 1].axvline(np.mean(results['batch_sizes']), color='red', 
                              linestyle='--', label=f'Mean: {np.mean(results["batch_sizes"]):.1f}')
            axes[1, 1].legend()
        
        # Latency over time (showing trends)
        window_size = 50
        rolling_latency = []
        for i in range(window_size, len(results['total_latency'])):
            rolling_latency.append(
                np.mean(results['total_latency'][i-window_size:i])
            )
            
        axes[2, 0].plot(rolling_latency, alpha=0.8)
        axes[2, 0].set_title(f'Rolling Average Latency (window={window_size})')
        axes[2, 0].set_xlabel('Request Number')
        axes[2, 0].set_ylabel('Latency (ms)')
        axes[2, 0].grid(True, alpha=0.3)
        
        # Performance percentiles
        percentiles = [50, 75, 90, 95, 99]
        latency_percentiles = [np.percentile(results['total_latency'], p) for p in percentiles]
        
        bars = axes[2, 1].bar(range(len(percentiles)), latency_percentiles, 
                             color='lightcoral', alpha=0.7)
        axes[2, 1].set_title('Latency Percentiles')
        axes[2, 1].set_xlabel('Percentile')
        axes[2, 1].set_ylabel('Latency (ms)')
        axes[2, 1].set_xticks(range(len(percentiles)))
        axes[2, 1].set_xticklabels([f'P{p}' for p in percentiles])
        
        # Add value labels on bars
        for i, (bar, value) in enumerate(zip(bars, latency_percentiles)):
            axes[2, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                           f'{value:.1f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        # Print comprehensive summary
        print("\n🏭 End-to-End Pipeline Performance Summary:")
        print("=" * 60)
        
        # Overall metrics
        total_requests = len(results['total_latency'])
        cache_hit_rate = np.mean(results['cache_hits']) * 100
        avg_latency = np.mean(results['total_latency'])
        p95_latency = np.percentile(results['total_latency'], 95)
        
        print(f"\n📊 Overall Performance:")
        print(f"  Total Requests: {total_requests:,}")
        print(f"  Cache Hit Rate: {cache_hit_rate:.1f}%")
        print(f"  Average Latency: {avg_latency:.2f} ms")
        print(f"  P95 Latency: {p95_latency:.2f} ms")
        print(f"  Theoretical Throughput: {1000 / avg_latency:.1f} RPS")
        
        # Stage breakdown
        print(f"\n⏱️ Stage Latency Breakdown:")
        total_stage_time = sum(np.mean(results['stage_latencies'][stage]) for stage in stage_names)
        
        for stage in stage_names:
            avg_time = np.mean(results['stage_latencies'][stage])
            percentage = (avg_time / total_stage_time) * 100
            print(f"  {stage.replace('_', ' ').title():<20}: {avg_time:>6.2f} ms ({percentage:>5.1f}%)")
        
        # Cache impact analysis
        if len(cache_hit_latencies) > 0 and len(cache_miss_latencies) > 0:
            cache_speedup = np.mean(cache_miss_latencies) / np.mean(cache_hit_latencies)
            print(f"\n💾 Cache Impact:")
            print(f"  Cache Hit Latency: {np.mean(cache_hit_latencies):.2f} ms")
            print(f"  Cache Miss Latency: {np.mean(cache_miss_latencies):.2f} ms")
            print(f"  Cache Speedup: {cache_speedup:.1f}x")
        
        # Batching analysis
        if results['batch_sizes']:
            avg_batch_size = np.mean(results['batch_sizes'])
            batch_efficiency = avg_batch_size / max(results['batch_sizes'])
            print(f"\n📦 Batching Analysis:")
            print(f"  Average Batch Size: {avg_batch_size:.1f}")
            print(f"  Batch Efficiency: {batch_efficiency:.1%}")
            print(f"  Total Batches: {len(results['batch_sizes'])}")
        
        # SLA compliance
        sla_threshold = 100  # 100ms SLA
        sla_compliance = np.mean(np.array(results['total_latency']) <= sla_threshold) * 100
        print(f"\n📋 SLA Compliance:")
        print(f"  SLA Threshold: {sla_threshold} ms")
        print(f"  Compliance Rate: {sla_compliance:.1f}%")
        
        if sla_compliance < 99.0:
            print(f"  ⚠️  SLA compliance below 99% - consider optimization")
        else:
            print(f"  ✅ SLA compliance meets target")

# Run complete pipeline demonstration
pipeline_demo = InferencePipelineDemo()
pipeline_results = pipeline_demo.simulate_end_to_end_pipeline(num_requests=2000)
pipeline_demo.visualize_pipeline_analysis(pipeline_results)

## 🎯 Key Takeaways

**Production Serving Fundamentals:**
- Monitor latency, throughput, availability, and cost
- Design for 99.9%+ availability with proper error handling
- Balance resource utilization (70-90% target)

**Batching Strategies:**
- Static batching: Simple but may cause delays
- Dynamic batching: Better latency-throughput balance
- Continuous batching: Optimal for generation tasks
- Bucket batching: Reduces padding overhead

**Memory Optimizations:**
- KV cache provides exponential speedup for generation
- Response caching reduces duplicate computations
- Memory pools prevent allocation overhead

**Scaling and Load Balancing:**
- Least-loaded and latency-aware routing perform best
- Auto-scaling based on multiple metrics (CPU, latency, queue)
- Conservative scale-down to avoid thrashing

**Edge Deployment:**
- Aggressive optimization needed for resource constraints
- Quantization + pruning provides best size reduction
- Trade-off analysis crucial for deployment decisions

## 🚀 Next Steps

Ready to evaluate your scaled models? Continue to [Topic 14: Evaluation and Safety](../14-evaluation-safety/) to learn comprehensive evaluation methodologies and safety measures for production LLMs!