# Continuous Batching and Inference Pipelines

This notebook explores continuous batching, a critical technique for maximizing LLM inference throughput.

We'll cover:
1. **Problem**: Static batching limitations
2. **Solution**: Continuous (dynamic) batching
3. **PagedAttention**: Efficient KV cache management
4. **Request scheduling**: Priority and fairness
5. **Throughput optimization**: Real-world patterns

## The Problem: Static Batching

Traditional batching issues:
- Wait for batch to fill → high latency
- Sequences finish at different times → wasted computation
- Fixed batch size → poor GPU utilization
- Can't add new requests mid-batch

**Result**: 10-20% GPU utilization in production!

## 1. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from collections import deque
import heapq

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Static Batching (Baseline)

Traditional approach:
- Collect requests until batch is full
- Process entire batch together
- Wait for all sequences to finish
- Pad shorter sequences

**Problems:**
- Head-of-line blocking
- Wasted padding computation
- Long wait times for new requests

In [None]:
@dataclass
class Request:
    """Represents an inference request."""
    id: int
    prompt_tokens: List[int]
    max_tokens: int
    generated_tokens: List[int]
    arrival_time: float
    start_time: Optional[float] = None
    finish_time: Optional[float] = None
    
    @property
    def is_complete(self):
        return len(self.generated_tokens) >= self.max_tokens
    
    @property
    def latency(self):
        if self.finish_time and self.arrival_time:
            return self.finish_time - self.arrival_time
        return None

def static_batching_simulation(requests, batch_size, tokens_per_second=100):
    """
    Simulate static batching inference.
    
    Args:
        requests: List of Request objects
        batch_size: Fixed batch size
        tokens_per_second: Throughput rate
    
    Returns:
        Completed requests, stats
    """
    completed = []
    current_time = 0
    request_queue = deque(requests)
    
    total_compute_time = 0
    total_wasted_time = 0  # Padding/waiting
    
    while request_queue:
        # Wait for batch to fill
        batch = []
        while len(batch) < batch_size and request_queue:
            req = request_queue.popleft()
            req.start_time = current_time
            batch.append(req)
        
        if not batch:
            break
        
        # Find max sequence length in batch
        max_gen_tokens = max(req.max_tokens for req in batch)
        
        # Process batch (all sequences to max length)
        for step in range(max_gen_tokens):
            # Count active sequences
            active = sum(1 for req in batch if not req.is_complete)
            
            # Time for this step
            step_time = 1.0 / tokens_per_second
            current_time += step_time
            total_compute_time += step_time
            
            # Wasted computation on padding
            if active < len(batch):
                wasted_fraction = (len(batch) - active) / len(batch)
                total_wasted_time += step_time * wasted_fraction
            
            # Generate tokens
            for req in batch:
                if not req.is_complete:
                    req.generated_tokens.append(0)  # Dummy token
                    if req.is_complete:
                        req.finish_time = current_time
        
        completed.extend(batch)
    
    stats = {
        'total_time': current_time,
        'total_compute': total_compute_time,
        'total_wasted': total_wasted_time,
        'efficiency': 1 - (total_wasted_time / total_compute_time) if total_compute_time > 0 else 0,
        'avg_latency': np.mean([r.latency for r in completed]),
        'throughput': sum(len(r.generated_tokens) for r in completed) / current_time
    }
    
    return completed, stats

# Generate sample requests
def generate_requests(num_requests, avg_prompt_len=10, avg_output_len=50):
    requests = []
    for i in range(num_requests):
        prompt_len = max(1, int(np.random.normal(avg_prompt_len, avg_prompt_len * 0.3)))
        output_len = max(1, int(np.random.normal(avg_output_len, avg_output_len * 0.5)))
        
        req = Request(
            id=i,
            prompt_tokens=list(range(prompt_len)),
            max_tokens=output_len,
            generated_tokens=[],
            arrival_time=i * 0.1  # Requests arrive every 0.1s
        )
        requests.append(req)
    return requests

# Test static batching
requests = generate_requests(20)
completed, stats = static_batching_simulation(requests, batch_size=4)

