# Performance Deep Dive: torch.compile Optimization Strategies

This notebook provides an in-depth analysis of performance optimization strategies available through torch.compile, including kernel fusion, memory optimization, and hardware-specific tuning.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from triton.testing import do_bench
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Callable
import time
import gc
import seaborn as sns
from dataclasses import dataclass

## Performance Measurement Framework

In [None]:
@dataclass
class PerformanceMetrics:
    """Container for performance metrics"""
    execution_time: float
    memory_usage: float
    throughput: float  # samples per second
    compilation_time: float = 0.0
    first_run_time: float = 0.0

class AdvancedBenchmarker:
    """Advanced benchmarking suite with detailed metrics"""
    
    def __init__(self, warmup_runs: int = 100, benchmark_runs: int = 1000):
        self.warmup_runs = warmup_runs
        self.benchmark_runs = benchmark_runs
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def measure_compilation_overhead(self, model: nn.Module, input_data: torch.Tensor, mode: str = "default") -> Tuple[float, float]:
        """Measure compilation time and first execution time"""
        torch._dynamo.reset()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Measure compilation time
        start_time = time.perf_counter()
        compiled_model = torch.compile(model, mode=mode)
        compilation_time = time.perf_counter() - start_time
        
        # Measure first execution (includes any additional compilation)
        start_time = time.perf_counter()
        with torch.no_grad():
            _ = compiled_model(input_data)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        first_run_time = time.perf_counter() - start_time
        
        return compilation_time, first_run_time
    
    def comprehensive_benchmark(self, model: nn.Module, input_data: torch.Tensor, batch_size: int = None) -> Dict[str, PerformanceMetrics]:
        """Run comprehensive benchmarks across all modes"""
        if batch_size is None:
            batch_size = input_data.shape[0] if len(input_data.shape) > 0 else 1
        
        modes = {
            "eager": None,
            "default": "default",
            "reduce-overhead": "reduce-overhead",
            "max-autotune": "max-autotune"
        }
        
        results = {}
        
        for mode_name, compile_mode in modes.items():
            print(f"Benchmarking {mode_name} mode...")
            
            if mode_name == "eager":
                test_model = model
                compilation_time = 0.0
                first_run_time = 0.0
            else:
                compilation_time, first_run_time = self.measure_compilation_overhead(model, input_data, compile_mode)
                test_model = torch.compile(model, mode=compile_mode)
            
            # Measure steady-state performance
            exec_time = do_bench(
                lambda: test_model(input_data),
                warmup=self.warmup_runs,
                rep=self.benchmark_runs
            )
            
            # Measure memory usage
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
                with torch.no_grad():
                    _ = test_model(input_data)
                memory_usage = torch.cuda.max_memory_allocated() / 1024**3  # GB
            else:
                memory_usage = 0.0
            
            # Calculate throughput
            throughput = (batch_size * 1000) / exec_time  # samples per second
            
            results[mode_name] = PerformanceMetrics(
                execution_time=exec_time,
                memory_usage=memory_usage,
                throughput=throughput,
                compilation_time=compilation_time * 1000,  # Convert to ms
                first_run_time=first_run_time * 1000  # Convert to ms
            )
        
        return results
    
    def plot_performance_comparison(self, results: Dict[str, PerformanceMetrics], title: str = "Performance Comparison"):
        """Create comprehensive performance comparison plots"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        modes = list(results.keys())
        colors = ['red', 'blue', 'green', 'orange'][:len(modes)]
        
        # Execution time comparison
        exec_times = [results[mode].execution_time for mode in modes]
        bars1 = ax1.bar(modes, exec_times, color=colors)
        ax1.set_ylabel('Execution Time (ms)')
        ax1.set_title('Execution Time by Mode')
        ax1.tick_params(axis='x', rotation=45)
        
        for bar, time in zip(bars1, exec_times):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(exec_times)*0.01,
                    f'{time:.2f}', ha='center', va='bottom', fontsize=10)
        
        # Throughput comparison
        throughputs = [results[mode].throughput for mode in modes]
        bars2 = ax2.bar(modes, throughputs, color=colors)
        ax2.set_ylabel('Throughput (samples/sec)')
        ax2.set_title('Throughput by Mode')
        ax2.tick_params(axis='x', rotation=45)
        
        for bar, throughput in zip(bars2, throughputs):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01,
                    f'{throughput:.0f}', ha='center', va='bottom', fontsize=10)
        
        # Memory usage comparison
        memory_usage = [results[mode].memory_usage for mode in modes]
        bars3 = ax3.bar(modes, memory_usage, color=colors)
        ax3.set_ylabel('Memory Usage (GB)')
        ax3.set_title('Memory Usage by Mode')
        ax3.tick_params(axis='x', rotation=45)
        
        for bar, memory in zip(bars3, memory_usage):
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(memory_usage)*0.01,
                    f'{memory:.3f}', ha='center', va='bottom', fontsize=10)
        
        # Compilation overhead
        compile_times = [results[mode].compilation_time for mode in modes if mode != 'eager']
        first_run_times = [results[mode].first_run_time for mode in modes if mode != 'eager']
        compile_modes = [mode for mode in modes if mode != 'eager']
        
        x_pos = np.arange(len(compile_modes))
        width = 0.35
        
        bars4a = ax4.bar(x_pos - width/2, compile_times, width, label='Compilation Time', color='lightblue')
        bars4b = ax4.bar(x_pos + width/2, first_run_times, width, label='First Run Time', color='lightcoral')
        
        ax4.set_ylabel('Time (ms)')
        ax4.set_title('Compilation Overhead')
        ax4.set_xticks(x_pos)
        ax4.set_xticklabels(compile_modes, rotation=45)
        ax4.legend()
        
        plt.suptitle(title, fontsize=16)
        plt.tight_layout()
        plt.show()
        
        # Print speedup summary
        eager_time = results['eager'].execution_time
        print("\n" + "="*50)
        print("SPEEDUP SUMMARY")
        print("="*50)
        for mode in modes:
            if mode != 'eager':
                speedup = eager_time / results[mode].execution_time
                print(f"{mode:15s}: {speedup:6.2f}x speedup")
        print("="*50)

# Initialize benchmarker
benchmarker = AdvancedBenchmarker()

## Kernel Fusion Analysis

In [None]:
class FusionAnalyzer:
    """Analyze kernel fusion patterns and benefits"""
    
    def __init__(self):
        self.fusion_patterns = []
    
    def create_fusion_test_models(self):
        """Create models with different fusion opportunities"""
        
        # Model 1: Element-wise operations (high fusion potential)
        class ElementWiseFusionModel(nn.Module):
            def forward(self, x):
                # These operations can be fused into a single kernel
                x = torch.relu(x)
                x = x + 1.0
                x = x * 2.0
                x = torch.sigmoid(x)
                return x
        
        # Model 2: Linear + activation fusion
        class LinearActivationFusionModel(nn.Module):
            def __init__(self, input_dim=1024, hidden_dim=512, output_dim=256):
                super().__init__()
                self.linear1 = nn.Linear(input_dim, hidden_dim)
                self.linear2 = nn.Linear(hidden_dim, output_dim)
            
            def forward(self, x):
                # Linear + ReLU can be fused
                x = F.relu(self.linear1(x))
                # Linear + GELU can be fused
                x = F.gelu(self.linear2(x))
                return x
        
        # Model 3: Batch norm + activation fusion
        class BatchNormFusionModel(nn.Module):
            def __init__(self, num_features=512):
                super().__init__()
                self.bn1 = nn.BatchNorm1d(num_features)
                self.bn2 = nn.BatchNorm1d(num_features)
            
            def forward(self, x):
                # BatchNorm + ReLU can be fused
                x = F.relu(self.bn1(x))
                # Multiple batch norms
                x = self.bn2(x)
                return x
        
        # Model 4: No fusion opportunities (control)
        class NoFusionModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1024, 512)
            
            def forward(self, x):
                # Simple linear operation with minimal fusion opportunity
                return self.linear(x)
        
        return {
            "elementwise_fusion": ElementWiseFusionModel(),
            "linear_activation_fusion": LinearActivationFusionModel(),
            "batchnorm_fusion": BatchNormFusionModel(),
            "no_fusion": NoFusionModel()
        }
    
    def analyze_fusion_benefits(self, batch_size: int = 64, input_dim: int = 1024):
        """Analyze performance benefits of different fusion patterns"""
        models = self.create_fusion_test_models()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        fusion_results = {}
        
        for model_name, model in models.items():
            print(f"\nAnalyzing {model_name}...")
            
            model = model.to(device)
            model.eval()
            
            # Create appropriate input
            if "batchnorm" in model_name:
                input_tensor = torch.randn(batch_size, 512, device=device)
            else:
                input_tensor = torch.randn(batch_size, input_dim, device=device)
            
            # Benchmark this model
            results = benchmarker.comprehensive_benchmark(model, input_tensor, batch_size)
            fusion_results[model_name] = results
            
            # Print quick summary
            eager_time = results['eager'].execution_time
            compiled_time = results['max-autotune'].execution_time
            speedup = eager_time / compiled_time
            print(f"  Speedup with max-autotune: {speedup:.2f}x")
        
        return fusion_results
    
    def plot_fusion_comparison(self, fusion_results: Dict):
        """Plot fusion analysis results"""
        model_names = list(fusion_results.keys())
        modes = ['eager', 'default', 'reduce-overhead', 'max-autotune']
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # Execution time comparison
        width = 0.2
        x = np.arange(len(model_names))
        
        for i, mode in enumerate(modes):
            times = [fusion_results[model][mode].execution_time for model in model_names]
            ax1.bar(x + i * width, times, width, label=mode, alpha=0.8)
        
        ax1.set_xlabel('Model Type')
        ax1.set_ylabel('Execution Time (ms)')
        ax1.set_title('Execution Time by Fusion Pattern')
        ax1.set_xticks(x + width * 1.5)
        ax1.set_xticklabels([name.replace('_', '\n') for name in model_names], rotation=45)
        ax1.legend()
        ax1.set_yscale('log')
        
        # Speedup comparison
        speedup_data = []
        for model_name in model_names:
            eager_time = fusion_results[model_name]['eager'].execution_time
            model_speedups = []
            for mode in modes[1:]:  # Skip eager
                compiled_time = fusion_results[model_name][mode].execution_time
                speedup = eager_time / compiled_time
                model_speedups.append(speedup)
            speedup_data.append(model_speedups)
        
        speedup_array = np.array(speedup_data).T
        
        for i, mode in enumerate(modes[1:]):
            ax2.bar(x + i * width, speedup_array[i], width, label=mode, alpha=0.8)
        
        ax2.set_xlabel('Model Type')
        ax2.set_ylabel('Speedup (x)')
        ax2.set_title('Speedup by Fusion Pattern')
        ax2.set_xticks(x + width)
        ax2.set_xticklabels([name.replace('_', '\n') for name in model_names], rotation=45)
        ax2.legend()
        ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        plt.show()
        
        # Print fusion analysis summary
        print("\n" + "="*60)
        print("FUSION ANALYSIS SUMMARY")
        print("="*60)
        
        for model_name in model_names:
            eager_time = fusion_results[model_name]['eager'].execution_time
            best_compiled_time = min([fusion_results[model_name][mode].execution_time for mode in modes[1:]])
            max_speedup = eager_time / best_compiled_time
            
            print(f"{model_name:25s}: {max_speedup:6.2f}x max speedup")
        
        print("="*60)

# Run fusion analysis
fusion_analyzer = FusionAnalyzer()
fusion_results = fusion_analyzer.analyze_fusion_benefits()
fusion_analyzer.plot_fusion_comparison(fusion_results)

## Memory Pattern Optimization

In [None]:
class MemoryPatternAnalyzer:
    """Analyze memory access patterns and optimization opportunities"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def create_memory_test_models(self):
        """Create models with different memory access patterns"""
        
        # Model 1: Sequential memory access (cache-friendly)
        class SequentialAccessModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.layers = nn.ModuleList([
                    nn.Linear(1024, 1024) for _ in range(4)
                ])
            
            def forward(self, x):
                for layer in self.layers:
                    x = F.relu(layer(x))
                return x
        
        # Model 2: Random memory access (cache-unfriendly)
        class RandomAccessModel(nn.Module):
            def __init__(self, vocab_size=10000, embed_dim=512):
                super().__init__()
                self.embedding = nn.Embedding(vocab_size, embed_dim)
                self.linear = nn.Linear(embed_dim, 256)
            
            def forward(self, x):
                # Random access to embedding table
                embedded = self.embedding(x)
                return self.linear(embedded.mean(dim=1))
        
        # Model 3: Memory-intensive operations
        class MemoryIntensiveModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(64, 128, 3, padding=1)
                self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
                self.pool = nn.AdaptiveAvgPool2d(1)
                self.fc = nn.Linear(256, 10)
            
            def forward(self, x):
                x = F.relu(self.conv1(x))
                x = F.relu(self.conv2(x))
                x = self.pool(x).flatten(1)
                return self.fc(x)
        
        # Model 4: In-place operations (memory efficient)
        class InPlaceModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.layers = nn.ModuleList([
                    nn.Linear(1024, 1024) for _ in range(4)
                ])
            
            def forward(self, x):
                for layer in self.layers:
                    x = layer(x)
                    x.relu_()  # In-place ReLU
                return x
        
        return {
            "sequential_access": SequentialAccessModel(),
            "random_access": RandomAccessModel(),
            "memory_intensive": MemoryIntensiveModel(),
            "inplace_ops": InPlaceModel()
        }
    
    def analyze_memory_patterns(self, batch_size: int = 32):
        """Analyze memory patterns and their optimization potential"""
        models = self.create_memory_test_models()
        memory_results = {}
        
        for model_name, model in models.items():
            print(f"\nAnalyzing memory pattern: {model_name}")
            
            model = model.to(self.device)
            model.eval()
            
            # Create appropriate input
            if model_name == "random_access":
                input_tensor = torch.randint(0, 9999, (batch_size, 50), device=self.device)
            elif model_name == "memory_intensive":
                input_tensor = torch.randn(batch_size, 64, 32, 32, device=self.device)
            else:
                input_tensor = torch.randn(batch_size, 1024, device=self.device)
            
            # Benchmark with detailed memory tracking
            results = self.detailed_memory_benchmark(model, input_tensor, batch_size)
            memory_results[model_name] = results
        
        return memory_results
    
    def detailed_memory_benchmark(self, model: nn.Module, input_tensor: torch.Tensor, batch_size: int) -> Dict:
        """Detailed memory and performance benchmark"""
        results = {}
        modes = ["eager", "default", "reduce-overhead", "max-autotune"]
        
        for mode in modes:
            if mode == "eager":
                test_model = model
            else:
                torch._dynamo.reset()
                test_model = torch.compile(model, mode=mode)
            
            # Reset memory tracking
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
            
            # Warmup
            for _ in range(10):
                with torch.no_grad():
                    _ = test_model(input_tensor)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.reset_peak_memory_stats()
            
            # Benchmark execution
            exec_time = do_bench(
                lambda: test_model(input_tensor),
                warmup=50,
                rep=200
            )
            
            # Memory measurements
            if torch.cuda.is_available():
                peak_memory = torch.cuda.max_memory_allocated() / 1024**3
                current_memory = torch.cuda.memory_allocated() / 1024**3
            else:
                peak_memory = 0.0
                current_memory = 0.0
            
            # Memory efficiency (throughput per GB)
            throughput = (batch_size * 1000) / exec_time
            memory_efficiency = throughput / max(peak_memory, 0.001)
            
            results[mode] = {
                'execution_time': exec_time,
                'peak_memory': peak_memory,
                'current_memory': current_memory,
                'throughput': throughput,
                'memory_efficiency': memory_efficiency
            }
        
        return results
    
    def plot_memory_analysis(self, memory_results: Dict):
        """Plot memory analysis results"""
        model_names = list(memory_results.keys())
        modes = ['eager', 'default', 'reduce-overhead', 'max-autotune']
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        width = 0.2
        x = np.arange(len(model_names))
        colors = ['red', 'blue', 'green', 'orange']
        
        # Execution time
        for i, mode in enumerate(modes):
            times = [memory_results[model][mode]['execution_time'] for model in model_names]
            ax1.bar(x + i * width, times, width, label=mode, color=colors[i], alpha=0.8)
        
        ax1.set_xlabel('Memory Pattern')
        ax1.set_ylabel('Execution Time (ms)')
        ax1.set_title('Execution Time by Memory Pattern')
        ax1.set_xticks(x + width * 1.5)
        ax1.set_xticklabels([name.replace('_', '\n') for name in model_names], rotation=45)
        ax1.legend()
        ax1.set_yscale('log')
        
        # Peak memory usage
        for i, mode in enumerate(modes):
            memory_usage = [memory_results[model][mode]['peak_memory'] for model in model_names]
            ax2.bar(x + i * width, memory_usage, width, label=mode, color=colors[i], alpha=0.8)
        
        ax2.set_xlabel('Memory Pattern')
        ax2.set_ylabel('Peak Memory (GB)')
        ax2.set_title('Peak Memory Usage by Pattern')
        ax2.set_xticks(x + width * 1.5)
        ax2.set_xticklabels([name.replace('_', '\n') for name in model_names], rotation=45)
        ax2.legend()
        
        # Throughput
        for i, mode in enumerate(modes):
            throughput = [memory_results[model][mode]['throughput'] for model in model_names]
            ax3.bar(x + i * width, throughput, width, label=mode, color=colors[i], alpha=0.8)
        
        ax3.set_xlabel('Memory Pattern')
        ax3.set_ylabel('Throughput (samples/sec)')
        ax3.set_title('Throughput by Memory Pattern')
        ax3.set_xticks(x + width * 1.5)
        ax3.set_xticklabels([name.replace('_', '\n') for name in model_names], rotation=45)
        ax3.legend()
        
        # Memory efficiency
        for i, mode in enumerate(modes):
            efficiency = [memory_results[model][mode]['memory_efficiency'] for model in model_names]
            ax4.bar(x + i * width, efficiency, width, label=mode, color=colors[i], alpha=0.8)
        
        ax4.set_xlabel('Memory Pattern')
        ax4.set_ylabel('Samples/sec per GB')
        ax4.set_title('Memory Efficiency by Pattern')
        ax4.set_xticks(x + width * 1.5)
        ax4.set_xticklabels([name.replace('_', '\n') for name in model_names], rotation=45)
        ax4.legend()
        
        plt.tight_layout()
        plt.show()
        
        # Print memory optimization summary
        print("\n" + "="*70)
        print("MEMORY OPTIMIZATION SUMMARY")
        print("="*70)
        
        for model_name in model_names:
            print(f"\n{model_name.upper().replace('_', ' ')}:")
            eager_results = memory_results[model_name]['eager']
            best_mode = max(modes[1:], key=lambda m: memory_results[model_name][m]['memory_efficiency'])
            best_results = memory_results[model_name][best_mode]
            
            speedup = eager_results['execution_time'] / best_results['execution_time']
            memory_reduction = (eager_results['peak_memory'] - best_results['peak_memory']) / eager_results['peak_memory'] * 100
            efficiency_gain = best_results['memory_efficiency'] / eager_results['memory_efficiency']
            
            print(f"  Best mode: {best_mode}")
            print(f"  Speedup: {speedup:.2f}x")
            print(f"  Memory reduction: {memory_reduction:.1f}%")
            print(f"  Efficiency gain: {efficiency_gain:.2f}x")
        
        print("="*70)

