# Advanced Graph Optimization Techniques

This notebook explores advanced graph optimization techniques beyond basic torch.compile, including custom operators, graph transformations, and manual optimization strategies.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import GraphModule, symbolic_trace
from torch.fx.passes import shape_prop
import torch._dynamo as dynamo
from triton.testing import do_bench
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Any, Callable
import time

## Custom Operator Development

In [None]:
class CustomOperatorBuilder:
    """Builder for creating custom optimized operators"""
    
    @staticmethod
    def fused_linear_gelu(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
        """Fused linear + GELU operation"""
        # This would typically be implemented in C++/CUDA for best performance
        linear_output = F.linear(input_tensor, weight, bias)
        return F.gelu(linear_output)
    
    @staticmethod
    def fused_conv_bn_relu(input_tensor: torch.Tensor, conv_weight: torch.Tensor, 
                          bn_weight: torch.Tensor, bn_bias: torch.Tensor,
                          bn_mean: torch.Tensor, bn_var: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
        """Fused convolution + batch norm + ReLU"""
        # Convolution
        conv_out = F.conv2d(input_tensor, conv_weight)
        
        # Batch normalization
        bn_out = F.batch_norm(conv_out, bn_mean, bn_var, bn_weight, bn_bias, 
                             training=False, eps=eps)
        
        # ReLU
        return F.relu(bn_out)
    
    @staticmethod
    def optimized_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 
                           mask: torch.Tensor = None) -> torch.Tensor:
        """Memory-efficient attention implementation"""
        batch_size, seq_len, d_model = query.shape
        d_k = query.shape[-1]
        
        # Scaled dot-product attention with optional masking
        scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        
        return output
    
    def create_custom_op_model(self, input_dim: int = 512, hidden_dim: int = 1024, output_dim: int = 256):
        """Create a model using custom operators"""
        
        class CustomOpModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear1_weight = nn.Parameter(torch.randn(hidden_dim, input_dim))
                self.linear1_bias = nn.Parameter(torch.randn(hidden_dim))
                self.linear2 = nn.Linear(hidden_dim, output_dim)
            
            def forward(self, x):
                # Use custom fused linear + GELU
                x = CustomOperatorBuilder.fused_linear_gelu(x, self.linear1_weight, self.linear1_bias)
                x = self.linear2(x)
                return x
        
        return CustomOpModel()
    
    def benchmark_custom_ops(self, batch_size: int = 64, input_dim: int = 512):
        """Benchmark custom operators vs standard implementation"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Standard model
        standard_model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.GELU(),
            nn.Linear(1024, 256)
        ).to(device)
        
        # Custom operator model
        custom_model = self.create_custom_op_model(input_dim).to(device)
        
        # Test input
        input_tensor = torch.randn(batch_size, input_dim, device=device)
        
        results = {}
        
        # Benchmark standard model
        print("Benchmarking standard model...")
        standard_eager_time = do_bench(lambda: standard_model(input_tensor), warmup=50, rep=200)
        standard_compiled = torch.compile(standard_model)
        standard_compiled_time = do_bench(lambda: standard_compiled(input_tensor), warmup=50, rep=200)
        
        results['standard'] = {
            'eager': standard_eager_time,
            'compiled': standard_compiled_time
        }
        
        # Benchmark custom model
        print("Benchmarking custom operator model...")
        custom_eager_time = do_bench(lambda: custom_model(input_tensor), warmup=50, rep=200)
        custom_compiled = torch.compile(custom_model)
        custom_compiled_time = do_bench(lambda: custom_compiled(input_tensor), warmup=50, rep=200)
        
        results['custom'] = {
            'eager': custom_eager_time,
            'compiled': custom_compiled_time
        }
        
        return results

# Test custom operators
custom_op_builder = CustomOperatorBuilder()
custom_op_results = custom_op_builder.benchmark_custom_ops()

print("\nCustom Operator Benchmark Results:")
print("="*50)
for model_type, times in custom_op_results.items():
    eager_time = times['eager']
    compiled_time = times['compiled']
    speedup = eager_time / compiled_time
    print(f"{model_type.capitalize()} model:")
    print(f"  Eager: {eager_time:.3f} ms")
    print(f"  Compiled: {compiled_time:.3f} ms")
    print(f"  Speedup: {speedup:.2f}x")
    print()

## Graph Pattern Matching and Replacement

In [None]:
class GraphPatternOptimizer:
    """Advanced graph pattern matching and optimization"""
    
    def __init__(self):
        self.optimization_patterns = []
        self.applied_optimizations = []
    
    def register_pattern(self, pattern_name: str, matcher: Callable, replacer: Callable):
        """Register a custom optimization pattern"""
        self.optimization_patterns.append({
            'name': pattern_name,
            'matcher': matcher,
            'replacer': replacer
        })
    
    def linear_relu_pattern_matcher(self, graph: GraphModule) -> List[Tuple]:
        """Find Linear -> ReLU patterns in the graph"""
        matches = []
        nodes = list(graph.graph.nodes)
        
        for i, node in enumerate(nodes[:-1]):
            next_node = nodes[i + 1]
            
            # Check for Linear -> ReLU pattern
            if (node.op == 'call_function' and 
                hasattr(node.target, '__name__') and 'linear' in node.target.__name__.lower() and
                next_node.op == 'call_function' and
                hasattr(next_node.target, '__name__') and 'relu' in next_node.target.__name__.lower()):
                
                matches.append((node, next_node))
        
        return matches
    
    def fused_linear_relu_replacer(self, graph: GraphModule, linear_node, relu_node):
        """Replace Linear + ReLU with fused operation"""
        # This is a simplified example - real implementation would be more complex
        with graph.graph.inserting_after(linear_node):
            # Create a new fused node
            fused_node = graph.graph.call_function(
                target=CustomOperatorBuilder.fused_linear_gelu,  # Placeholder for fused op
                args=linear_node.args
            )
            
            # Replace all uses of relu_node with fused_node
            relu_node.replace_all_uses_with(fused_node)
        
        # Remove the original nodes
        graph.graph.erase_node(relu_node)
        graph.graph.erase_node(linear_node)
        
        return fused_node
    
    def analyze_graph_structure(self, model: nn.Module, input_tensor: torch.Tensor) -> Dict[str, Any]:
        """Analyze the structure of a model's computation graph"""
        # Trace the model to get FX graph
        traced_model = symbolic_trace(model)
        
        # Add shape information
        shape_prop.ShapeProp(traced_model).propagate(input_tensor)
        
        nodes = list(traced_model.graph.nodes)
        
        analysis = {
            'total_nodes': len(nodes),
            'node_types': {},
            'optimization_opportunities': [],
            'memory_usage_estimate': 0,
            'compute_intensity': 0
        }
        
        # Analyze node types
        for node in nodes:
            node_type = f"{node.op}:{getattr(node.target, '__name__', str(node.target))}"
            analysis['node_types'][node_type] = analysis['node_types'].get(node_type, 0) + 1
        
        # Look for optimization opportunities
        linear_relu_matches = self.linear_relu_pattern_matcher(traced_model)
        if linear_relu_matches:
            analysis['optimization_opportunities'].append({
                'pattern': 'Linear -> ReLU fusion',
                'count': len(linear_relu_matches),
                'potential_speedup': '10-20%'
            })
        
        # Estimate memory usage and compute intensity
        for node in nodes:
            if hasattr(node, 'meta') and 'tensor_meta' in node.meta:
                tensor_meta = node.meta['tensor_meta']
                if hasattr(tensor_meta, 'shape') and hasattr(tensor_meta, 'dtype'):
                    # Rough memory estimate
                    numel = 1
                    for dim in tensor_meta.shape:
                        numel *= dim
                    
                    dtype_size = 4 if tensor_meta.dtype in [torch.float32, torch.int32] else 2
                    analysis['memory_usage_estimate'] += numel * dtype_size
        
        # Convert bytes to MB
        analysis['memory_usage_estimate'] /= (1024 * 1024)
        
        return analysis, traced_model
    
    def apply_graph_optimizations(self, model: nn.Module, input_tensor: torch.Tensor) -> Tuple[nn.Module, Dict]:
        """Apply registered optimization patterns to a model"""
        analysis, traced_model = self.analyze_graph_structure(model, input_tensor)
        
        optimizations_applied = []
        
        # Apply each registered pattern
        for pattern in self.optimization_patterns:
            matches = pattern['matcher'](traced_model)
            
            if matches:
                print(f"Applying optimization: {pattern['name']} ({len(matches)} matches)")
                
                for match in matches:
                    try:
                        pattern['replacer'](traced_model, *match)
                        optimizations_applied.append(pattern['name'])
                    except Exception as e:
                        print(f"Failed to apply {pattern['name']}: {e}")
        
        # Recompile the graph
        traced_model.recompile()
        
        optimization_summary = {
            'original_analysis': analysis,
            'applied_optimizations': optimizations_applied,
            'optimization_count': len(optimizations_applied)
        }
        
        return traced_model, optimization_summary
    
    def visualize_graph(self, model: nn.Module, input_tensor: torch.Tensor, title: str = "Model Graph"):
        """Visualize the computation graph"""
        traced_model = symbolic_trace(model)
        
        # Get graph representation
        graph_str = str(traced_model.graph)
        nodes = list(traced_model.graph.nodes)
        
        print(f"\n{title}")
        print("=" * len(title))
        print(f"Total nodes: {len(nodes)}")
        print("\nGraph structure:")
        
        for i, node in enumerate(nodes):
            indent = "  " * (i % 3)  # Simple indentation for readability
            node_info = f"{node.name}: {node.op}"
            if hasattr(node, 'target') and hasattr(node.target, '__name__'):
                node_info += f" -> {node.target.__name__}"
            elif hasattr(node, 'target'):
                node_info += f" -> {str(node.target)[:50]}..."
            
            print(f"{indent}{i:2d}. {node_info}")
        
        return traced_model

# Example usage
graph_optimizer = GraphPatternOptimizer()

# Register optimization patterns
graph_optimizer.register_pattern(
    "linear_relu_fusion",
    graph_optimizer.linear_relu_pattern_matcher,
    graph_optimizer.fused_linear_relu_replacer
)

# Create a test model
test_model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

input_tensor = torch.randn(32, 512)

# Analyze original model
print("Original Model Analysis:")
analysis, traced_model = graph_optimizer.analyze_graph_structure(test_model, input_tensor)

print(f"Total nodes: {analysis['total_nodes']}")
print(f"Memory estimate: {analysis['memory_usage_estimate']:.2f} MB")
print("\nOptimization opportunities:")
for opp in analysis['optimization_opportunities']:
    print(f"  - {opp['pattern']}: {opp['count']} instances, {opp['potential_speedup']} improvement")

# Visualize the graph
graph_optimizer.visualize_graph(test_model, input_tensor, "Original Model Graph")

## Advanced Compilation Strategies

In [None]:
class AdvancedCompilationStrategies:
    """Implement and benchmark advanced compilation strategies"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.compilation_strategies = {}
    
    def register_custom_backend(self, name: str, backend_fn: Callable):
        """Register a custom compilation backend"""
        self.compilation_strategies[name] = backend_fn
    
    def shape_specialized_compilation(self, model: nn.Module, input_shapes: List[Tuple]):
        """Compile model with shape specialization for different input sizes"""
        specialized_models = {}
        
        for i, shape in enumerate(input_shapes):
            print(f"Compiling for shape {shape}...")
            
            # Create sample input
            sample_input = torch.randn(*shape, device=self.device)
            
            # Compile with shape specialization
            dynamo.reset()
            specialized_model = torch.compile(
                model,
                mode="max-autotune",
                dynamic=False  # Force static shapes
            )
            
            # Warm up the compiled model
            for _ in range(5):
                with torch.no_grad():
                    _ = specialized_model(sample_input)
            
            specialized_models[shape] = specialized_model
        
        return specialized_models
    
    def dynamic_shape_compilation(self, model: nn.Module, min_shape: Tuple, max_shape: Tuple):
        """Compile model with dynamic shape support"""
        print(f"Compiling with dynamic shapes: {min_shape} to {max_shape}")
        
        # Mark dynamic dimensions
        sample_input = torch.randn(*max_shape, device=self.device)
        
        # Compile with dynamic shapes enabled
        dynamo.reset()
        dynamic_model = torch.compile(
            model,
            mode="reduce-overhead",
            dynamic=True
        )
        
        # Warm up with different shapes
        for batch_size in [min_shape[0], (min_shape[0] + max_shape[0]) // 2, max_shape[0]]:
            test_input = torch.randn(batch_size, *min_shape[1:], device=self.device)
            with torch.no_grad():
                _ = dynamic_model(test_input)
        
        return dynamic_model
    
    def profile_guided_optimization(self, model: nn.Module, profiling_inputs: List[torch.Tensor]):
        """Use profiling data to guide optimization decisions"""
        print("Running profile-guided optimization...")
        
        # Collect profiling data
        profile_data = {}
        
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            with_stack=True
        ) as prof:
            for input_tensor in profiling_inputs:
                with torch.no_grad():
                    _ = model(input_tensor)
        
        # Analyze profiling results
        key_averages = prof.key_averages(group_by_input_shape=True)
        
        # Find bottlenecks
        bottlenecks = []
        for avg in key_averages:
            if avg.cuda_time_total > 1000:  # Operations taking more than 1ms
                bottlenecks.append({
                    'name': avg.key,
                    'cuda_time': avg.cuda_time_total,
                    'cpu_time': avg.cpu_time_total,
                    'count': avg.count
                })
        
        profile_data['bottlenecks'] = sorted(bottlenecks, key=lambda x: x['cuda_time'], reverse=True)
        
        # Compile with profile-guided optimizations
        if profile_data['bottlenecks']:
            # Use max-autotune for models with identified bottlenecks
            optimized_model = torch.compile(model, mode="max-autotune")
        else:
            # Use default mode for simpler models
            optimized_model = torch.compile(model, mode="default")
        
        return optimized_model, profile_data
    
    def multi_stage_compilation(self, model: nn.Module, input_tensor: torch.Tensor):
        """Apply multi-stage compilation with different optimization levels"""
        compilation_stages = {
            'stage1_fast': {'mode': 'default', 'description': 'Fast compilation'},
            'stage2_balanced': {'mode': 'reduce-overhead', 'description': 'Balanced optimization'},
            'stage3_aggressive': {'mode': 'max-autotune', 'description': 'Aggressive optimization'}
        }
        
        stage_results = {}
        
        for stage_name, config in compilation_stages.items():
            print(f"Running {stage_name}: {config['description']}")
            
            # Reset compilation state
            dynamo.reset()
            
            # Measure compilation time
            compile_start = time.perf_counter()
            compiled_model = torch.compile(model, mode=config['mode'])
            compile_time = time.perf_counter() - compile_start
            
            # Measure first execution time (includes lazy compilation)
            first_exec_start = time.perf_counter()
            with torch.no_grad():
                output = compiled_model(input_tensor)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            first_exec_time = time.perf_counter() - first_exec_start
            
            # Measure steady-state performance
            steady_state_time = do_bench(
                lambda: compiled_model(input_tensor),
                warmup=20,
                rep=100
            )
            
            stage_results[stage_name] = {
                'compilation_time': compile_time * 1000,  # ms
                'first_execution_time': first_exec_time * 1000,  # ms
                'steady_state_time': steady_state_time,  # ms
                'total_setup_time': (compile_time + first_exec_time) * 1000,  # ms
                'mode': config['mode'],
                'description': config['description']
            }
        
        return stage_results
    
    def benchmark_compilation_strategies(self, model: nn.Module, input_shapes: List[Tuple]):
        """Comprehensive benchmark of different compilation strategies"""
        model = model.to(self.device)
        model.eval()
        
        benchmark_results = {}
        
        # 1. Standard compilation modes
        print("\n1. Testing standard compilation modes...")
        standard_input = torch.randn(*input_shapes[0], device=self.device)
        standard_results = self.multi_stage_compilation(model, standard_input)
        benchmark_results['standard_modes'] = standard_results
        
        # 2. Shape-specialized compilation
        print("\n2. Testing shape-specialized compilation...")
        specialized_models = self.shape_specialized_compilation(model, input_shapes[:3])
        
        shape_specialized_results = {}
        for shape, specialized_model in specialized_models.items():
            test_input = torch.randn(*shape, device=self.device)
            exec_time = do_bench(lambda: specialized_model(test_input), warmup=20, rep=100)
            shape_specialized_results[str(shape)] = exec_time
        
        benchmark_results['shape_specialized'] = shape_specialized_results
        
        # 3. Dynamic shape compilation
        print("\n3. Testing dynamic shape compilation...")
        min_shape = min(input_shapes, key=lambda x: x[0])
        max_shape = max(input_shapes, key=lambda x: x[0])
        
        dynamic_model = self.dynamic_shape_compilation(model, min_shape, max_shape)
        
        dynamic_results = {}
        for shape in input_shapes[:3]:
            test_input = torch.randn(*shape, device=self.device)
            exec_time = do_bench(lambda: dynamic_model(test_input), warmup=20, rep=100)
            dynamic_results[str(shape)] = exec_time
        
        benchmark_results['dynamic_shapes'] = dynamic_results
        
        # 4. Profile-guided optimization
        print("\n4. Testing profile-guided optimization...")
        profiling_inputs = [torch.randn(*shape, device=self.device) for shape in input_shapes[:2]]
        pgo_model, profile_data = self.profile_guided_optimization(model, profiling_inputs)
        
        pgo_time = do_bench(lambda: pgo_model(standard_input), warmup=20, rep=100)
        benchmark_results['profile_guided'] = {
            'execution_time': pgo_time,
            'bottlenecks_found': len(profile_data['bottlenecks'])
        }
        
        return benchmark_results
    
    def plot_strategy_comparison(self, benchmark_results: Dict):
        """Plot comparison of compilation strategies"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        # 1. Standard modes comparison
        standard_modes = benchmark_results['standard_modes']
        modes = list(standard_modes.keys())
        setup_times = [standard_modes[mode]['total_setup_time'] for mode in modes]
        steady_times = [standard_modes[mode]['steady_state_time'] for mode in modes]
        
        x = np.arange(len(modes))
        width = 0.35
        
        ax1.bar(x - width/2, setup_times, width, label='Setup Time', alpha=0.8)
        ax1.bar(x + width/2, steady_times, width, label='Execution Time', alpha=0.8)
        ax1.set_xlabel('Compilation Mode')
        ax1.set_ylabel('Time (ms)')
        ax1.set_title('Standard Compilation Modes')
        ax1.set_xticks(x)
        ax1.set_xticklabels([mode.replace('stage', '').replace('_', '\n') for mode in modes], rotation=45)
        ax1.legend()
        ax1.set_yscale('log')
        
        # 2. Shape specialization vs dynamic shapes
        if 'shape_specialized' in benchmark_results and 'dynamic_shapes' in benchmark_results:
            specialized_data = benchmark_results['shape_specialized']
            dynamic_data = benchmark_results['dynamic_shapes']
            
            shapes = list(specialized_data.keys())
            specialized_times = list(specialized_data.values())
            dynamic_times = [dynamic_data[shape] for shape in shapes if shape in dynamic_data]
            
            x = np.arange(len(shapes))
            ax2.bar(x - width/2, specialized_times, width, label='Shape Specialized', alpha=0.8)
            ax2.bar(x + width/2, dynamic_times, width, label='Dynamic Shapes', alpha=0.8)
            ax2.set_xlabel('Input Shape')
            ax2.set_ylabel('Execution Time (ms)')
            ax2.set_title('Shape Specialized vs Dynamic Compilation')
            ax2.set_xticks(x)
            ax2.set_xticklabels([shape.replace('(', '').replace(')', '').replace(', ', '\nx') for shape in shapes], rotation=45)
            ax2.legend()
        
        # 3. Compilation overhead analysis
        compile_times = [standard_modes[mode]['compilation_time'] for mode in modes]
        first_exec_times = [standard_modes[mode]['first_execution_time'] for mode in modes]
        
        ax3.bar(x - width/2, compile_times, width, label='Compilation', alpha=0.8)
        ax3.bar(x + width/2, first_exec_times, width, label='First Execution', alpha=0.8)
        ax3.set_xlabel('Compilation Mode')
        ax3.set_ylabel('Time (ms)')
        ax3.set_title('Compilation Overhead Breakdown')
        ax3.set_xticks(x)
        ax3.set_xticklabels([mode.replace('stage', '').replace('_', '\n') for mode in modes], rotation=45)
        ax3.legend()
        ax3.set_yscale('log')
        
        # 4. Overall strategy comparison
        strategy_names = ['Standard', 'Specialized', 'Dynamic', 'PGO']
        strategy_times = [
            standard_modes['stage2_balanced']['steady_state_time'],
            min(benchmark_results['shape_specialized'].values()) if 'shape_specialized' in benchmark_results else 0,
            min(benchmark_results['dynamic_shapes'].values()) if 'dynamic_shapes' in benchmark_results else 0,
            benchmark_results['profile_guided']['execution_time'] if 'profile_guided' in benchmark_results else 0
        ]
        
        bars = ax4.bar(strategy_names, strategy_times, alpha=0.8, color=['blue', 'green', 'orange', 'red'])
        ax4.set_ylabel('Execution Time (ms)')
        ax4.set_title('Strategy Comparison (Best Case)')
        
        # Add value labels
        for bar, time in zip(bars, strategy_times):
            if time > 0:
                ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(strategy_times)*0.01,
                        f'{time:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()

# Example usage
compilation_strategies = AdvancedCompilationStrategies()

# Create a representative model
strategy_test_model = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.BatchNorm1d(512),
    nn.Linear(512, 256),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(256, 10)
)

# Test different input shapes
test_shapes = [(16, 1024), (32, 1024), (64, 1024), (128, 1024)]

# Run comprehensive benchmark
strategy_results = compilation_strategies.benchmark_compilation_strategies(strategy_test_model, test_shapes)

# Plot results
compilation_strategies.plot_strategy_comparison(strategy_results)