# Interactive Compiler Exploration

This notebook provides an interactive environment to explore torch.compile internals, including TorchDynamo, AOTAutograd, PrimTorch, and TorchInductor components.

In [None]:
import torch
import torch.nn as nn
import torch._dynamo as dynamo
from torch._inductor import config as inductor_config
from torch.fx import GraphModule
import matplotlib.pyplot as plt
import networkx as nx
from typing import Dict, List, Any
import graphviz
import tempfile
import os

## Interactive Model Creation

In [None]:
class InteractiveModelBuilder:
    """Build models interactively for compiler exploration"""
    
    def __init__(self):
        self.layers = []
        self.current_dim = None
    
    def add_linear(self, output_dim: int, activation: str = "relu"):
        if self.current_dim is None:
            raise ValueError("Set input dimension first with set_input_dim()")
        
        self.layers.append(nn.Linear(self.current_dim, output_dim))
        
        if activation == "relu":
            self.layers.append(nn.ReLU())
        elif activation == "gelu":
            self.layers.append(nn.GELU())
        elif activation == "tanh":
            self.layers.append(nn.Tanh())
        elif activation != "none":
            raise ValueError(f"Unknown activation: {activation}")
        
        self.current_dim = output_dim
        print(f"Added Linear({self.layers[-2].in_features} -> {output_dim}) + {activation}")
        return self
    
    def add_dropout(self, p: float = 0.1):
        self.layers.append(nn.Dropout(p))
        print(f"Added Dropout(p={p})")
        return self
    
    def add_batchnorm(self):
        if self.current_dim is None:
            raise ValueError("Cannot add BatchNorm without knowing current dimension")
        self.layers.append(nn.BatchNorm1d(self.current_dim))
        print(f"Added BatchNorm1d({self.current_dim})")
        return self
    
    def set_input_dim(self, dim: int):
        self.current_dim = dim
        print(f"Set input dimension to {dim}")
        return self
    
    def build(self):
        if not self.layers:
            raise ValueError("No layers added")
        model = nn.Sequential(*self.layers)
        print(f"\nBuilt model with {len(self.layers)} layers")
        print(model)
        return model
    
    def reset(self):
        self.layers = []
        self.current_dim = None
        print("Model builder reset")
        return self

# Interactive model builder
builder = InteractiveModelBuilder()

# Example: Build a simple model
model = (builder
         .set_input_dim(784)
         .add_linear(512, "relu")
         .add_dropout(0.2)
         .add_linear(256, "relu")
         .add_batchnorm()
         .add_linear(10, "none")
         .build())

## TorchDynamo Graph Extraction

In [None]:
class GraphExtractor:
    """Extract and visualize graphs from TorchDynamo"""
    
    def __init__(self):
        self.extracted_graphs = []
        self.original_backend = None
    
    def extract_graph_backend(self, graph_module: GraphModule, example_inputs):
        """Custom backend that captures graphs"""
        self.extracted_graphs.append({
            'graph_module': graph_module,
            'example_inputs': example_inputs,
            'nodes': list(graph_module.graph.nodes),
            'code': graph_module.code
        })
        
        # Still compile normally
        from torch._inductor.compile_fx import compile_fx
        return compile_fx(graph_module, example_inputs)
    
    def extract_graphs(self, model: nn.Module, input_tensor: torch.Tensor):
        """Extract graphs from model compilation"""
        self.extracted_graphs = []
        
        # Compile with our custom backend
        compiled_model = torch.compile(model, backend=self.extract_graph_backend)
        
        # Run once to trigger compilation
        with torch.no_grad():
            _ = compiled_model(input_tensor)
        
        return self.extracted_graphs
    
    def print_graph_info(self, graph_idx: int = 0):
        """Print information about extracted graph"""
        if not self.extracted_graphs:
            print("No graphs extracted. Run extract_graphs() first.")
            return
        
        if graph_idx >= len(self.extracted_graphs):
            print(f"Graph index {graph_idx} out of range. Available: 0-{len(self.extracted_graphs)-1}")
            return
        
        graph_info = self.extracted_graphs[graph_idx]
        nodes = graph_info['nodes']
        
        print(f"Graph {graph_idx} Information:")
        print(f"Total nodes: {len(nodes)}")
        
        # Count node types
        node_types = {}
        for node in nodes:
            node_type = node.op
            node_types[node_type] = node_types.get(node_type, 0) + 1
        
        print("Node types:")
        for node_type, count in node_types.items():
            print(f"  {node_type}: {count}")
        
        print("\nNode details:")
        for node in nodes:
            if node.op == 'call_function':
                print(f"  {node.name}: {node.target.__name__ if hasattr(node.target, '__name__') else node.target}")
            else:
                print(f"  {node.name}: {node.op}")
    
    def visualize_graph(self, graph_idx: int = 0, save_path: str = None):
        """Visualize graph structure"""
        if not self.extracted_graphs:
            print("No graphs extracted. Run extract_graphs() first.")
            return
        
        graph_info = self.extracted_graphs[graph_idx]
        nodes = graph_info['nodes']
        
        # Create directed graph
        G = nx.DiGraph()
        
        # Add nodes
        for node in nodes:
            label = node.name
            if node.op == 'call_function' and hasattr(node.target, '__name__'):
                label += f"\n{node.target.__name__}"
            G.add_node(node.name, label=label)
        
        # Add edges
        for node in nodes:
            for input_node in node.args:
                if hasattr(input_node, 'name'):
                    G.add_edge(input_node.name, node.name)
        
        # Plot
        plt.figure(figsize=(12, 8))
        pos = nx.spring_layout(G, k=2, iterations=50)
        
        # Draw graph
        nx.draw(G, pos, with_labels=True, node_color='lightblue', 
                node_size=3000, font_size=8, font_weight='bold',
                arrows=True, arrowsize=20, edge_color='gray')
        
        plt.title(f"Graph {graph_idx} Structure")
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()