# Run memory pattern analysis
memory_analyzer = MemoryPatternAnalyzer()
memory_results = memory_analyzer.analyze_memory_patterns()
memory_analyzer.plot_memory_analysis(memory_results)

## Batch Size Scaling Analysis

In [None]:
class BatchScalingAnalyzer:
    """Analyze how torch.compile performance scales with batch size"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def analyze_batch_scaling(self, model: nn.Module, input_shape: Tuple, batch_sizes: List[int] = None):
        """Analyze performance scaling across different batch sizes"""
        if batch_sizes is None:
            batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
        
        model = model.to(self.device)
        model.eval()
        
        scaling_results = {}
        modes = ["eager", "default", "max-autotune"]
        
        for batch_size in batch_sizes:
            print(f"Testing batch size: {batch_size}")
            
            # Create input tensor
            input_tensor = torch.randn(batch_size, *input_shape, device=self.device)
            
            batch_results = {}
            
            for mode in modes:
                try:
                    if mode == "eager":
                        test_model = model
                    else:
                        torch._dynamo.reset()
                        test_model = torch.compile(model, mode=mode)
                    
                    # Warmup
                    for _ in range(5):
                        with torch.no_grad():
                            _ = test_model(input_tensor)
                    
                    if torch.cuda.is_available():
                        torch.cuda.synchronize()
                        torch.cuda.empty_cache()
                        torch.cuda.reset_peak_memory_stats()
                    
                    # Benchmark
                    exec_time = do_bench(
                        lambda: test_model(input_tensor),
                        warmup=20,
                        rep=100
                    )
                    
                    # Memory usage
                    if torch.cuda.is_available():
                        peak_memory = torch.cuda.max_memory_allocated() / 1024**3
                    else:
                        peak_memory = 0.0
                    
                    # Calculate metrics
                    throughput = (batch_size * 1000) / exec_time  # samples/second
                    latency_per_sample = exec_time / batch_size  # ms per sample
                    memory_per_sample = peak_memory / batch_size * 1024  # MB per sample
                    
                    batch_results[mode] = {
                        'execution_time': exec_time,
                        'throughput': throughput,
                        'latency_per_sample': latency_per_sample,
                        'peak_memory': peak_memory,
                        'memory_per_sample': memory_per_sample
                    }
                    
                except Exception as e:
                    print(f"  Failed for {mode} mode: {e}")
                    batch_results[mode] = None
            
            scaling_results[batch_size] = batch_results
        
        return scaling_results
    
    def plot_scaling_analysis(self, scaling_results: Dict, title: str = "Batch Size Scaling Analysis"):
        """Plot batch size scaling results"""
        batch_sizes = list(scaling_results.keys())
        modes = ["eager", "default", "max-autotune"]
        colors = ['red', 'blue', 'green']
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        # Throughput scaling
        for i, mode in enumerate(modes):
            throughputs = []
            valid_batch_sizes = []
            
            for batch_size in batch_sizes:
                if scaling_results[batch_size][mode] is not None:
                    throughputs.append(scaling_results[batch_size][mode]['throughput'])
                    valid_batch_sizes.append(batch_size)
            
            if throughputs:
                ax1.plot(valid_batch_sizes, throughputs, 'o-', label=mode, color=colors[i], linewidth=2, markersize=6)
        
        ax1.set_xlabel('Batch Size')
        ax1.set_ylabel('Throughput (samples/sec)')
        ax1.set_title('Throughput vs Batch Size')
        ax1.set_xscale('log', base=2)
        ax1.set_yscale('log')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Latency per sample
        for i, mode in enumerate(modes):
            latencies = []
            valid_batch_sizes = []
            
            for batch_size in batch_sizes:
                if scaling_results[batch_size][mode] is not None:
                    latencies.append(scaling_results[batch_size][mode]['latency_per_sample'])
                    valid_batch_sizes.append(batch_size)
            
            if latencies:
                ax2.plot(valid_batch_sizes, latencies, 'o-', label=mode, color=colors[i], linewidth=2, markersize=6)
        
        ax2.set_xlabel('Batch Size')
        ax2.set_ylabel('Latency per Sample (ms)')
        ax2.set_title('Latency per Sample vs Batch Size')
        ax2.set_xscale('log', base=2)
        ax2.set_yscale('log')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Memory usage
        for i, mode in enumerate(modes):
            memories = []
            valid_batch_sizes = []
            
            for batch_size in batch_sizes:
                if scaling_results[batch_size][mode] is not None:
                    memories.append(scaling_results[batch_size][mode]['peak_memory'])
                    valid_batch_sizes.append(batch_size)
            
            if memories:
                ax3.plot(valid_batch_sizes, memories, 'o-', label=mode, color=colors[i], linewidth=2, markersize=6)
        
        ax3.set_xlabel('Batch Size')
        ax3.set_ylabel('Peak Memory (GB)')
        ax3.set_title('Memory Usage vs Batch Size')
        ax3.set_xscale('log', base=2)
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Speedup over eager mode
        for i, mode in enumerate(modes[1:], 1):  # Skip eager
            speedups = []
            valid_batch_sizes = []
            
            for batch_size in batch_sizes:
                if (scaling_results[batch_size]['eager'] is not None and 
                    scaling_results[batch_size][mode] is not None):
                    eager_time = scaling_results[batch_size]['eager']['execution_time']
                    compiled_time = scaling_results[batch_size][mode]['execution_time']
                    speedup = eager_time / compiled_time
                    speedups.append(speedup)
                    valid_batch_sizes.append(batch_size)
            
            if speedups:
                ax4.plot(valid_batch_sizes, speedups, 'o-', label=mode, color=colors[i], linewidth=2, markersize=6)
        
        ax4.set_xlabel('Batch Size')
        ax4.set_ylabel('Speedup over Eager (x)')
        ax4.set_title('Speedup vs Batch Size')
        ax4.set_xscale('log', base=2)
        ax4.axhline(y=1.0, color='black', linestyle='--', alpha=0.5)
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        plt.suptitle(title, fontsize=16)
        plt.tight_layout()
        plt.show()
        
        # Print scaling efficiency analysis
        print("\n" + "="*60)
        print("BATCH SCALING EFFICIENCY ANALYSIS")
        print("="*60)
        
        for mode in modes:
            print(f"\n{mode.upper()} MODE:")
            
            # Find optimal batch size for throughput
            best_throughput = 0
            best_batch_size = 0
            
            for batch_size in batch_sizes:
                if scaling_results[batch_size][mode] is not None:
                    throughput = scaling_results[batch_size][mode]['throughput']
                    if throughput > best_throughput:
                        best_throughput = throughput
                        best_batch_size = batch_size
            
            if best_batch_size > 0:
                print(f"  Optimal batch size for throughput: {best_batch_size}")
                print(f"  Peak throughput: {best_throughput:.0f} samples/sec")
                
                # Scaling efficiency (throughput increase vs batch size increase)
                small_batch = min(batch_sizes)
                if scaling_results[small_batch][mode] is not None:
                    small_throughput = scaling_results[small_batch][mode]['throughput']
                    throughput_ratio = best_throughput / small_throughput
                    batch_ratio = best_batch_size / small_batch
                    efficiency = throughput_ratio / batch_ratio * 100
                    print(f"  Scaling efficiency: {efficiency:.1f}% (ideal would be 100%)")
        
        print("="*60)

# Example usage with a representative model
scaling_analyzer = BatchScalingAnalyzer()

# Create a representative model for scaling analysis
scaling_model = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Run scaling analysis
scaling_results = scaling_analyzer.analyze_batch_scaling(scaling_model, (1024,), [1, 2, 4, 8, 16, 32, 64, 128])
scaling_analyzer.plot_scaling_analysis(scaling_results, "Deep Neural Network Batch Scaling")