# Debugging torch.compile Errors and Fallbacks

This notebook provides comprehensive guidance on debugging compilation errors, understanding fallback mechanisms, and troubleshooting common issues with torch.compile.

In [None]:
import torch
import torch.nn as nn
import torch._dynamo as dynamo
import logging
import warnings
from torch._dynamo.utils import CompileProfiler
from contextlib import contextmanager

## Setting Up Debugging Environment

In [None]:
# Enable detailed logging
torch._dynamo.config.log_level = logging.INFO
torch._dynamo.config.verbose = True

# Configure debugging flags
torch._dynamo.config.suppress_errors = False
torch._dynamo.config.print_specializations = True

@contextmanager
def debug_compile():
    """Context manager for enhanced debugging"""
    original_log_level = torch._dynamo.config.log_level
    original_verbose = torch._dynamo.config.verbose
    
    torch._dynamo.config.log_level = logging.DEBUG
    torch._dynamo.config.verbose = True
    
    try:
        yield
    finally:
        torch._dynamo.config.log_level = original_log_level
        torch._dynamo.config.verbose = original_verbose

def compilation_counter():
    """Track compilation statistics"""
    stats = dynamo.utils.counters["frames"]
    print("Compilation Statistics:")
    for key, value in stats.items():
        print(f"  {key}: {value}")
    
    # Reset counters
    dynamo.utils.counters.clear()

## Common Problematic Patterns

In [None]:
class ProblematicModel(nn.Module):
    """Model with patterns that can cause compilation issues"""
    
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.dynamic_list = []
    
    def forward_with_python_list(self, x):
        """Issue: Dynamic Python list modification"""
        result = self.linear(x)
        self.dynamic_list.append(result.sum().item())  # Problematic
        return result
    
    def forward_with_control_flow(self, x):
        """Issue: Data-dependent control flow"""
        result = self.linear(x)
        if result.sum() > 0:  # Data-dependent condition
            return result * 2
        else:
            return result * -1
    
    def forward_with_external_function(self, x):
        """Issue: External function call"""
        import numpy as np  # Import inside function
        result = self.linear(x)
        # Convert to numpy and back (unsupported pattern)
        np_result = result.detach().cpu().numpy()
        return torch.from_numpy(np.sin(np_result)).to(x.device)
    
    def forward_good(self, x):
        """Good pattern: Pure tensor operations"""
        result = self.linear(x)
        return torch.sin(result)  # Use torch.sin instead of numpy

model = ProblematicModel()
x = torch.randn(2, 10)

## Debugging Compilation Failures

In [None]:
def test_compilation(model_fn, x, description):
    """Test compilation and catch/analyze errors"""
    print(f"\n=== Testing: {description} ===")
    
    try:
        # Reset dynamo state
        dynamo.reset()
        
        # Compile the function
        compiled_fn = torch.compile(model_fn)
        
        # Try to execute
        with debug_compile():
            result = compiled_fn(x)
        
        print(f"✅ Success: {description}")
        compilation_counter()
        
    except Exception as e:
        print(f"❌ Error: {description}")
        print(f"   Exception: {type(e).__name__}: {e}")
        
        # Print compilation statistics even on failure
        compilation_counter()
        
        # Analyze the error
        if "graph break" in str(e).lower():
            print("   This is likely a graph break issue")
        elif "unsupported" in str(e).lower():
            print("   This operation is not supported by the compiler")
        elif "dynamic" in str(e).lower():
            print("   This might be related to dynamic shapes or control flow")

# Test different patterns
test_compilation(model.forward_with_python_list, x, "Python list modification")
test_compilation(model.forward_with_control_flow, x, "Data-dependent control flow")
test_compilation(model.forward_with_external_function, x, "External function call")
test_compilation(model.forward_good, x, "Good pattern")

## Understanding Graph Breaks

In [None]:
def analyze_graph_breaks(model_fn, x, description):
    """Analyze and explain graph breaks"""
    print(f"\n=== Graph Break Analysis: {description} ===")
    
    # Enable graph break debugging
    dynamo.reset()
    
    # Capture graph breaks
    graph_breaks = []
    
    def graph_break_handler(frame, event, arg):
        if event == 'call':
            graph_breaks.append(f"Call to {frame.f_code.co_name}")
        return graph_break_handler
    
    try:
        compiled_fn = torch.compile(model_fn, fullgraph=False)  # Allow graph breaks
        result = compiled_fn(x)
        
        # Check for graph breaks in compilation
        break_reasons = dynamo.utils.counters.get("graph_break", {})
        if break_reasons:
            print("Graph breaks detected:")
            for reason, count in break_reasons.items():
                print(f"  {reason}: {count} times")
        else:
            print("No graph breaks detected")
            
    except Exception as e:
        print(f"Error during compilation: {e}")
    
    finally:
        dynamo.utils.counters.clear()

# Analyze graph breaks for different patterns
analyze_graph_breaks(model.forward_with_python_list, x, "Python list modification")
analyze_graph_breaks(model.forward_with_control_flow, x, "Data-dependent control flow")
analyze_graph_breaks(model.forward_with_external_function, x, "External function call")

## Fallback Mechanisms and Workarounds

In [None]:
class FallbackDemoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
    
    def problematic_forward(self, x):
        """Method that will trigger fallbacks"""
        result = self.linear(x)
        
        # This will cause a graph break
        if result.sum() > 0:
            print("Positive sum detected")  # Side effect
            return result * 2
        return result
    
    def workaround_forward(self, x):
        """Workaround version"""
        result = self.linear(x)
        
        # Use torch.where instead of if-else
        condition = result.sum() > 0
        return torch.where(condition, result * 2, result)