# Example usage
extractor = GraphExtractor()
input_tensor = torch.randn(32, 784)

# Extract graphs
graphs = extractor.extract_graphs(model, input_tensor)
print(f"Extracted {len(graphs)} graphs")

# Analyze first graph
extractor.print_graph_info(0)

## AOTAutograd Exploration

In [None]:
class AOTAutogradExplorer:
    """Explore AOTAutograd behavior and backward graph generation"""
    
    def __init__(self):
        self.forward_graphs = []
        self.backward_graphs = []
    
    def aot_backend(self, graph_module: GraphModule, example_inputs):
        """Custom AOT backend to capture forward and backward graphs"""
        print(f"AOT Backend called with graph: {len(list(graph_module.graph.nodes))} nodes")
        
        # Store the graph
        self.forward_graphs.append({
            'graph_module': graph_module,
            'nodes': list(graph_module.graph.nodes),
            'is_backward': 'backward' in str(graph_module.graph)
        })
        
        # Return the original function for execution
        return graph_module.forward
    
    def explore_aot_autograd(self, model: nn.Module, input_tensor: torch.Tensor):
        """Explore AOT autograd graph generation"""
        from torch._functorch.aot_autograd import aot_module_simplified
        
        # Reset storage
        self.forward_graphs = []
        self.backward_graphs = []
        
        # Create AOT compiled version
        def forward_fn(params, inputs):
            return torch.func.functional_call(model, params, inputs)
        
        # Get model parameters
        params = dict(model.named_parameters())
        
        print("Analyzing AOT Autograd behavior...")
        
        # Compile with AOT
        compiled_model = torch.compile(model, backend=self.aot_backend)
        
        # Forward pass
        input_tensor.requires_grad_(True)
        output = compiled_model(input_tensor)
        
        # Backward pass to generate backward graph
        loss = output.sum()
        loss.backward()
        
        print(f"Captured {len(self.forward_graphs)} graphs")
        
        return self.forward_graphs
    
    def analyze_gradient_flow(self):
        """Analyze gradient flow through captured graphs"""
        if not self.forward_graphs:
            print("No graphs captured. Run explore_aot_autograd() first.")
            return
        
        for i, graph_info in enumerate(self.forward_graphs):
            nodes = graph_info['nodes']
            print(f"\nGraph {i}:")
            
            # Count gradient-related operations
            grad_ops = 0
            for node in nodes:
                if node.op == 'call_function':
                    func_name = str(node.target)
                    if any(keyword in func_name.lower() for keyword in ['grad', 'backward', 'autograd']):
                        grad_ops += 1
                        print(f"  Gradient op: {func_name}")
            
            print(f"  Total gradient operations: {grad_ops}")
            print(f"  Total nodes: {len(nodes)}")

# Example usage
aot_explorer = AOTAutogradExplorer()

# Create a model that requires gradients
grad_model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

input_tensor = torch.randn(32, 784, requires_grad=True)