print("=== Static Batching Results ===")
print(f"Total time: {stats['total_time']:.2f}s")
print(f"Compute efficiency: {stats['efficiency']*100:.1f}%")
print(f"Wasted computation: {stats['total_wasted']:.2f}s ({stats['total_wasted']/stats['total_compute']*100:.1f}%)")
print(f"Average latency: {stats['avg_latency']:.2f}s")
print(f"Throughput: {stats['throughput']:.1f} tokens/s")

## 3. Continuous Batching

**Key innovation**: Add/remove requests dynamically at each iteration

**Benefits:**
- Remove completed sequences immediately
- Add new requests without waiting
- No padding waste
- Much higher GPU utilization (60-90%)

**Implementation:**
1. Maintain active batch of requests
2. At each step:
   - Generate one token per active request
   - Remove completed requests
   - Add new requests if space available
3. Dynamic batch size based on capacity

In [None]:
class ContinuousBatchingEngine:
    """
    Continuous batching inference engine.
    Dynamically adds/removes requests at each iteration.
    """
    def __init__(self, max_batch_size, tokens_per_second=100):
        self.max_batch_size = max_batch_size
        self.tokens_per_second = tokens_per_second
        
        self.active_batch: List[Request] = []
        self.waiting_queue = deque()
        self.completed: List[Request] = []
        
        self.current_time = 0
        self.total_tokens_generated = 0
    
    def add_request(self, request: Request):
        """Add request to waiting queue."""
        self.waiting_queue.append(request)
    
    def step(self):
        """Execute one inference step."""
        if not self.active_batch and not self.waiting_queue:
            return False
        
        # Add new requests to batch if space available
        while len(self.active_batch) < self.max_batch_size and self.waiting_queue:
            req = self.waiting_queue.popleft()
            req.start_time = self.current_time
            self.active_batch.append(req)
        
        if not self.active_batch:
            return False
        
        # Generate one token for each active request
        batch_size = len(self.active_batch)
        step_time = 1.0 / (self.tokens_per_second * batch_size)
        self.current_time += step_time
        
        # Process each request
        to_remove = []
        for req in self.active_batch:
            req.generated_tokens.append(0)  # Dummy token
            self.total_tokens_generated += 1
            
            if req.is_complete:
                req.finish_time = self.current_time
                to_remove.append(req)
                self.completed.append(req)
        
        # Remove completed requests (continuous batching!)
        for req in to_remove:
            self.active_batch.remove(req)
        
        return True
    
    def run(self, requests: List[Request]):
        """Run inference on all requests."""
        # Add all requests
        for req in requests:
            self.add_request(req)
        
        # Process until done
        while self.step():
            pass
        
        return self.get_stats()
    
    def get_stats(self):
        """Calculate performance statistics."""
        if not self.completed:
            return {}
        
        latencies = [r.latency for r in self.completed]
        return {
            'total_time': self.current_time,
            'avg_latency': np.mean(latencies),
            'p50_latency': np.percentile(latencies, 50),
            'p95_latency': np.percentile(latencies, 95),
            'p99_latency': np.percentile(latencies, 99),
            'throughput': self.total_tokens_generated / self.current_time,
            'total_requests': len(self.completed)
        }

# Test continuous batching
requests = generate_requests(20)
engine = ContinuousBatchingEngine(max_batch_size=8)
stats = engine.run(requests)

print("\n=== Continuous Batching Results ===")
print(f"Total time: {stats['total_time']:.2f}s")
print(f"Average latency: {stats['avg_latency']:.2f}s")
print(f"P95 latency: {stats['p95_latency']:.2f}s")
print(f"P99 latency: {stats['p99_latency']:.2f}s")
print(f"Throughput: {stats['throughput']:.1f} tokens/s")
print(f"Total requests: {stats['total_requests']}")

## 4. Comparison: Static vs Continuous Batching

Let's compare both approaches across different scenarios.

