# PyTorch + Triton Fundamentals

This notebook provides a comprehensive introduction to PyTorch's compilation system and how it leverages Triton for GPU optimization. We'll focus on understanding the fundamentals before diving into advanced kernel optimization.

## 🎯 Learning Objectives

By the end of this notebook, you'll understand:
- How PyTorch's compilation system works internally
- Why compilation has overhead and how to manage it
- How to use environment variables for debugging and optimization
- Best practices for production deployment
- How to troubleshoot common compilation issues

## 📚 What You'll Learn

### Core Concepts
1. **PyTorch Compilation Pipeline**: From Python code to optimized GPU kernels
2. **Environment Variables**: Powerful debugging and monitoring tools
3. **Performance Patterns**: Understanding compilation overhead vs execution benefits
4. **Production Deployment**: Best practices for real-world applications

### Practical Skills
- Setting up optimal development environments
- Debugging compilation issues effectively
- Measuring and analyzing performance impacts
- Deploying compiled models in production

Let's start with the fundamentals!

In [None]:
# Environment Setup and Foundation
import os
import torch
import time
import gc
from pathlib import Path
from typing import Dict, List, Tuple

def setup_pytorch_triton_environment():
    """
    Configure PyTorch and Triton for educational exploration
    
    This function demonstrates how to set up environment variables
    that provide deep insights into PyTorch's compilation process.
    """
    
    print("🚀 Setting up PyTorch + Triton Learning Environment")
    print("=" * 60)
    
    # Core environment variables for understanding compilation
    educational_settings = {
        # Show generated kernel code - see what Triton creates
        "TORCH_LOGS": "output_code",
        
        # Display autotuning process - see optimization in action
        "TRITON_PRINT_AUTOTUNING": "1", 
        
        # Show cache statistics - understand reuse patterns
        "TRITON_PRINT_CACHE_STATS": "1",
        
        # Additional debugging (optional)
        # "TORCH_LOGS": "output_code,dynamo,inductor",  # More detailed logs
        # "TRITON_PRINT_CACHE_DIR": "1",  # Show cache directory
    }
    
    for key, value in educational_settings.items():
        os.environ[key] = value
        print(f"✅ {key} = '{value}'")
    
    print(f"\n📖 What these variables reveal:")
    print(f"  • TORCH_LOGS: Shows actual generated Triton kernel source code")
    print(f"  • TRITON_PRINT_AUTOTUNING: Displays different configurations being tested")
    print(f"  • TRITON_PRINT_CACHE_STATS: Shows kernel cache hits vs misses")
    
    return educational_settings