# Explore AOT autograd
aot_graphs = aot_explorer.explore_aot_autograd(grad_model, input_tensor)
aot_explorer.analyze_gradient_flow()

## PrimTorch Operation Analysis

In [None]:
class PrimTorchAnalyzer:
    """Analyze primitive operations generated by PrimTorch"""
    
    def __init__(self):
        self.operation_counts = {}
        self.primitive_ops = set()
    
    def analyze_primitives(self, model: nn.Module, input_tensor: torch.Tensor):
        """Analyze primitive operations in compiled model"""
        from torch._decomp import get_decompositions
        from torch._refs import get_decompositions as get_ref_decompositions
        
        # Get available decompositions
        decompositions = get_decompositions([torch.ops.aten])
        ref_decompositions = get_ref_decompositions()
        
        print(f"Available decompositions: {len(decompositions)}")
        print(f"Reference decompositions: {len(ref_decompositions)}")
        
        # Create custom backend to capture primitive ops
        def primitive_analyzer_backend(graph_module, example_inputs):
            nodes = list(graph_module.graph.nodes)
            
            for node in nodes:
                if node.op == 'call_function':
                    op_name = str(node.target)
                    self.operation_counts[op_name] = self.operation_counts.get(op_name, 0) + 1
                    
                    # Check if it's a primitive operation
                    if 'aten' in op_name or 'prims' in op_name:
                        self.primitive_ops.add(op_name)
            
            # Return original compiled function
            from torch._inductor.compile_fx import compile_fx
            return compile_fx(graph_module, example_inputs)
        
        # Compile with primitive analysis
        compiled_model = torch.compile(model, backend=primitive_analyzer_backend)
        
        # Execute to trigger compilation
        with torch.no_grad():
            _ = compiled_model(input_tensor)
        
        self.print_primitive_analysis()
    
    def print_primitive_analysis(self):
        """Print analysis of primitive operations"""
        print("\n=== Primitive Operation Analysis ===")
        print(f"Total unique operations: {len(self.operation_counts)}")
        print(f"Primitive operations: {len(self.primitive_ops)}")
        
        # Sort operations by frequency
        sorted_ops = sorted(self.operation_counts.items(), key=lambda x: x[1], reverse=True)
        
        print("\nMost frequent operations:")
        for op, count in sorted_ops[:10]:
            is_primitive = "[PRIMITIVE]" if op in self.primitive_ops else ""
            print(f"  {op}: {count} {is_primitive}")
        
        # Categorize operations
        categories = {
            'aten': [],
            'prims': [],
            'torch': [],
            'other': []
        }
        
        for op in self.operation_counts:
            if 'aten' in op:
                categories['aten'].append(op)
            elif 'prims' in op:
                categories['prims'].append(op)
            elif 'torch' in op:
                categories['torch'].append(op)
            else:
                categories['other'].append(op)
        
        print("\nOperation categories:")
        for category, ops in categories.items():
            if ops:
                print(f"  {category}: {len(ops)} operations")
    
    def compare_before_after_decomposition(self, model: nn.Module, input_tensor: torch.Tensor):
        """Compare operations before and after decomposition"""
        # Analyze without decomposition
        self.operation_counts = {}
        self.primitive_ops = set()
        
        print("Analyzing WITHOUT decomposition...")
        basic_compiled = torch.compile(model, backend="eager")
        with torch.no_grad():
            _ = basic_compiled(input_tensor)
        
        ops_before = dict(self.operation_counts)
        
        # Reset and analyze with full compilation
        self.operation_counts = {}
        self.primitive_ops = set()
        
        print("\nAnalyzing WITH full compilation...")
        self.analyze_primitives(model, input_tensor)
        
        ops_after = dict(self.operation_counts)
        
        print("\n=== Comparison ===")
        print(f"Operations before: {len(ops_before)}")
        print(f"Operations after: {len(ops_after)}")
        
        # Show new operations introduced by decomposition
        new_ops = set(ops_after.keys()) - set(ops_before.keys())
        if new_ops:
            print(f"\nNew operations after decomposition ({len(new_ops)}):")
            for op in sorted(new_ops):
                print(f"  {op}: {ops_after[op]}")

# Example usage
prim_analyzer = PrimTorchAnalyzer()