fallback_model = FallbackDemoModel()

print("Testing fallback behavior:")

# Test with different configurations
configs = [
    ("No compilation", lambda fn: fn),
    ("Compiled (allow breaks)", lambda fn: torch.compile(fn, fullgraph=False)),
    ("Compiled (require fullgraph)", lambda fn: torch.compile(fn, fullgraph=True))
]

for config_name, compile_fn in configs:
    print(f"\n--- {config_name} ---")
    
    try:
        dynamo.reset()
        compiled_model = compile_fn(fallback_model.problematic_forward)
        result = compiled_model(x)
        print(f"✅ Success with {config_name}")
        
    except Exception as e:
        print(f"❌ Failed with {config_name}: {e}")
        
        # Try workaround
        try:
            dynamo.reset()
            compiled_workaround = compile_fn(fallback_model.workaround_forward)
            result = compiled_workaround(x)
            print(f"✅ Workaround successful")
        except Exception as e2:
            print(f"❌ Workaround also failed: {e2}")

## Performance Impact of Fallbacks

In [None]:
from triton.testing import do_bench

def benchmark_fallback_impact(model_fn, workaround_fn, x):
    """Benchmark performance impact of fallbacks"""
    
    # Eager mode
    eager_time = do_bench(lambda: model_fn(x))
    
    # Compiled with fallbacks
    dynamo.reset()
    compiled_with_breaks = torch.compile(model_fn, fullgraph=False)
    compiled_time = do_bench(lambda: compiled_with_breaks(x))
    
    # Compiled workaround (should be faster)
    dynamo.reset()
    compiled_workaround = torch.compile(workaround_fn, fullgraph=True)
    workaround_time = do_bench(lambda: compiled_workaround(x))
    
    print(f"Eager mode: {eager_time:.4f} ms")
    print(f"Compiled with breaks: {compiled_time:.4f} ms")
    print(f"Compiled workaround: {workaround_time:.4f} ms")
    
    print(f"\nSpeedup analysis:")
    print(f"Compiled vs Eager: {eager_time/compiled_time:.2f}x")
    print(f"Workaround vs Eager: {eager_time/workaround_time:.2f}x")
    print(f"Workaround vs Compiled: {compiled_time/workaround_time:.2f}x")

# Run performance comparison
print("Performance impact of graph breaks:")
benchmark_fallback_impact(
    fallback_model.problematic_forward,
    fallback_model.workaround_forward,
    x
)

## Debugging Tools and Utilities

In [None]:
class CompilationDebugger:
    """Utility class for debugging compilation issues"""
    
    @staticmethod
    def explain_error(error_msg):
        """Provide explanations for common error patterns"""
        explanations = {
            "graph break": "The compiler encountered an operation it cannot handle and fell back to eager mode.",
            "dynamic shape": "The tensor shape varies at runtime, making compilation difficult.",
            "unsupported operator": "This PyTorch operation is not yet supported by the compiler.",
            "control flow": "Data-dependent branching cannot be compiled efficiently.",
            "side effect": "Operations with side effects (like print) cause graph breaks.",
            "python container": "Modifying Python lists/dicts during forward pass is not supported."
        }
        
        error_lower = error_msg.lower()
        for pattern, explanation in explanations.items():
            if pattern in error_lower:
                print(f"Likely issue: {explanation}")
                return
        
        print("Error pattern not recognized. Check torch._dynamo documentation.")
    
    @staticmethod
    def suggest_workarounds(error_msg):
        """Suggest workarounds for common issues"""
        suggestions = {
            "graph break": [
                "Use torch operations instead of Python operations",
                "Move side effects outside the compiled function",
                "Use torch.where instead of if-else statements"
            ],
            "dynamic shape": [
                "Use torch._dynamo.mark_dynamic to mark dynamic dimensions",
                "Pad tensors to fixed sizes",
                "Use bucketing for similar-sized inputs"
            ],
            "unsupported operator": [
                "Use alternative PyTorch operations",
                "Implement custom operations using torch.library",
                "Exclude the operation from compilation"
            ]
        }
        
        error_lower = error_msg.lower()
        for pattern, workarounds in suggestions.items():
            if pattern in error_lower:
                print(f"Suggested workarounds:")
                for i, workaround in enumerate(workarounds, 1):
                    print(f"  {i}. {workaround}")
                return
        
        print("No specific workarounds available. Try simplifying the model.")

# Example usage
debugger = CompilationDebugger()
print("Example error analysis:")
debugger.explain_error("Encountered graph break due to unsupported control flow")
debugger.suggest_workarounds("Encountered graph break due to unsupported control flow")

## Best Practices Summary

### ✅ Do:
1. Use pure tensor operations when possible
2. Replace data-dependent control flow with `torch.where`
3. Move side effects (print, logging) outside compiled functions
4. Use torch operations instead of numpy/python equivalents
5. Test with `fullgraph=True` to catch graph breaks early

### ❌ Avoid:
1. Modifying Python containers (lists, dicts) in forward pass
2. Data-dependent control flow (if statements on tensor values)
3. Side effects like print statements or file I/O
4. Converting to/from numpy arrays
5. Dynamic imports inside compiled functions

### 🔧 Debugging Tools:
1. Enable verbose logging with `torch._dynamo.config.verbose = True`
2. Use `fullgraph=True` to require full compilation
3. Check `torch._dynamo.utils.counters` for statistics
4. Use `torch._dynamo.explain()` for detailed analysis
5. Profile with PyTorch profiler for performance insights