def detect_and_configure_device():
    """
    Detect GPU capabilities and configure for optimal learning
    """
    
    print(f"\n🔍 Device Detection and Configuration")
    print("=" * 40)
    
    print(f"PyTorch version: {torch.__version__}")
    
    if torch.cuda.is_available():
        device = "cuda"
        print(f"✅ CUDA GPU available: {torch.cuda.get_device_name(0)}")
        print(f"   Device count: {torch.cuda.device_count()}")
        print(f"   CUDA version: {torch.version.cuda}")
        print(f"   Compute capability: {torch.cuda.get_device_capability(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        
        # Check Triton availability
        try:
            import triton
            print(f"✅ Triton available: {triton.__version__}")
        except ImportError:
            print(f"⚠️  Triton not available - install with: pip install triton")
            
    else:
        device = "cpu"
        print("⚠️  CUDA not available - using CPU")
        print("   Note: Many optimizations are GPU-specific")
    
    print(f"\n🎯 Selected device: {device.upper()}")
    return device

# Initialize the learning environment
settings = setup_pytorch_triton_environment()
device = detect_and_configure_device()

print(f"\n✅ Environment ready for PyTorch + Triton exploration!")

## Understanding PyTorch Compilation Pipeline

### 🧠 How PyTorch Compilation Works

When you use `@torch.compile()`, PyTorch goes through several sophisticated stages:


Let's explore each stage:

#### 1. **Graph Capture** 📊
- PyTorch traces your Python code execution
- Creates a computation graph (nodes = operations, edges = data flow)
- Captures control flow and data dependencies

#### 2. **Graph Optimization** ⚡
- Fusion opportunities identified (combine multiple ops)
- Dead code elimination (remove unused computations)
- Constant folding (precompute constant expressions)
- Memory layout optimization

#### 3. **Backend Selection** 🎯
- Triton selected for GPU operations
- Different backends for different hardware (CPU, GPU, TPU)
- Backend-specific optimization passes

#### 4. **Kernel Generation** 🔧
- Triton generates GPU kernel source code
- Automatic memory management and parallelization
- Hardware-specific optimizations applied

#### 5. **Compilation** ⚙️
- Triton kernels compiled to GPU machine code
- CUDA compilation pipeline invoked
- Binary kernels ready for execution

#### 6. **Caching** 💾
- Compiled kernels cached for reuse
- Cache keys based on input shapes and types
- Avoids recompilation for identical patterns

### 🎓 Key Insight: Two-Phase Performance

This pipeline explains the fundamental performance pattern:
- **First Run**: Slow (includes all compilation overhead)
- **Subsequent Runs**: Fast (uses cached compiled kernels)

In [None]:
# Demonstrating Compilation Overhead vs Execution Speed
import torch.nn as nn
import torch.nn.functional as F
import os

def demonstrate_compilation_phases():
    """
    Practical demonstration of compilation overhead vs execution speed
    
    This function shows the two-phase performance pattern that's
    fundamental to understanding PyTorch compilation.
    """
    
    print("🧪 DEMONSTRATION: Compilation Phases")
    print("=" * 50)
    
    # Enable verbose compilation output to see Triton kernel generation
    old_verbose = os.environ.get('TORCH_COMPILE_DEBUG', '0')
    os.environ['TORCH_COMPILE_DEBUG'] = '1'
    
    # Also enable TorchInductor debug output
    import torch._inductor.config as config
    old_debug = config.debug
    config.debug = True
    
    print("🔧 Enabled verbose compilation output - you should now see Triton kernel generation!")
    
    try:
        # Create a simple but representative model
        class SimpleModel(nn.Module):
            def __init__(self, hidden_size=512):
                super().__init__()
                self.layer_norm = nn.LayerNorm(hidden_size)
                
            def forward(self, x):
                # Simple pattern: normalize then activate
                normalized = self.layer_norm(x)
                return F.gelu(normalized)
        
        # Initialize model and test data
        model = SimpleModel().to(device)
        test_input = torch.randn(32, 128, 512, device=device)
        
        print(f"\n📊 Test configuration:")
        print(f"   Model: LayerNorm + GELU")
        print(f"   Input shape: {test_input.shape}")
        print(f"   Device: {device}")
        
        # Phase 1: Measure baseline (uncompiled) performance
        print(f"\n🔍 Phase 1: Baseline (Uncompiled) Performance")
        
        # Warmup
        for _ in range(5):
            with torch.no_grad():
                _ = model(test_input)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        # Measure baseline
        baseline_times = []
        for _ in range(10):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            start = time.perf_counter()
            with torch.no_grad():
                output = model(test_input)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            baseline_times.append(time.perf_counter() - start)
        
        baseline_avg = sum(baseline_times) / len(baseline_times)
        print(f"   Average time: {baseline_avg*1000:.3f} ms")
        
        # Phase 2: Compile the model
        print(f"\n🔧 Phase 2: Compiling Model (Watch for Triton Output Below)")
        print(f"   Note: With debug enabled, you should see detailed Triton kernel generation")
        print("=" * 60)
        
        compiled_model = torch.compile(model, mode="default")
        
        print("=" * 60)
        print("🔚 End of compilation output")
        
        # Phase 3: First run (compilation + execution)
        print(f"\n⏱️  Phase 3: First Run (Compilation + Execution)")
        print(f"   Note: Additional Triton kernels may be generated during first execution")
        print("-" * 40)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        start = time.perf_counter()
        with torch.no_grad():
            compiled_output = compiled_model(test_input)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        first_run_time = time.perf_counter() - start
        print("-" * 40)
        print(f"   First run time: {first_run_time*1000:.3f} ms")
        print(f"   Overhead factor: {first_run_time/baseline_avg:.1f}x slower than baseline")
        
        # Phase 4: Subsequent runs (cached execution)
        print(f"\n⚡ Phase 4: Subsequent Runs (Cached Kernels)")
        
        cached_times = []
        for i in range(10):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            start = time.perf_counter()
            with torch.no_grad():
                _ = compiled_model(test_input)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            cached_times.append(time.perf_counter() - start)
        
        cached_avg = sum(cached_times) / len(cached_times)
        print(f"   Average cached time: {cached_avg*1000:.3f} ms")
        print(f"   Speedup vs baseline: {baseline_avg/cached_avg:.2f}x")
        print(f"   Speedup vs first run: {first_run_time/cached_avg:.1f}x")
        
        # Verify correctness
        max_diff = (output - compiled_output).abs().max().item()
        print(f"\n✅ Correctness check: Max difference = {max_diff:.2e}")
        
        return {
            'baseline_avg': baseline_avg,
            'first_run_time': first_run_time, 
            'cached_avg': cached_avg,
            'compilation_overhead': first_run_time / baseline_avg,
            'speedup': baseline_avg / cached_avg
        }
    
    finally:
        # Restore original settings
        os.environ['TORCH_COMPILE_DEBUG'] = old_verbose
        config.debug = old_debug
        print(f"\n🔧 Restored original debug settings")

# Run the demonstration
results = demonstrate_compilation_phases()

print(f"\n🎓 Key Takeaways:")
print(f"   • Compilation adds significant overhead to first run")
print(f"   • Subsequent runs benefit from cached optimized kernels")
print(f"   • The break-even point depends on how many times you'll run the model")
print(f"   • In production, you want to 'warm up' during initialization")

In [None]:
# DIRECT APPROACH: Capture and display Triton compilation output
import torch
import sys
import io
import contextlib
import logging
import os

# Clear any previous compilations
torch._dynamo.reset()

print("🎯 DIRECT TRITON KERNEL GENERATION CAPTURE")
print("=" * 50)

# Set up comprehensive logging
logging.basicConfig(level=logging.DEBUG, force=True)

# Enable ALL debug output
os.environ.update({
    'TORCH_COMPILE_DEBUG': '1',
    'TORCH_LOGS': '+dynamo,+inductor,+aot',
    'TORCHINDUCTOR_VERBOSE': '1',
    'TRITON_PRINT_AUTOTUNING': '1',
    'TRITON_DEBUG': '1'
})

# Configure inductor for maximum verbosity
import torch._inductor.config as config
config.debug = True
config.trace.enabled = True
config.verbose_progress = True

print("🔧 Environment configured for maximum compilation visibility")

# Create a simple model that will definitely generate Triton kernels
def triton_demo_model(x):
    # This pattern should trigger multiple Triton kernels
    y = torch.relu(x)           # Pointwise operation
    z = y * 2.0 + 1.0          # Fused arithmetic
    return torch.sum(z, dim=-1) # Reduction operation

# Test data
x = torch.randn(512, 512, device=device, requires_grad=False)

print(f"\n📊 Input: {x.shape} on {device}")
print("\n🚀 COMPILING MODEL - Watch for Triton output below:")
print("=" * 60)

# Capture stdout and stderr during compilation
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()

try:
    with contextlib.redirect_stdout(stdout_capture), \
         contextlib.redirect_stderr(stderr_capture):
        
        # Compile the model
        compiled_model = torch.compile(triton_demo_model, mode="default")
        
        # First execution (triggers kernel generation)
        result = compiled_model(x)
        
    # Get captured output
    stdout_output = stdout_capture.getvalue()
    stderr_output = stderr_capture.getvalue()
    
    print("=" * 60)
    print("🔚 END OF COMPILATION")
    
    # Display captured output
    if stdout_output:
        print(f"\n📝 CAPTURED STDOUT ({len(stdout_output)} chars):")
        print("-" * 40)
        print(stdout_output[:2000])  # Show first 2000 chars
        if len(stdout_output) > 2000:
            print(f"... ({len(stdout_output) - 2000} more characters)")
    
    if stderr_output:
        print(f"\n📝 CAPTURED STDERR ({len(stderr_output)} chars):")
        print("-" * 40)
        print(stderr_output[:2000])  # Show first 2000 chars
        if len(stderr_output) > 2000:
            print(f"... ({len(stderr_output) - 2000} more characters)")
    
    print(f"\n✅ Compilation successful!")
    print(f"   Result shape: {result.shape}")
    print(f"   Result: {result[:5]}")
    
except Exception as e:
    print(f"❌ Error during compilation: {e}")
    # Still show captured output even if there was an error
    stdout_output = stdout_capture.getvalue()
    stderr_output = stderr_capture.getvalue()
    
    if stdout_output:
        print(f"\n📝 PARTIAL STDOUT:")
        print(stdout_output[:1000])
    if stderr_output:
        print(f"\n📝 PARTIAL STDERR:")
        print(stderr_output[:1000])

In [None]:
# EXAMINE GENERATED TRITON KERNELS DIRECTLY
import os
import glob
from pathlib import Path

print("🔍 EXAMINING GENERATED TRITON KERNELS")
print("=" * 45)

# Check the debug trace directory mentioned in the output above
debug_base = "/home/alibina/repo/innovation_crucible/notes/triton-gpu-optimization/torch_compile_debug"
if os.path.exists(debug_base):
    # Find the latest run directory
    run_dirs = glob.glob(f"{debug_base}/run_*")
    if run_dirs:
        latest_run = max(run_dirs, key=os.path.getctime)
        inductor_dir = os.path.join(latest_run, "torchinductor")
        
        print(f"📂 Latest debug run: {os.path.basename(latest_run)}")
        print(f"📂 Inductor directory: {inductor_dir}")
        
        if os.path.exists(inductor_dir):
            # Find all generated files
            all_files = []
            for root, dirs, files in os.walk(inductor_dir):
                for file in files:
                    full_path = os.path.join(root, file)
                    all_files.append(full_path)
            
            print(f"\n📄 Found {len(all_files)} generated files:")
            
            # Categorize files
            py_files = [f for f in all_files if f.endswith('.py')]
            cpp_files = [f for f in all_files if f.endswith(('.cpp', '.h'))]
            other_files = [f for f in all_files if not f.endswith(('.py', '.cpp', '.h', '.lock'))]
            
            print(f"   🐍 Python files: {len(py_files)}")
            print(f"   🔧 C++ files: {len(cpp_files)}")
            print(f"   📋 Other files: {len(other_files)}")
            
            # Show Python files (likely Triton kernels)
            if py_files:
                print(f"\n🐍 PYTHON/TRITON KERNEL FILES:")
                for f in py_files:
                    rel_path = os.path.relpath(f, inductor_dir)
                    size = os.path.getsize(f)
                    print(f"   📄 {rel_path} ({size} bytes)")
                
                # Show content of the first substantial Python file
                substantial_py = [f for f in py_files if os.path.getsize(f) > 100]
                if substantial_py:
                    print(f"\n📝 KERNEL SOURCE CODE ({os.path.basename(substantial_py[0])}):")
                    print("-" * 50)
                    try:
                        with open(substantial_py[0], 'r') as file:
                            content = file.read()
                            lines = content.split('\n')
                            
                            # Show the content with line numbers
                            for i, line in enumerate(lines[:50], 1):  # First 50 lines
                                print(f"{i:3d}: {line}")
                            
                            if len(lines) > 50:
                                print(f"... ({len(lines) - 50} more lines)")
                            
                            # Look for Triton-specific keywords
                            triton_keywords = ['@triton.jit', 'tl.program_id', 'tl.load', 'tl.store', 'BLOCK_SIZE']
                            found_keywords = [kw for kw in triton_keywords if kw in content]
                            
                            if found_keywords:
                                print(f"\n🎯 TRITON KEYWORDS FOUND: {', '.join(found_keywords)}")
                            else:
                                print(f"\nℹ️  This appears to be generated wrapper code, not raw Triton kernel")
                                
                    except Exception as e:
                        print(f"❌ Could not read file: {e}")
        else:
            print(f"❌ Inductor directory not found: {inductor_dir}")
    else:
        print("❌ No debug run directories found")
else:
    print(f"❌ Debug base directory not found: {debug_base}")

# Also check the kernel cache
print(f"\n📁 CHECKING KERNEL CACHE")
cache_dir = "/tmp/torchinductor_alibina"
if os.path.exists(cache_dir):
    cache_files = []
    for root, dirs, files in os.walk(cache_dir):
        for file in files:
            if file.endswith('.py'):
                cache_files.append(os.path.join(root, file))
    
    print(f"🔧 Found {len(cache_files)} cached Python files")
    
    if cache_files:
        # Show the most recent cache file
        latest_cache = max(cache_files, key=os.path.getctime)
        print(f"\n📝 LATEST CACHED KERNEL ({os.path.basename(latest_cache)}):")
        print("-" * 50)
        
        try:
            with open(latest_cache, 'r') as file:
                content = file.read()
                lines = content.split('\n')
                
                for i, line in enumerate(lines[:30], 1):  # First 30 lines
                    print(f"{i:3d}: {line}")
                
                if len(lines) > 30:
                    print(f"... ({len(lines) - 30} more lines)")
                
                # Check for Triton signatures
                if '@triton.jit' in content:
                    print("\n✅ This is a genuine Triton kernel!")
                elif 'triton' in content.lower():
                    print("\n🔧 This file references Triton")
                else:
                    print("\nℹ️  This appears to be wrapper/helper code")
                    
        except Exception as e:
            print(f"❌ Could not read cache file: {e}")
else:
    print(f"❌ Cache directory not found: {cache_dir}")

In [None]:
# FINAL APPROACH: Direct FX compilation to show Triton kernel generation
import torch
import torch.fx
from torch._inductor import compile_fx
import sys
from io import StringIO

print("🚀 FINAL APPROACH: Direct FX Compilation")
print("=" * 45)

# Clear everything
torch._dynamo.reset()

# Create a simple function that will generate Triton kernels
def kernel_demo(x, y):
    # Multiple operations that should each generate kernels
    z1 = torch.relu(x)              # Pointwise
    z2 = z1 + y                     # Pointwise fusion
    z3 = z2 * 2.0                   # More fusion
    z4 = torch.sum(z3, dim=0)       # Reduction
    return z4

# Create test inputs
x = torch.randn(256, 256, device=device)
y = torch.randn(256, 256, device=device)

print(f"📊 Inputs: x={x.shape}, y={y.shape} on {device}")

# Enable verbose mode
import torch._inductor.config as config
config.debug = True
config.verbose_progress = True

# Capture the FX graph
print("\n🔍 Step 1: Capturing FX Graph...")
traced = torch.fx.symbolic_trace(kernel_demo)
print(f"✅ Graph captured with {len(list(traced.graph.nodes))} nodes")

# Show the graph
print("\n📊 FX Graph Structure:")
print(traced.graph)

print("\n🔧 Step 2: Compiling with Inductor (Watch for Triton output)...")
print("=" * 50)

# Redirect stdout to capture compilation output
old_stdout = sys.stdout
output_capture = StringIO()

try:
    sys.stdout = output_capture
    
    # Compile using inductor directly
    compiled_fn = compile_fx(traced, [x, y])
    
    # Restore stdout
    sys.stdout = old_stdout
    
    # Get the captured output
    compilation_output = output_capture.getvalue()
    
    print("=" * 50)
    print("🔚 Compilation Complete")
    
    # Show compilation output
    if compilation_output:
        print(f"\n📝 COMPILATION OUTPUT ({len(compilation_output)} characters):")
        print("-" * 40)
        lines = compilation_output.split('\n')
        for i, line in enumerate(lines[:100]):  # First 100 lines
            if line.strip():  # Skip empty lines
                print(f"{i+1:3d}: {line}")
        
        if len(lines) > 100:
            print(f"... ({len(lines) - 100} more lines)")
        
        # Look for Triton-specific content
        triton_indicators = ['triton', 'kernel', '@jit', 'tl.', 'BLOCK_SIZE']
        found_indicators = []
        for indicator in triton_indicators:
            if indicator in compilation_output.lower():
                found_indicators.append(indicator)
        
        if found_indicators:
            print(f"\n🎯 TRITON INDICATORS FOUND: {', '.join(found_indicators)}")
        else:
            print("\nℹ️  No obvious Triton indicators in compilation output")
    else:
        print("\n⚠️  No compilation output captured")
    
    # Test the compiled function
    print(f"\n⚡ Step 3: Testing Compiled Function...")
    result = compiled_fn(x, y)
    print(f"✅ Result shape: {result.shape}")
    print(f"   Sample values: {result[:5]}")
    
except Exception as e:
    sys.stdout = old_stdout
    print(f"❌ Compilation failed: {e}")
    
    # Show partial output
    partial_output = output_capture.getvalue()
    if partial_output:
        print(f"\n📝 PARTIAL OUTPUT:")
        print(partial_output[:1000])

print(f"\n🎓 Summary:")
print(f"   • FX graph successfully traced and compiled")
print(f"   • Check the output above for Triton kernel generation details")
print(f"   • Generated kernels are cached in /tmp/torchinductor_alibina")

Now, let's to Set environment variables to see Triton compilation

In [None]:

import os
import logging

# Clear any cached compilations first
torch._dynamo.reset()

# Set environment variables to show detailed compilation info
os.environ['TORCH_COMPILE_DEBUG'] = '1'
os.environ['TORCH_LOGS'] = '+dynamo,+inductor'
os.environ['TORCHINDUCTOR_VERBOSE'] = '1'

# Enable all relevant logging
logging.basicConfig(level=logging.DEBUG)

print("🔧 Environment variables set for maximum verbosity:")
print(f"   TORCH_COMPILE_DEBUG = {os.environ.get('TORCH_COMPILE_DEBUG')}")
print(f"   TORCH_LOGS = {os.environ.get('TORCH_LOGS')}")
print(f"   TORCHINDUCTOR_VERBOSE = {os.environ.get('TORCHINDUCTOR_VERBOSE')}")
print("\nNow run the compilation demonstration above to see Triton kernel generation!")

In [None]:
# Minimal example to trigger Triton kernel generation with maximum visibility
import torch
torch._dynamo.reset()  # Clear cache

# Enable the most specific debugging available
import torch._inductor.config as config
config.debug = True
config.trace.enabled = True

# Additional environment variables for Triton visibility
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
os.environ['TRITON_DEBUG'] = '1'

print("🎯 FOCUSED TRITON DEMONSTRATION")
print("=" * 40)

# Simple operation that will definitely trigger Triton compilation
def simple_operation(x):
    return torch.relu(x) + 1.0

# Create input
x = torch.randn(1024, 1024, device=device)

print("📝 About to compile a simple ReLU + addition...")
print("   Look for compilation messages in the output below:")
print("-" * 40)

# Compile with maximum verbosity
compiled_fn = torch.compile(simple_operation, mode="default")

# Trigger compilation
print("🚀 First execution (triggers kernel generation):")
result = compiled_fn(x)

print("-" * 40)
print("✅ Compilation completed!")
print(f"   Result shape: {result.shape}")
print(f"   Result mean: {result.mean():.4f}")

# Show some kernel information if available
import torch._inductor.codecache as codecache
cache_dir = codecache.cache_dir()
print(f"📁 Kernel cache directory: {cache_dir}")

# Try to list generated files
import glob
triton_files = glob.glob(f"{cache_dir}/*triton*")
if triton_files:
    print(f"🔧 Found {len(triton_files)} Triton-related cache files")
    for f in triton_files[:3]:  # Show first 3
        print(f"   - {os.path.basename(f)}")
else:
    print("ℹ️  No Triton cache files found yet")

In [None]:
# Explore the generated Triton kernels
import os
import glob

print("🔍 EXPLORING GENERATED TRITON KERNELS")
print("=" * 45)

# Find the debug trace directory
debug_dirs = glob.glob("./torch_compile_debug/*/torchinductor/")
if debug_dirs:
    latest_debug_dir = max(debug_dirs, key=os.path.getctime)
    print(f"📂 Latest debug directory: {latest_debug_dir}")
    
    # Look for generated kernel files
    kernel_files = []
    for ext in ['*.py', '*.cpp', '*.h']:
        kernel_files.extend(glob.glob(os.path.join(latest_debug_dir, "**", ext), recursive=True))
    
    print(f"\n🔧 Found {len(kernel_files)} generated files:")
    for f in kernel_files:
        rel_path = os.path.relpath(f, latest_debug_dir)
        print(f"   📄 {rel_path}")
    
    # Try to show some Triton kernel source
    triton_files = [f for f in kernel_files if 'triton' in f.lower() or f.endswith('.py')]
    if triton_files:
        print(f"\n📝 Triton kernel source (first 30 lines of {os.path.basename(triton_files[0])}):")
        print("-" * 50)
        try:
            with open(triton_files[0], 'r') as file:
                lines = file.readlines()
                for i, line in enumerate(lines[:30]):
                    print(f"{i+1:2d}: {line.rstrip()}")
                if len(lines) > 30:
                    print(f"... ({len(lines) - 30} more lines)")
        except Exception as e:
            print(f"❌ Could not read file: {e}")
    else:
        print("ℹ️  No Triton kernel source files found yet")
else:
    print("❌ No debug directories found")

# Also check the kernel cache
cache_dir = "/tmp/torchinductor_alibina"
print(f"\n📁 Checking kernel cache: {cache_dir}")
if os.path.exists(cache_dir):
    cache_files = glob.glob(f"{cache_dir}/**/*", recursive=True)
    py_files = [f for f in cache_files if f.endswith('.py') and os.path.isfile(f)]
    
    print(f"🐍 Found {len(py_files)} Python files in cache:")
    for f in py_files[:5]:  # Show first 5
        rel_path = os.path.relpath(f, cache_dir)
        print(f"   📄 {rel_path}")
    
    # Show content of a kernel file if available
    if py_files:
        print(f"\n📝 Sample kernel content (first 20 lines):")
        print("-" * 40)
        try:
            with open(py_files[0], 'r') as file:
                lines = file.readlines()
                for i, line in enumerate(lines[:20]):
                    print(f"{i+1:2d}: {line.rstrip()}")
        except Exception as e:
            print(f"❌ Could not read file: {e}")
else:
    print("❌ Cache directory not found")

## Environment Variables: Your Debugging Toolkit

Environment variables are your window into PyTorch's compilation process. Let's explore the most important ones and what they reveal.

### 🔍 Essential Environment Variables

| Variable | Purpose | What You'll See | When to Use |
|----------|---------|-----------------|-------------|
| `TORCH_LOGS=output_code` | Shows generated kernel code | Actual Triton source code | Understanding optimizations |
| `TRITON_PRINT_AUTOTUNING=1` | Displays autotuning process | Different block sizes tested | Performance debugging |
| `TRITON_PRINT_CACHE_STATS=1` | Shows cache statistics | Cache hits vs misses | Cache optimization |
| `TORCH_LOGS=dynamo` | Shows graph capture | Python → graph conversion | Debugging capture issues |
| `TORCH_LOGS=inductor` | Shows backend compilation | Optimization passes | Backend debugging |

### 🎯 Debug Levels

You can combine multiple log types:
```python
# Comprehensive debugging (verbose!)
os.environ["TORCH_LOGS"] = "output_code,dynamo,inductor"

# Focus on specific areas
os.environ["TORCH_LOGS"] = "output_code"  # Just kernel code
```

### 💡 Production vs Development

**Development Environment:**
- Enable detailed logging for learning and debugging
- Use cache statistics to understand reuse patterns
- Monitor autotuning to see optimization decisions

**Production Environment:**
- Minimal logging for performance
- Cache kernels to avoid recompilation
- Pre-warm models during initialization

In [None]:
# Exploring Environment Variables in Action
def explore_environment_variables():
    """
    Demonstrate how different environment variables provide insights
    into the compilation process.
    """
    
    print("🔍 EXPLORING ENVIRONMENT VARIABLES")
    print("=" * 50)
    
    # Create a model that will trigger interesting optimizations
    def fusion_example(x):
        # Multiple operations that can be fused
        y = torch.relu(x)
        z = y * 2.0
        w = z + 1.0
        return torch.tanh(w)
    
    test_data = torch.randn(1000, device=device)
    
    print("📊 Test case: Multi-operation fusion example")
    print("   Operations: ReLU → Multiply → Add → Tanh")
    print("   Expected: These should fuse into a single kernel")
    
    # Demonstrate different logging levels
    scenarios = [
        ("minimal", {}),
        ("output_code", {"TORCH_LOGS": "output_code"}),
        ("with_autotuning", {
            "TORCH_LOGS": "output_code",
            "TRITON_PRINT_AUTOTUNING": "1"
        }),
        ("comprehensive", {
            "TORCH_LOGS": "output_code,dynamo,inductor",
            "TRITON_PRINT_AUTOTUNING": "1",
            "TRITON_PRINT_CACHE_STATS": "1"
        })
    ]
    
    for scenario_name, env_vars in scenarios:
        print(f"\n🎯 Scenario: {scenario_name.upper()}")
        print("-" * 30)
        
        # Temporarily set environment variables
        original_env = {}
        for key, value in env_vars.items():
            original_env[key] = os.environ.get(key)
            os.environ[key] = value
            print(f"   {key} = {value}")
        
        if not env_vars:
            print("   No special logging enabled")
        
        print(f"\n   Compiling and running...")
        
        # Clear compilation cache to force recompilation
        torch._dynamo.reset()
        
        # Compile and run
        compiled_fn = torch.compile(fusion_example)
        
        # Time the execution
        start = time.perf_counter()
        result = compiled_fn(test_data)
        execution_time = time.perf_counter() - start
        
        print(f"   ✅ Execution time: {execution_time*1000:.3f} ms")
        
        # Restore original environment
        for key in env_vars:
            if original_env[key] is not None:
                os.environ[key] = original_env[key]
            else:
                os.environ.pop(key, None)
        
        print(f"   🔄 Environment restored")
    
    print(f"\n🎓 Observations:")
    print(f"   • 'minimal': Clean output, no compilation details")
    print(f"   • 'output_code': Shows generated Triton kernel source")
    print(f"   • 'with_autotuning': Shows performance optimization process")
    print(f"   • 'comprehensive': Full insight into entire pipeline")
    
    # Restore our educational settings
    for key, value in settings.items():
        os.environ[key] = value

# Run the exploration
explore_environment_variables()

print(f"\n💡 Pro Tips:")
print(f"   • Start with TORCH_LOGS=output_code for learning")
print(f"   • Add autotuning logs when optimizing performance")
print(f"   • Use comprehensive logging only when debugging issues")
print(f"   • Turn off logging in production for best performance")

## Performance Patterns and Optimization Strategies

Understanding PyTorch compilation performance patterns is crucial for effective optimization. Let's explore the key patterns and how to leverage them.

### 📊 Performance Pattern Analysis

#### The Break-Even Point
```
Total Time = Compilation Time + (Execution Time × Number of Runs)

Uncompiled Total = Baseline Time × Number of Runs
Compiled Total = Compilation Time + (Optimized Time × Number of Runs)

Break-even when: Compilation Time = (Baseline - Optimized) × Number of Runs
```

#### Factors Affecting Performance

1. **Model Complexity**: More operations → more fusion opportunities
2. **Input Size**: Larger tensors → better amortization of overhead
3. **Hardware**: Better GPUs → more optimization opportunities
4. **Pattern Recognition**: Common patterns → better optimizations

### 🎯 Optimization Strategies

#### Strategy 1: Warm-up in Development
```python
# During model initialization
model = MyModel()
compiled_model = torch.compile(model)

# Warm-up with dummy data
dummy_input = torch.randn(typical_batch_size, ...)
_ = compiled_model(dummy_input)  # Triggers compilation

# Now ready for production use
```

#### Strategy 2: Selective Compilation
```python
# Compile only the critical paths
class MyModel(nn.Module):
    def __init__(self):
        self.critical_path = torch.compile(self.forward_critical)
        self.non_critical = self.forward_simple
    
    def forward(self, x):
        if self.training:
            return self.critical_path(x)  # Optimized training
        else:
            return self.non_critical(x)   # Fast inference
```

#### Strategy 3: Cache Management
```python
# Save compiled model state
torch.save({
    'model_state': model.state_dict(),
    'compiled_state': compiled_model.state_dict()
}, 'model_with_cache.pt')
```

In [None]:
# Performance Pattern Analysis and Break-Even Calculation
def analyze_performance_patterns():
    """
    Analyze when compilation pays off and develop optimization strategies
    """
    
    print("📊 PERFORMANCE PATTERN ANALYSIS")
    print("=" * 50)
    
    # Test different scenarios
    scenarios = [
        ("Small Model", 32, 64, 256),      # Small: batch=32, seq=64, hidden=256
        ("Medium Model", 16, 128, 512),    # Medium: batch=16, seq=128, hidden=512  
        ("Large Model", 8, 256, 1024),     # Large: batch=8, seq=256, hidden=1024
    ]
    
    results = []
    
    for scenario_name, batch_size, seq_len, hidden_size in scenarios:
        print(f"\n🧪 Scenario: {scenario_name}")
        print(f"   Configuration: B={batch_size}, S={seq_len}, H={hidden_size}")
        
        # Create model and data
        class TestModel(nn.Module):
            def __init__(self, hidden_size):
                super().__init__()
                self.norm1 = nn.LayerNorm(hidden_size)
                self.norm2 = nn.LayerNorm(hidden_size)
                
            def forward(self, x):
                x = F.gelu(self.norm1(x))
                x = F.relu(self.norm2(x))
                return x
        
        model = TestModel(hidden_size).to(device)
        test_input = torch.randn(batch_size, seq_len, hidden_size, device=device)
        
        # Measure baseline performance
        print(f"   📏 Measuring baseline...")
        
        # Warmup
        for _ in range(5):
            with torch.no_grad():
                _ = model(test_input)
        
        # Measure
        baseline_times = []
        for _ in range(20):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            start = time.perf_counter()
            with torch.no_grad():
                _ = model(test_input)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            baseline_times.append(time.perf_counter() - start)
        
        baseline_avg = sum(baseline_times) / len(baseline_times)
        
        # Measure compilation overhead
        print(f"   ⚙️  Measuring compilation...")
        
        torch._dynamo.reset()  # Clear cache
        compiled_model = torch.compile(model, mode="default")
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        start = time.perf_counter()
        with torch.no_grad():
            _ = compiled_model(test_input)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        compilation_time = time.perf_counter() - start
        
        # Measure optimized performance
        print(f"   ⚡ Measuring optimized performance...")
        
        optimized_times = []
        for _ in range(20):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            start = time.perf_counter()
            with torch.no_grad():
                _ = compiled_model(test_input)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            optimized_times.append(time.perf_counter() - start)
        
        optimized_avg = sum(optimized_times) / len(optimized_times)
        
        # Calculate break-even point
        if baseline_avg > optimized_avg:
            break_even = compilation_time / (baseline_avg - optimized_avg)
        else:
            break_even = float('inf')  # Never breaks even
        
        # Store results
        scenario_results = {
            'name': scenario_name,
            'baseline_ms': baseline_avg * 1000,
            'optimized_ms': optimized_avg * 1000,
            'compilation_ms': compilation_time * 1000,
            'speedup': baseline_avg / optimized_avg if optimized_avg > 0 else 0,
            'break_even_runs': break_even
        }
        
        results.append(scenario_results)
        
        # Print results for this scenario
        print(f"   📊 Results:")
        print(f"      Baseline: {scenario_results['baseline_ms']:.3f} ms")
        print(f"      Optimized: {scenario_results['optimized_ms']:.3f} ms")
        print(f"      Compilation: {scenario_results['compilation_ms']:.1f} ms")
        print(f"      Speedup: {scenario_results['speedup']:.2f}x")
        if break_even != float('inf'):
            print(f"      Break-even: {break_even:.1f} runs")
        else:
            print(f"      Break-even: Never (compilation slower)")
    
    # Summary analysis
    print(f"\n📈 SUMMARY ANALYSIS")
    print("=" * 40)
    
    print(f"{'Scenario':<15} {'Speedup':<8} {'Break-even':<12} {'Recommendation':<20}")
    print("-" * 65)
    
    for result in results:
        speedup_str = f"{result['speedup']:.2f}x"
        
        if result['break_even_runs'] == float('inf'):
            breakeven_str = "Never"
            recommendation = "Skip compilation"
        elif result['break_even_runs'] < 5:
            breakeven_str = f"{result['break_even_runs']:.1f} runs"
            recommendation = "Always compile"
        elif result['break_even_runs'] < 20:
            breakeven_str = f"{result['break_even_runs']:.1f} runs"
            recommendation = "Compile for training"
        else:
            breakeven_str = f"{result['break_even_runs']:.1f} runs"
            recommendation = "Selective compilation"
        
        print(f"{result['name']:<15} {speedup_str:<8} {breakeven_str:<12} {recommendation:<20}")
    
    return results

# Run the analysis
performance_results = analyze_performance_patterns()

print(f"\n🎓 Key Insights:")
print(f"   • Larger models generally benefit more from compilation")
print(f"   • Break-even point varies significantly by model size")
print(f"   • Consider your use case: training vs inference vs experimentation")
print(f"   • Measure your specific workloads - patterns vary!")

## Debugging Common Compilation Issues

Even with PyTorch's sophisticated compilation system, issues can arise. Let's explore common problems and their solutions.

### 🐛 Common Issues and Solutions

#### 1. **Compilation Failures**
```python
# Common error: Dynamic shapes
RuntimeError: Cannot compile with dynamic shapes

# Solution: Use torch.compile with dynamic=True or fix shapes
compiled_fn = torch.compile(fn, dynamic=True)
```

#### 2. **Performance Regressions**
```python
# Issue: Compiled version slower than baseline
# Causes: Small models, wrong compilation mode, graph breaks

# Solutions:
# 1. Try different modes
compiled_fn = torch.compile(fn, mode="reduce-overhead")  # vs "default"

# 2. Check for graph breaks
with torch._dynamo.optimize("inductor"):
    result = fn(input)  # Will show graph break warnings
```

#### 3. **Memory Issues**
```python
# Issue: Out of memory during compilation
# Solution: Reduce compilation scope or use checkpointing
@torch.compile(mode="reduce-overhead")
def smaller_function(x):
    # Break large functions into smaller ones
    return partial_computation(x)
```

#### 4. **Unsupported Operations**
```python
# Issue: Some operations don't support compilation
# Solution: Selective compilation or fallbacks

class HybridModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.compiled_part = torch.compile(self.core_computation)
        
    def forward(self, x):
        # Compiled part
        x = self.compiled_part(x)
        
        # Unsupported operations run normally
        x = unsupported_operation(x)
        
        return x
```

### 🔧 Debugging Toolkit

1. **Environment Variables**: Use detailed logging
2. **Graph Breaks**: Monitor for optimization barriers
3. **Profiling**: Use torch.profiler for detailed analysis
4. **Selective Compilation**: Isolate problematic areas

In [None]:
# Debugging Compilation Issues
def demonstrate_debugging_techniques():
    """
    Show common compilation issues and how to debug them
    """
    
    print("🐛 DEBUGGING COMPILATION ISSUES")
    print("=" * 50)
    
    # Issue 1: Graph breaks
    print("🔍 Issue 1: Graph Breaks")
    print("-" * 30)
    
    def problematic_function(x):
        # This will cause a graph break
        y = x * 2
        
        # Python control flow can cause graph breaks
        if x.sum() > 0:  # Dynamic condition
            z = y + 1
        else:
            z = y - 1
            
        return z
    
    test_input = torch.randn(100, device=device)
    
    # Enable graph break warnings
    import torch._dynamo as dynamo
    
    print("   Compiling function with potential graph breaks...")
    
    # This will show warnings about graph breaks
    try:
        compiled_problematic = torch.compile(problematic_function)
        result = compiled_problematic(test_input)
        print("   ✅ Compilation succeeded despite graph breaks")
    except Exception as e:
        print(f"   ❌ Compilation failed: {e}")
    
    # Issue 2: Dynamic shapes
    print(f"\n🔍 Issue 2: Dynamic Shapes")
    print("-" * 30)
    
    def shape_sensitive_function(x):
        # Function that's sensitive to input shapes
        return x.view(-1, x.shape[-1] // 2, 2).sum(dim=-1)
    
    # This might cause issues with dynamic shapes
    inputs_different_shapes = [
        torch.randn(10, 20, device=device),
        torch.randn(15, 30, device=device),  # Different shape
        torch.randn(20, 40, device=device),  # Another different shape
    ]
    
    print("   Testing with different input shapes...")
    
    try:
        compiled_shape_sensitive = torch.compile(shape_sensitive_function)
        
        for i, inp in enumerate(inputs_different_shapes):
            result = compiled_shape_sensitive(inp)
            print(f"   ✅ Shape {inp.shape}: Success")
            
    except Exception as e:
        print(f"   ❌ Dynamic shapes issue: {e}")
        print("   💡 Solution: Use dynamic=True in torch.compile")
        
        # Try with dynamic compilation
        try:
            compiled_dynamic = torch.compile(shape_sensitive_function, dynamic=True)
            for i, inp in enumerate(inputs_different_shapes):
                result = compiled_dynamic(inp)
                print(f"   ✅ Dynamic shape {inp.shape}: Success")
        except Exception as e2:
            print(f"   ❌ Still failing: {e2}")
    
    # Issue 3: Performance regression detection
    print(f"\n🔍 Issue 3: Performance Regression Detection")
    print("-" * 30)
    
    def potentially_slow_function(x):
        # Simple function that might not benefit from compilation
        return x + 1
    
    simple_input = torch.randn(100, device=device)
    
    # Measure baseline
    times_baseline = []
    for _ in range(50):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        start = time.perf_counter()
        _ = potentially_slow_function(simple_input)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        times_baseline.append(time.perf_counter() - start)
    
    baseline_avg = sum(times_baseline) / len(times_baseline)
    
    # Measure compiled version
    torch._dynamo.reset()
    compiled_simple = torch.compile(potentially_slow_function)
    
    # First run (compilation)
    start = time.perf_counter()
    _ = compiled_simple(simple_input)
    compilation_time = time.perf_counter() - start
    
    # Subsequent runs
    times_compiled = []
    for _ in range(50):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        start = time.perf_counter()
        _ = compiled_simple(simple_input)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        times_compiled.append(time.perf_counter() - start)
    
    compiled_avg = sum(times_compiled) / len(times_compiled)
    
    print(f"   Baseline: {baseline_avg*1000:.3f} ms")
    print(f"   Compiled: {compiled_avg*1000:.3f} ms")
    print(f"   Compilation overhead: {compilation_time*1000:.3f} ms")
    
    if compiled_avg > baseline_avg:
        print("   ⚠️  Performance regression detected!")
        print("   💡 Recommendations:")
        print("      • Try different compilation modes")
        print("      • Consider skipping compilation for simple operations")
        print("      • Check for graph breaks")
    else:
        speedup = baseline_avg / compiled_avg
        print(f"   ✅ Speedup achieved: {speedup:.2f}x")

# Run debugging demonstration
demonstrate_debugging_techniques()

print(f"\n🎓 Debugging Best Practices:")
print(f"   • Always measure performance before and after compilation")
print(f"   • Use environment variables to understand what's happening")
print(f"   • Start with simple cases and add complexity gradually")
print(f"   • Monitor for graph breaks and dynamic shape issues")
print(f"   • Consider selective compilation for problematic functions")

## Production Deployment Best Practices

Deploying compiled PyTorch models in production requires careful consideration of performance, reliability, and maintainability.

### 🚀 Production Deployment Strategy

#### Phase 1: Development and Testing
1. **Profile Your Workloads**: Measure baseline performance
2. **Identify Compilation Candidates**: Focus on hot paths
3. **Test Thoroughly**: Verify correctness and performance
4. **Benchmark Different Modes**: Find optimal compilation settings

#### Phase 2: Staging and Validation
1. **Warm-up Strategy**: Pre-compile during initialization
2. **Error Handling**: Graceful fallbacks for compilation failures
3. **Monitoring**: Track performance and compilation success rates
4. **A/B Testing**: Compare compiled vs uncompiled in production

#### Phase 3: Production Rollout
1. **Gradual Rollout**: Start with small traffic percentage
2. **Performance Monitoring**: Track latency and throughput
3. **Fallback Mechanisms**: Quick rollback if issues arise
4. **Cache Management**: Optimize kernel reuse

### 💡 Production Patterns

#### Pattern 1: Initialization-Time Compilation
```python
class ProductionModel:
    def __init__(self):
        self.model = MyModel()
        
        # Compile during initialization
        self.compiled_model = torch.compile(self.model)
        
        # Warm-up with typical inputs
        self._warmup()
    
    def _warmup(self):
        dummy_input = torch.randn(typical_batch_size, ...)
        _ = self.compiled_model(dummy_input)
```

#### Pattern 2: Conditional Compilation
```python
class AdaptiveModel:
    def __init__(self, enable_compilation=True):
        self.model = MyModel()
        
        if enable_compilation:
            try:
                self.forward = torch.compile(self.model.forward)
                self.compiled = True
            except Exception:
                self.forward = self.model.forward
                self.compiled = False
        else:
            self.forward = self.model.forward
            self.compiled = False
```

#### Pattern 3: Performance Monitoring
```python
class MonitoredModel:
    def __init__(self):
        self.model = torch.compile(MyModel())
        self.performance_metrics = {
            'total_calls': 0,
            'total_time': 0,
            'compilation_failures': 0
        }
    
    def forward(self, x):
        start_time = time.perf_counter()
        try:
            result = self.model(x)
            self.performance_metrics['total_calls'] += 1
            self.performance_metrics['total_time'] += time.perf_counter() - start_time
            return result
        except Exception as e:
            self.performance_metrics['compilation_failures'] += 1
            # Fallback to uncompiled
            return self.model._orig_mod(x)
```

In [None]:
# Production Deployment Template
class ProductionModelTemplate:
    """
    Template for production deployment of compiled PyTorch models
    
    This class demonstrates best practices for:
    - Safe compilation with fallbacks
    - Performance monitoring
    - Warm-up strategies
    - Error handling
    """
    
    def __init__(self, model_class, model_args=None, compilation_config=None):
        """
        Initialize production model with compilation
        
        Args:
            model_class: The PyTorch model class to instantiate
            model_args: Arguments for model initialization
            compilation_config: Configuration for torch.compile
        """
        
        print("🚀 Initializing Production Model")
        print("=" * 40)
        
        # Default configurations
        model_args = model_args or {}
        compilation_config = compilation_config or {
            'mode': 'default',
            'dynamic': True,  # Handle dynamic shapes
            'fullgraph': False  # Allow graph breaks
        }
        
        # Initialize base model
        self.model = model_class(**model_args)
        self.compilation_config = compilation_config
        
        # Performance tracking
        self.metrics = {
            'total_calls': 0,
            'total_time': 0.0,
            'compilation_failures': 0,
            'warmup_time': 0.0,
            'compiled': False
        }
        
        # Attempt compilation
        self._attempt_compilation()
        
        # Warm-up if compilation succeeded
        if self.metrics['compiled']:
            self._warmup()
    
    def _attempt_compilation(self):
        """Safely attempt model compilation with fallback"""
        
        print("🔧 Attempting model compilation...")
        
        try:
            # Create compiled version
            self.compiled_model = torch.compile(
                self.model,
                **self.compilation_config
            )
            
            # Test compilation with dummy input
            dummy_input = self._create_dummy_input()
            
            start_time = time.perf_counter()
            _ = self.compiled_model(dummy_input)
            compilation_time = time.perf_counter() - start_time
            
            self.metrics['compiled'] = True
            self.metrics['compilation_time'] = compilation_time
            
            print(f"✅ Compilation successful")
            print(f"   Compilation time: {compilation_time*1000:.1f} ms")
            
        except Exception as e:
            print(f"❌ Compilation failed: {e}")
            print("   Falling back to uncompiled model")
            
            self.compiled_model = self.model
            self.metrics['compiled'] = False
            self.metrics['compilation_failures'] += 1
    
    def _create_dummy_input(self):
        """Create dummy input for testing and warm-up"""
        # This should be overridden based on your model's expected input
        return torch.randn(1, 128, device=device)
    
    def _warmup(self, num_warmup_runs=5):
        """Warm up the compiled model"""
        
        print(f"🔥 Warming up compiled model ({num_warmup_runs} runs)...")
        
        dummy_input = self._create_dummy_input()
        
        start_time = time.perf_counter()
        
        for i in range(num_warmup_runs):
            try:
                with torch.no_grad():
                    _ = self.compiled_model(dummy_input)
            except Exception as e:
                print(f"   ⚠️  Warmup run {i+1} failed: {e}")
        
        warmup_time = time.perf_counter() - start_time
        self.metrics['warmup_time'] = warmup_time
        
        print(f"✅ Warmup complete")
        print(f"   Total warmup time: {warmup_time*1000:.1f} ms")
        print(f"   Average per run: {warmup_time/num_warmup_runs*1000:.1f} ms")
    
    def forward(self, x):
        """Production forward pass with monitoring"""
        
        start_time = time.perf_counter()
        
        try:
            if self.metrics['compiled']:
                result = self.compiled_model(x)
            else:
                result = self.model(x)
            
            # Update metrics
            execution_time = time.perf_counter() - start_time
            self.metrics['total_calls'] += 1
            self.metrics['total_time'] += execution_time
            
            return result
            
        except Exception as e:
            print(f"⚠️  Forward pass failed: {e}")
            
            # Fallback to uncompiled model
            if self.metrics['compiled']:
                print("   Falling back to uncompiled model")
                self.metrics['compilation_failures'] += 1
                result = self.model(x)
            else:
                raise  # Re-raise if uncompiled model also fails
            
            execution_time = time.perf_counter() - start_time
            self.metrics['total_calls'] += 1
            self.metrics['total_time'] += execution_time
            
            return result
    
    def get_performance_report(self):
        """Generate performance report"""
        
        if self.metrics['total_calls'] == 0:
            return "No calls made yet"
        
        avg_time = self.metrics['total_time'] / self.metrics['total_calls']
        
        report = f"""
📊 Performance Report
{'='*30}
Model Status: {'Compiled' if self.metrics['compiled'] else 'Uncompiled'}
Total Calls: {self.metrics['total_calls']:,}
Total Time: {self.metrics['total_time']*1000:.1f} ms
Average Time: {avg_time*1000:.3f} ms per call
Compilation Failures: {self.metrics['compilation_failures']}
Success Rate: {(1 - self.metrics['compilation_failures']/max(1, self.metrics['total_calls']))*100:.1f}%
        """
        
        if self.metrics.get('compilation_time'):
            report += f"Initial Compilation: {self.metrics['compilation_time']*1000:.1f} ms\n"
        
        if self.metrics.get('warmup_time'):
            report += f"Warmup Time: {self.metrics['warmup_time']*1000:.1f} ms\n"
        
        return report.strip()

# Demonstration of production deployment
def demonstrate_production_deployment():
    """Demonstrate production deployment patterns"""
    
    print("🏭 PRODUCTION DEPLOYMENT DEMONSTRATION")
    print("=" * 50)
    
    # Example model for demonstration
    class ExampleModel(nn.Module):
        def __init__(self, hidden_size=512):
            super().__init__()
            self.norm = nn.LayerNorm(hidden_size)
            self.linear = nn.Linear(hidden_size, hidden_size)
            
        def forward(self, x):
            return F.gelu(self.linear(self.norm(x)))
    
    # Custom production model with proper dummy input
    class ProductionExampleModel(ProductionModelTemplate):
        def _create_dummy_input(self):
            return torch.randn(16, 64, 512, device=device)
    
    # Deploy model
    production_model = ProductionExampleModel(
        model_class=ExampleModel,
        model_args={'hidden_size': 512},
        compilation_config={
            'mode': 'default',
            'dynamic': True
        }
    )
    
    # Simulate production usage
    print(f"\n📈 Simulating Production Usage")
    print("-" * 30)
    
    test_inputs = [
        torch.randn(16, 64, 512, device=device),
        torch.randn(32, 128, 512, device=device),  # Different shape
        torch.randn(8, 32, 512, device=device),    # Another shape
    ]
    
    for i, test_input in enumerate(test_inputs):
        print(f"   Processing batch {i+1} (shape: {test_input.shape})...")
        
        with torch.no_grad():
            result = production_model.forward(test_input)
        
        print(f"   ✅ Success - Output shape: {result.shape}")
    
    # Generate performance report
    print(f"\n{production_model.get_performance_report()}")
    
    return production_model

# Run production deployment demonstration
prod_model = demonstrate_production_deployment()

print(f"\n🎓 Production Best Practices Summary:")
print(f"   ✅ Always include fallback mechanisms")
print(f"   ✅ Monitor performance and failure rates")
print(f"   ✅ Warm up models during initialization")
print(f"   ✅ Handle dynamic shapes appropriately")
print(f"   ✅ Test thoroughly before production deployment")

## 🎓 Summary and Next Steps

### What We've Learned

In this notebook, we've explored the fundamental aspects of PyTorch + Triton compilation:

#### 🧠 **Core Understanding**
- **Compilation Pipeline**: How PyTorch transforms Python code into optimized GPU kernels
- **Two-Phase Performance**: Why first runs are slow but subsequent runs are fast
- **Environment Variables**: Powerful tools for debugging and understanding optimizations
- **Performance Patterns**: When compilation helps and when it doesn't

#### 🔧 **Practical Skills**
- **Environment Setup**: Configuring optimal development environments
- **Performance Analysis**: Measuring and understanding compilation benefits
- **Debugging Techniques**: Solving common compilation issues
- **Production Deployment**: Best practices for real-world applications

#### 📊 **Key Insights**
- Compilation overhead is significant but amortizes over multiple runs
- Different model sizes and patterns have different break-even points
- Environment variables provide deep insights into the compilation process
- Production deployment requires careful error handling and monitoring

### 🚀 Next Steps

#### Immediate Actions
1. **Experiment with Your Models**: Apply `torch.compile()` to your existing PyTorch models
2. **Measure Performance**: Use the techniques from this notebook to analyze benefits
3. **Set Up Environment**: Configure development environment with appropriate logging

#### Advanced Learning
1. **Kernel Optimization**: Dive deeper into specific kernel fusion patterns
2. **Custom Triton Kernels**: Learn to write hand-optimized kernels
3. **Production Deployment**: Implement robust compilation strategies in your applications

#### Continue Your Journey
- **Next Notebook**: "Optimizing PyTorch Kernels with Triton" - Focus on specific optimization patterns
- **Documentation**: Explore PyTorch's compilation documentation
- **Community**: Join discussions about PyTorch optimization techniques

### 💡 Key Takeaways

1. **Compilation is an Investment**: Upfront cost, long-term benefits
2. **Measurement is Critical**: Always profile before optimizing
3. **Environment Variables are Powerful**: Use them to understand and debug
4. **Production Needs Planning**: Robust deployment requires careful design
5. **Start Simple**: Begin with basic patterns and gradually increase complexity

**You now have a solid foundation in PyTorch + Triton fundamentals. Ready to dive deeper into kernel optimization!**