# Analyze a complex model
complex_model = nn.Sequential(
    nn.Linear(784, 512),
    nn.LayerNorm(512),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

input_tensor = torch.randn(32, 784)
prim_analyzer.analyze_primitives(complex_model, input_tensor)

## TorchInductor Code Generation

In [None]:
class InductorExplorer:
    """Explore TorchInductor code generation"""
    
    def __init__(self):
        self.generated_code = []
        self.kernel_info = []
    
    def capture_generated_code(self, model: nn.Module, input_tensor: torch.Tensor):
        """Capture generated code from TorchInductor"""
        import torch._inductor.config as inductor_config
        
        # Enable debug mode to capture generated code
        old_debug = inductor_config.debug
        old_trace = inductor_config.trace.enabled
        
        inductor_config.debug = True
        inductor_config.trace.enabled = True
        
        try:
            # Create temporary directory for generated code
            with tempfile.TemporaryDirectory() as temp_dir:
                inductor_config.output_code = True
                inductor_config.debug_dir = temp_dir
                
                # Compile model
                compiled_model = torch.compile(model, mode="max-autotune")
                
                # Execute to trigger code generation
                with torch.no_grad():
                    output = compiled_model(input_tensor)
                
                # Collect generated files
                generated_files = []
                for root, dirs, files in os.walk(temp_dir):
                    for file in files:
                        if file.endswith(('.py', '.cpp', '.cu')):
                            file_path = os.path.join(root, file)
                            try:
                                with open(file_path, 'r') as f:
                                    content = f.read()
                                generated_files.append({
                                    'filename': file,
                                    'content': content,
                                    'size': len(content)
                                })
                            except Exception as e:
                                print(f"Could not read {file}: {e}")
                
                self.generated_code = generated_files
                
        finally:
            # Restore original config
            inductor_config.debug = old_debug
            inductor_config.trace.enabled = old_trace
        
        print(f"Captured {len(self.generated_code)} generated files")
        for file_info in self.generated_code:
            print(f"  {file_info['filename']}: {file_info['size']} characters")
    
    def analyze_generated_kernels(self):
        """Analyze generated GPU kernels"""
        if not self.generated_code:
            print("No generated code available. Run capture_generated_code() first.")
            return
        
        print("\n=== Generated Kernel Analysis ===")
        
        for file_info in self.generated_code:
            content = file_info['content']
            filename = file_info['filename']
            
            print(f"\nFile: {filename}")
            
            # Count different types of operations
            triton_kernels = content.count('@triton.jit')
            cuda_kernels = content.count('__global__')
            cpu_kernels = content.count('def cpp_fused')
            
            print(f"  Triton kernels: {triton_kernels}")
            print(f"  CUDA kernels: {cuda_kernels}")
            print(f"  CPU kernels: {cpu_kernels}")
            
            # Look for specific optimizations
            if 'vectorized' in content.lower():
                print(f"  ✓ Vectorization detected")
            if 'fused' in content.lower():
                print(f"  ✓ Operator fusion detected")
            if 'tl.load' in content or 'tl.store' in content:
                print(f"  ✓ Triton memory operations detected")
    
    def show_sample_kernel(self, file_index: int = 0, max_lines: int = 50):
        """Show sample generated kernel code"""
        if not self.generated_code or file_index >= len(self.generated_code):
            print("No code available or invalid index")
            return
        
        file_info = self.generated_code[file_index]
        content = file_info['content']
        lines = content.split('\n')
        
        print(f"\n=== Sample from {file_info['filename']} ===")
        print(f"Total lines: {len(lines)}")
        print(f"Showing first {min(max_lines, len(lines))} lines:")
        print("-" * 60)
        
        for i, line in enumerate(lines[:max_lines], 1):
            print(f"{i:3d}: {line}")
        
        if len(lines) > max_lines:
            print(f"... ({len(lines) - max_lines} more lines)")
        
        print("-" * 60)
    
    def compare_compilation_modes(self, model: nn.Module, input_tensor: torch.Tensor):
        """Compare code generation across different compilation modes"""
        modes = ["default", "reduce-overhead", "max-autotune"]
        mode_results = {}
        
        for mode in modes:
            print(f"\nAnalyzing mode: {mode}")
            
            # Reset state
            torch._dynamo.reset()
            self.generated_code = []
            
            # Compile with specific mode
            compiled_model = torch.compile(model, mode=mode)
            
            # Execute
            with torch.no_grad():
                _ = compiled_model(input_tensor)
            
            # Get compilation stats
            stats = torch._dynamo.utils.counters
            mode_results[mode] = {
                'frames_ok': stats.get('frames', {}).get('ok', 0),
                'graph_breaks': sum(stats.get('graph_break', {}).values()),
            }
        
        print("\n=== Mode Comparison ===")
        for mode, stats in mode_results.items():
            print(f"{mode}:")
            print(f"  Successful frames: {stats['frames_ok']}")
            print(f"  Graph breaks: {stats['graph_breaks']}")

# Example usage
inductor_explorer = InductorExplorer()

# Create a model with various operations
inductor_model = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.GELU(),
    nn.Linear(256, 10)
)