In [None]:
def compare_batching_strategies():
    """
    Compare static vs continuous batching.
    """
    scenarios = [
        {'name': 'Low Load (10 req)', 'num_requests': 10, 'batch_size': 4},
        {'name': 'Medium Load (50 req)', 'num_requests': 50, 'batch_size': 8},
        {'name': 'High Load (100 req)', 'num_requests': 100, 'batch_size': 16},
    ]
    
    results = []
    
    for scenario in scenarios:
        print(f"\n=== {scenario['name']} ===")
        
        # Generate requests
        requests_static = generate_requests(scenario['num_requests'])
        requests_continuous = [Request(
            id=r.id, 
            prompt_tokens=r.prompt_tokens.copy(),
            max_tokens=r.max_tokens,
            generated_tokens=[],
            arrival_time=r.arrival_time
        ) for r in requests_static]
        
        # Static batching
        _, static_stats = static_batching_simulation(
            requests_static, 
            batch_size=scenario['batch_size']
        )
        
        # Continuous batching
        engine = ContinuousBatchingEngine(max_batch_size=scenario['batch_size'])
        continuous_stats = engine.run(requests_continuous)
        
        print(f"\nStatic Batching:")
        print(f"  Latency: {static_stats['avg_latency']:.2f}s")
        print(f"  Throughput: {static_stats['throughput']:.1f} tok/s")
        print(f"  Efficiency: {static_stats['efficiency']*100:.1f}%")
        
        print(f"\nContinuous Batching:")
        print(f"  Latency: {continuous_stats['avg_latency']:.2f}s (P95: {continuous_stats['p95_latency']:.2f}s)")
        print(f"  Throughput: {continuous_stats['throughput']:.1f} tok/s")
        
        improvement = (static_stats['avg_latency'] - continuous_stats['avg_latency']) / static_stats['avg_latency'] * 100
        throughput_improvement = (continuous_stats['throughput'] - static_stats['throughput']) / static_stats['throughput'] * 100
        
        print(f"\nImprovement:")
        print(f"  Latency: {improvement:.1f}% better")
        print(f"  Throughput: {throughput_improvement:.1f}% better")
        
        results.append({
            'scenario': scenario['name'],
            'static_latency': static_stats['avg_latency'],
            'continuous_latency': continuous_stats['avg_latency'],
            'static_throughput': static_stats['throughput'],
            'continuous_throughput': continuous_stats['throughput'],
            'latency_improvement': improvement,
            'throughput_improvement': throughput_improvement
        })
    
    return results

results = compare_batching_strategies()

## 5. Visualization: Static vs Continuous Batching

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

scenarios = [r['scenario'] for r in results]
x_pos = np.arange(len(scenarios))
width = 0.35

# Plot 1: Latency comparison
ax = axes[0]
static_latencies = [r['static_latency'] for r in results]
continuous_latencies = [r['continuous_latency'] for r in results]

ax.bar(x_pos - width/2, static_latencies, width, label='Static Batching', color='#ff7f0e')
ax.bar(x_pos + width/2, continuous_latencies, width, label='Continuous Batching', color='#2ca02c')