input_tensor = torch.randn(64, 784)

# Capture and analyze generated code
inductor_explorer.capture_generated_code(inductor_model, input_tensor)
inductor_explorer.analyze_generated_kernels()

# Show sample generated code
if inductor_explorer.generated_code:
    inductor_explorer.show_sample_kernel(0, 30)

## Interactive Compilation Dashboard

In [None]:
class CompilationDashboard:
    """Interactive dashboard for exploring compilation behavior"""
    
    def __init__(self):
        self.model_builder = InteractiveModelBuilder()
        self.graph_extractor = GraphExtractor()
        self.aot_explorer = AOTAutogradExplorer()
        self.prim_analyzer = PrimTorchAnalyzer()
        self.inductor_explorer = InductorExplorer()
    
    def full_analysis(self, model: nn.Module, input_tensor: torch.Tensor):
        """Run complete compilation analysis pipeline"""
        print("🔍 TORCH.COMPILE FULL ANALYSIS")
        print("=" * 60)
        
        # 1. Graph extraction
        print("\n1. 📊 Extracting computation graphs...")
        graphs = self.graph_extractor.extract_graphs(model, input_tensor)
        self.graph_extractor.print_graph_info(0)
        
        # 2. AOT Autograd analysis
        print("\n2. 🔄 Analyzing AOT Autograd...")
        aot_graphs = self.aot_explorer.explore_aot_autograd(model, input_tensor.clone().detach().requires_grad_(True))
        self.aot_explorer.analyze_gradient_flow()
        
        # 3. Primitive operation analysis
        print("\n3. 🔧 Analyzing primitive operations...")
        self.prim_analyzer.analyze_primitives(model, input_tensor)
        
        # 4. Code generation analysis
        print("\n4. ⚡ Analyzing generated code...")
        self.inductor_explorer.capture_generated_code(model, input_tensor)
        self.inductor_explorer.analyze_generated_kernels()
        
        print("\n✅ Analysis complete!")
        
        return {
            'graphs': graphs,
            'aot_graphs': aot_graphs,
            'primitives': dict(self.prim_analyzer.operation_counts),
            'generated_code': self.inductor_explorer.generated_code
        }
    
    def interactive_model_builder(self):
        """Interactive model building session"""
        print("🏗️  INTERACTIVE MODEL BUILDER")
        print("Available commands:")
        print("  .set_input_dim(dim)")
        print("  .add_linear(output_dim, activation='relu')")
        print("  .add_dropout(p=0.1)")
        print("  .add_batchnorm()")
        print("  .build()")
        print("  .reset()")
        
        return self.model_builder
    
    def quick_comparison(self, model: nn.Module, input_tensor: torch.Tensor):
        """Quick performance comparison across modes"""
        from triton.testing import do_bench
        
        print("⚡ QUICK PERFORMANCE COMPARISON")
        print("=" * 50)
        
        modes = ["eager", "default", "reduce-overhead", "max-autotune"]
        results = {}
        
        for mode in modes:
            if mode == "eager":
                compiled_model = model
            else:
                torch._dynamo.reset()
                compiled_model = torch.compile(model, mode=mode)
            
            # Benchmark
            time_ms = do_bench(lambda: compiled_model(input_tensor), warmup=50, rep=100)
            results[mode] = time_ms
            
            print(f"{mode:15s}: {time_ms:8.3f} ms")
        
        # Calculate speedups
        eager_time = results["eager"]
        print("\nSpeedups:")
        for mode in modes[1:]:
            speedup = eager_time / results[mode]
            print(f"{mode:15s}: {speedup:8.2f}x")
        
        return results

# Create dashboard
dashboard = CompilationDashboard()

# Example: Quick comparison
test_model = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

test_input = torch.randn(64, 1024)

# Run quick comparison
perf_results = dashboard.quick_comparison(test_model, test_input)

print("\n" + "="*60)
print("🎯 Use dashboard.full_analysis(model, input) for complete analysis")
print("🏗️  Use dashboard.interactive_model_builder() for building models")
print("="*60)