ax.set_xlabel('Scenario', fontsize=11)
ax.set_ylabel('Average Latency (s)', fontsize=11)
ax.set_title('Latency Comparison', fontsize=12, fontweight='bold')
ax.set_xticks(x_pos)
ax.set_xticklabels([s.split('(')[0].strip() for s in scenarios], rotation=15, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Plot 2: Throughput comparison
ax = axes[1]
static_throughput = [r['static_throughput'] for r in results]
continuous_throughput = [r['continuous_throughput'] for r in results]

ax.bar(x_pos - width/2, static_throughput, width, label='Static Batching', color='#ff7f0e')
ax.bar(x_pos + width/2, continuous_throughput, width, label='Continuous Batching', color='#2ca02c')

ax.set_xlabel('Scenario', fontsize=11)
ax.set_ylabel('Throughput (tokens/s)', fontsize=11)
ax.set_title('Throughput Comparison', fontsize=12, fontweight='bold')
ax.set_xticks(x_pos)
ax.set_xticklabels([s.split('(')[0].strip() for s in scenarios], rotation=15, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("1. Continuous batching reduces latency by 30-70%")
print("2. Throughput improves by 50-200%")
print("3. Benefits increase with load")
print("4. No padding waste = better efficiency")

## 6. PagedAttention: Memory Management

**Problem**: KV cache memory fragmentation
- Each sequence needs contiguous memory
- Hard to predict final length
- Over-allocation wastes memory
- Under-allocation causes failures

**Solution: PagedAttention (vLLM)**
- Split KV cache into fixed-size blocks (pages)
- Allocate blocks on-demand
- Non-contiguous memory OK
- Share blocks for identical prefixes
- ~90% memory utilization vs ~40% traditional

In [None]:
class PagedKVCache:
    """
    Simplified PagedAttention memory manager.
    Manages KV cache in fixed-size blocks.
    """
    def __init__(self, num_blocks, block_size, d_model):
        """
        Args:
            num_blocks: Total number of blocks available
            block_size: Tokens per block (e.g., 16)
            d_model: Model dimension
        """
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.d_model = d_model
        
        # Physical memory (all blocks)
        self.blocks = torch.zeros(num_blocks, block_size, d_model)
        self.free_blocks = set(range(num_blocks))
        
        # Mapping: request_id -> list of block indices
        self.request_blocks = {}
        self.request_tokens = {}  # Tokens stored per request
    
    def allocate_request(self, request_id):
        """Allocate blocks for a new request."""
        if request_id not in self.request_blocks:
            self.request_blocks[request_id] = []
            self.request_tokens[request_id] = 0
    
    def allocate_block(self, request_id):
        """Allocate one block for a request."""
        if not self.free_blocks:
            raise RuntimeError("Out of memory: no free blocks")
        
        block_idx = self.free_blocks.pop()
        self.request_blocks[request_id].append(block_idx)
        return block_idx
    
    def append_token(self, request_id, kv_data):
        """
        Append KV for one token to request's cache.
        Allocates new block if needed.
        """
        if request_id not in self.request_blocks:
            self.allocate_request(request_id)
        
        token_idx = self.request_tokens[request_id]
        block_idx_in_seq = token_idx // self.block_size
        offset_in_block = token_idx % self.block_size
        
        # Allocate new block if needed
        if block_idx_in_seq >= len(self.request_blocks[request_id]):
            self.allocate_block(request_id)
        
        # Write to block
        physical_block = self.request_blocks[request_id][block_idx_in_seq]
        self.blocks[physical_block, offset_in_block] = kv_data
        
        self.request_tokens[request_id] += 1
    
    def free_request(self, request_id):
        """Free all blocks for a request."""
        if request_id in self.request_blocks:
            for block_idx in self.request_blocks[request_id]:
                self.free_blocks.add(block_idx)
            del self.request_blocks[request_id]
            del self.request_tokens[request_id]
    
    def get_stats(self):
        """Get memory utilization statistics."""
        used_blocks = self.num_blocks - len(self.free_blocks)
        total_tokens = sum(self.request_tokens.values())
        allocated_capacity = used_blocks * self.block_size
        
        utilization = total_tokens / allocated_capacity if allocated_capacity > 0 else 0
        
        return {
            'used_blocks': used_blocks,
            'free_blocks': len(self.free_blocks),
            'total_blocks': self.num_blocks,
            'utilization': utilization,
            'total_tokens': total_tokens,
            'allocated_capacity': allocated_capacity
        }

# Example usage
cache = PagedKVCache(num_blocks=100, block_size=16, d_model=256)

# Simulate adding tokens for different requests
print("=== PagedAttention Memory Management ===")
print()

# Request 1: 50 tokens
cache.allocate_request(request_id=1)
for i in range(50):
    cache.append_token(1, torch.randn(256))

stats = cache.get_stats()
print(f"After Request 1 (50 tokens):")
print(f"  Used blocks: {stats['used_blocks']} / {stats['total_blocks']}")
print(f"  Utilization: {stats['utilization']*100:.1f}%")
print(f"  Tokens: {stats['total_tokens']} / {stats['allocated_capacity']} capacity")

# Request 2: 30 tokens
cache.allocate_request(request_id=2)
for i in range(30):
    cache.append_token(2, torch.randn(256))

stats = cache.get_stats()
print(f"\nAfter Request 2 (30 tokens):")
print(f"  Used blocks: {stats['used_blocks']} / {stats['total_blocks']}")
print(f"  Utilization: {stats['utilization']*100:.1f}%")
print(f"  Tokens: {stats['total_tokens']} / {stats['allocated_capacity']} capacity")

# Free request 1
cache.free_request(1)
stats = cache.get_stats()
print(f"\nAfter freeing Request 1:")
print(f"  Used blocks: {stats['used_blocks']} / {stats['total_blocks']}")
print(f"  Utilization: {stats['utilization']*100:.1f}%")
print(f"  Free blocks: {stats['free_blocks']} (can serve new requests)")

print("\nBenefits of PagedAttention:")
print("  ✓ No memory fragmentation")
print("  ✓ Allocate on-demand (no over-allocation)")
print("  ✓ Easy to free completed requests")
print("  ✓ Can share blocks for identical prefixes")
print("  ✓ ~90% memory utilization (vs ~40% traditional)")

## 7. Request Scheduling and Prioritization

With continuous batching, we can implement sophisticated scheduling:
- **FCFS**: First-come, first-served
- **Priority-based**: VIP users, urgent requests
- **Shortest-job-first**: Minimize average latency
- **Fair queuing**: Prevent starvation

In [None]:
class PriorityScheduler:
    """
    Priority-based request scheduler.
    """
    def __init__(self):
        self.queue = []  # Min-heap: (priority, arrival_time, request)
        self.counter = 0
    
    def add_request(self, request: Request, priority: int = 0):
        """
        Add request with priority (lower number = higher priority).
        """
        # Use counter for tie-breaking (FCFS within same priority)
        heapq.heappush(self.queue, (priority, self.counter, request))
        self.counter += 1
    
    def get_next(self) -> Optional[Request]:
        """Get highest priority request."""
        if self.queue:
            _, _, request = heapq.heappop(self.queue)
            return request
        return None
    
    def size(self):
        return len(self.queue)

class ShortestJobFirstScheduler:
    """
    Shortest-job-first scheduler (minimizes average latency).
    """
    def __init__(self):
        self.queue = []  # Min-heap: (estimated_length, arrival_time, request)
        self.counter = 0
    
    def add_request(self, request: Request):
        """Add request (sorted by expected length)."""
        estimated_length = request.max_tokens
        heapq.heappush(self.queue, (estimated_length, self.counter, request))
        self.counter += 1
    
    def get_next(self) -> Optional[Request]:
        if self.queue:
            _, _, request = heapq.heappop(self.queue)
            return request
        return None
    
    def size(self):
        return len(self.queue)

# Compare schedulers
def test_schedulers():
    print("=== Request Scheduling Comparison ===")
    print()
    
    # Generate requests with varying lengths
    requests = [
        Request(id=1, prompt_tokens=[0]*10, max_tokens=100, generated_tokens=[], arrival_time=0),
        Request(id=2, prompt_tokens=[0]*10, max_tokens=20, generated_tokens=[], arrival_time=0.1),
        Request(id=3, prompt_tokens=[0]*10, max_tokens=200, generated_tokens=[], arrival_time=0.2),
        Request(id=4, prompt_tokens=[0]*10, max_tokens=30, generated_tokens=[], arrival_time=0.3),
        Request(id=5, prompt_tokens=[0]*10, max_tokens=150, generated_tokens=[], arrival_time=0.4),
    ]
    
    # FCFS (baseline)
    print("FCFS Order:")
    for req in requests:
        print(f"  Request {req.id}: {req.max_tokens} tokens")
    
    # Shortest-Job-First
    sjf = ShortestJobFirstScheduler()
    for req in requests:
        sjf.add_request(req)
    
    print("\nShortest-Job-First Order:")
    while sjf.size() > 0:
        req = sjf.get_next()
        print(f"  Request {req.id}: {req.max_tokens} tokens")
    
    # Priority-based
    priority_sched = PriorityScheduler()
    priorities = [1, 0, 2, 0, 1]  # 0 = high, 1 = medium, 2 = low
    for req, prio in zip(requests, priorities):
        priority_sched.add_request(req, priority=prio)
    
    print("\nPriority-Based Order:")
    while priority_sched.size() > 0:
        req = priority_sched.get_next()
        print(f"  Request {req.id}: {req.max_tokens} tokens (priority {priorities[req.id-1]})")
    
    print("\nKey Insights:")
    print("  • FCFS: Simple but can cause head-of-line blocking")
    print("  • SJF: Minimizes average latency but may starve long jobs")
    print("  • Priority: Flexible but needs careful tuning")
    print("  • Production: Often use hybrid approach")

test_schedulers()

## 8. Production Inference Pipeline

Complete inference system with all optimizations.

In [None]:
class ProductionInferenceEngine:
    """
    Production-grade inference engine combining:
    - Continuous batching
    - PagedAttention memory management
    - Request scheduling
    - Performance monitoring
    """
    def __init__(self, 
                 max_batch_size=32,
                 num_memory_blocks=1000,
                 block_size=16,
                 d_model=256):
        self.max_batch_size = max_batch_size
        
        # Memory management
        self.kv_cache = PagedKVCache(num_memory_blocks, block_size, d_model)
        
        # Scheduling
        self.scheduler = ShortestJobFirstScheduler()
        
        # Active batch
        self.active_batch: List[Request] = []
        self.completed: List[Request] = []
        
        # Metrics
        self.current_time = 0
        self.total_tokens = 0
        self.batch_sizes = []
    
    def add_request(self, request: Request):
        """Add new request to scheduler."""
        self.scheduler.add_request(request)
    
    def step(self):
        """Execute one inference iteration."""
        # Add requests from queue to batch
        while len(self.active_batch) < self.max_batch_size and self.scheduler.size() > 0:
            req = self.scheduler.get_next()
            req.start_time = self.current_time
            self.kv_cache.allocate_request(req.id)
            self.active_batch.append(req)
        
        if not self.active_batch:
            return False
        
        # Record batch size
        self.batch_sizes.append(len(self.active_batch))
        
        # Simulate inference
        self.current_time += 0.01  # 10ms per step
        
        # Generate tokens
        to_remove = []
        for req in self.active_batch:
            # Generate token
            req.generated_tokens.append(0)
            self.total_tokens += 1
            
            # Update KV cache
            self.kv_cache.append_token(req.id, torch.randn(256))
            
            # Check completion
            if req.is_complete:
                req.finish_time = self.current_time
                self.kv_cache.free_request(req.id)
                to_remove.append(req)
                self.completed.append(req)
        
        # Remove completed
        for req in to_remove:
            self.active_batch.remove(req)
        
        return True
    
    def run_until_complete(self):
        """Run until all requests complete."""
        while self.step():
            pass
        
        return self.get_stats()
    
    def get_stats(self):
        """Comprehensive performance statistics."""
        if not self.completed:
            return {}
        
        latencies = [r.latency for r in self.completed]
        memory_stats = self.kv_cache.get_stats()
        
        return {
            'total_time': self.current_time,
            'total_requests': len(self.completed),
            'total_tokens': self.total_tokens,
            'throughput': self.total_tokens / self.current_time,
            'avg_latency': np.mean(latencies),
            'p50_latency': np.percentile(latencies, 50),
            'p95_latency': np.percentile(latencies, 95),
            'p99_latency': np.percentile(latencies, 99),
            'avg_batch_size': np.mean(self.batch_sizes),
            'max_batch_size': max(self.batch_sizes),
            'memory_utilization': memory_stats['utilization'],
        }

# Test production engine
print("=== Production Inference Engine ===")
print()

engine = ProductionInferenceEngine(
    max_batch_size=16,
    num_memory_blocks=500,
    block_size=16
)

# Add requests
requests = generate_requests(50, avg_output_len=40)
for req in requests:
    engine.add_request(req)

# Run
stats = engine.run_until_complete()

print(f"Performance Metrics:")
print(f"  Total time: {stats['total_time']:.2f}s")
print(f"  Requests completed: {stats['total_requests']}")
print(f"  Total tokens: {stats['total_tokens']}")
print(f"  Throughput: {stats['throughput']:.1f} tokens/s")
print()
print(f"Latency:")
print(f"  Average: {stats['avg_latency']:.2f}s")
print(f"  P50: {stats['p50_latency']:.2f}s")
print(f"  P95: {stats['p95_latency']:.2f}s")
print(f"  P99: {stats['p99_latency']:.2f}s")
print()
print(f"Resource Utilization:")
print(f"  Avg batch size: {stats['avg_batch_size']:.1f} / {engine.max_batch_size}")
print(f"  Max batch size: {stats['max_batch_size']}")
print(f"  Memory utilization: {stats['memory_utilization']*100:.1f}%")

## 9. Real-World Best Practices

In [None]:
print("=== Continuous Batching: Production Guide ===")
print()

print("1. Batch Size Configuration:")
print("   • Start with GPU memory / (KV cache per request)")
print("   • Typical: 16-64 for 7B models, 8-32 for 70B models")
print("   • Monitor memory utilization and adjust")
print("   • Consider using dynamic batch size")
print()

print("2. Memory Management:")
print("   • Use PagedAttention (vLLM implementation)")
print("   • Block size: 16-32 tokens typical")
print("   • Reserve memory for new requests")
print("   • Implement OOM handling (reject or queue)")
print()

print("3. Scheduling Strategy:")
print("   • FCFS for fairness")
print("   • SJF for minimum latency")
print("   • Priority queues for tiered service")
print("   • Consider: preemption for long-running requests")
print()

print("4. Performance Tuning:")
print("   • Target: 60-90% GPU utilization")
print("   • Monitor: P95, P99 latency")
print("   • Profile: time per iteration, memory usage")
print("   • A/B test: different batch sizes, schedulers")
print()

print("5. Production Checklist:")
print("   ✓ Implement timeout handling")
print("   ✓ Add request cancellation")
print("   ✓ Monitor queue depth")
print("   ✓ Log performance metrics")
print("   ✓ Implement graceful degradation")
print("   ✓ Set up alerting for latency spikes")
print()

print("6. Expected Improvements:")
print("   • Throughput: 5-10× vs static batching")
print("   • Latency: 30-70% reduction")
print("   • GPU utilization: 60-90% (vs 10-20%)")
print("   • Memory utilization: 80-95% (with PagedAttention)")
print()

print("7. Common Pitfalls:")
print("   ❌ Batch size too large → OOM")
print("   ❌ Batch size too small → low throughput")
print("   ❌ No request timeout → hung requests")
print("   ❌ Poor scheduling → starvation")
print("   ❌ Memory fragmentation → crashes")
print()

print("8. Integration with Other Optimizations:")
print("   • Continuous batching + FlashAttention")
print("   • Continuous batching + Speculative decoding")
print("   • Continuous batching + Quantization")
print("   • PagedAttention + KV cache quantization")
print("   Combined: 10-20× speedup possible!")

## 10. Summary

### Key Concepts

1. **Static Batching Problems**
   - Fixed batch size
   - Wait for batch to fill
   - Padding waste
   - Low GPU utilization (10-20%)

2. **Continuous Batching Solution**
   - Dynamic batch size
   - Add/remove requests at each step
   - No padding
   - High GPU utilization (60-90%)

3. **PagedAttention**
   - Fixed-size memory blocks
   - Non-contiguous allocation
   - On-demand allocation
   - 90%+ memory utilization

4. **Request Scheduling**
   - FCFS: Simple, fair
   - SJF: Minimal latency
   - Priority: Flexible service tiers
   - Hybrid: Production standard

### Performance Impact

| Metric | Static Batching | Continuous Batching | Improvement |
|--------|----------------|--------------------|--------------|
| Throughput | 100 tok/s | 500-1000 tok/s | **5-10×** |
| Latency | 5.0s | 1.5-3.5s | **30-70%** |
| GPU Util | 10-20% | 60-90% | **3-9×** |
| Memory Util | 40% | 80-95% | **2-2.4×** |

### Real-World Systems

- **vLLM**: PagedAttention + continuous batching
- **TensorRT-LLM**: Inflight batching
- **Text Generation Inference (TGI)**: Continuous batching
- **DeepSpeed-FastGen**: Dynamic SplitFuse

### When to Use

**Use continuous batching for:**
- ✅ Production LLM serving
- ✅ High throughput requirements
- ✅ Variable request patterns
- ✅ Cost optimization

**Essential for:**
- API serving (ChatGPT, Claude, etc.)
- Multi-user applications
- Batch processing with time constraints
- Any production LLM deployment

### Key Takeaways

1. **Continuous batching is standard** for production LLMs
2. **Combine with PagedAttention** for maximum efficiency
3. **Choose scheduler** based on use case
4. **Monitor metrics** and tune dynamically
5. **5-10× throughput improvement** typical
6. **Critical for cost-effective** LLM serving

Continuous batching transforms LLM inference from toy to production-ready!