<a href="https://colab.research.google.com/github/kobejean/torch-ops/blob/main/test_roll_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Custom PyTorch Roll Operation Test

This notebook tests the custom `torch.roll` implementation with both CPU and CUDA support.
The implementation includes optimized CUDA kernels inspired by TensorFlow's roll operation.

## Setup Environment

In [None]:
# Check if we're in Colab and have GPU access
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## Clone Repository and Build Extension

In [None]:
# Clone the repository
!git clone https://github.com/kobejean/torch-ops.git
%cd torch-ops

In [None]:
# Install the extension
!pip install -e .

## Import and Basic Tests

In [None]:
import torch
import custom_ops  # Import the custom_ops module directly
import numpy as np
import time

print("Custom ops available:")
print(dir(custom_ops))

## Test CPU Roll Operation

In [None]:
def test_cpu_roll():
    print("=== CPU Roll Tests ===")
    
    # Test 1: Basic 2D roll
    x = torch.arange(12, dtype=torch.float32).reshape(3, 4)
    print(f"Original tensor:\n{x}")
    
    # Roll along dimension 0
    result = custom_ops.roll(x, [1], [0])
    expected = torch.roll(x, [1], [0])
    print(f"\nRoll by 1 along dim 0:")
    print(f"Custom: \n{result}")
    print(f"PyTorch:\n{expected}")
    print(f"Match: {torch.allclose(result, expected)}")
    
    # Test 2: Multi-dimensional roll
    result = custom_ops.roll(x, [1, 2], [0, 1])
    expected = torch.roll(x, [1, 2], [0, 1])
    print(f"\nRoll by [1, 2] along dims [0, 1]:")
    print(f"Custom: \n{result}")
    print(f"PyTorch:\n{expected}")
    print(f"Match: {torch.allclose(result, expected)}")
    
    # Test 3: Negative shifts
    result = custom_ops.roll(x, [-1, -2], [0, 1])
    expected = torch.roll(x, [-1, -2], [0, 1])
    print(f"\nRoll by [-1, -2] along dims [0, 1]:")
    print(f"Custom: \n{result}")
    print(f"PyTorch:\n{expected}")
    print(f"Match: {torch.allclose(result, expected)}")
    
    # Test 4: Zero shifts (should be no-op)
    result = custom_ops.roll(x, [0, 0], [0, 1])
    print(f"\nZero shift test:")
    print(f"Match original: {torch.allclose(result, x)}")
    
    return True

test_cpu_roll()

## Test CUDA Roll Operation (if available)

In [None]:
def test_cuda_roll():
    if not torch.cuda.is_available():
        print("CUDA not available, skipping CUDA tests")
        return False
        
    print("=== CUDA Roll Tests ===")
    
    # Test 1: Basic CUDA roll
    x_cpu = torch.arange(12, dtype=torch.float32).reshape(3, 4)
    x_cuda = x_cpu.cuda()
    
    print(f"Original tensor:\n{x_cpu}")
    
    # Test single dimension roll
    result_cuda = custom_ops.roll(x_cuda, [1], [0])
    expected_cuda = torch.roll(x_cuda, [1], [0])
    
    print(f"\nCUDA Roll by 1 along dim 0:")
    print(f"Custom: \n{result_cuda.cpu()}")
    print(f"PyTorch:\n{expected_cuda.cpu()}")
    print(f"Match: {torch.allclose(result_cuda, expected_cuda)}")
    
    # Test 2: Multi-dimensional CUDA roll
    result_cuda = custom_ops.roll(x_cuda, [1, 2], [0, 1])
    expected_cuda = torch.roll(x_cuda, [1, 2], [0, 1])
    
    print(f"\nCUDA Roll by [1, 2] along dims [0, 1]:")
    print(f"Custom: \n{result_cuda.cpu()}")
    print(f"PyTorch:\n{expected_cuda.cpu()}")
    print(f"Match: {torch.allclose(result_cuda, expected_cuda)}")
    
    # Test 3: Larger tensor for performance
    large_tensor = torch.randn(100, 100, device='cuda')
    result_large = custom_ops.roll(large_tensor, [10, 20], [0, 1])
    expected_large = torch.roll(large_tensor, [10, 20], [0, 1])
    
    print(f"\nLarge tensor (100x100) test:")
    print(f"Match: {torch.allclose(result_large, expected_large)}")
    print(f"Max difference: {torch.max(torch.abs(result_large - expected_large)).item()}")
    
    return True

test_cuda_roll()

## Performance Benchmark

In [None]:
def benchmark_roll():
    print("=== Performance Benchmark ===")
    
    # Test configurations: (size, shifts, dims, description)
    test_configs = [
        # 2D benchmarks with different shift patterns
        ((100, 100), [10, 20], [0, 1], "2D multi-dim shifts"),
        ((500, 500), [50, 100], [0, 1], "2D multi-dim shifts"),
        ((1000, 1000), [100, 200], [0, 1], "2D multi-dim shifts"),
        
        # Single dimension shifts for comparison
        ((100, 100), [10], [0], "2D single-dim shift"),
        ((500, 500), [50], [0], "2D single-dim shift"),
        ((1000, 1000), [100], [0], "2D single-dim shift"),
        
        # 3D benchmarks
        ((50, 50, 50), [5, 10, 15], [0, 1, 2], "3D multi-dim shifts"),
        ((100, 100, 100), [10, 20, 30], [0, 1, 2], "3D multi-dim shifts"),
        
        # Large 2D tensors for scaling test
        ((5000, 5000), [500, 1000], [0, 1], "Large 2D multi-dim"),
    ]
    
    for size, shifts, dims, description in test_configs:
        print(f"\n--- {description}: {size} ---")
        
        # CPU Benchmark
        x_cpu = torch.randn(size, dtype=torch.float32)
        
        # Warmup
        for _ in range(3):
            _ = custom_ops.roll(x_cpu, shifts, dims)
            _ = torch.roll(x_cpu, shifts, dims)
        
        # Benchmark custom CPU
        num_iterations = 20 if max(size) < 1000 else 10
        start = time.time()
        for _ in range(num_iterations):
            result_custom = custom_ops.roll(x_cpu, shifts, dims)
        custom_time = time.time() - start
        
        # Benchmark PyTorch CPU
        start = time.time()
        for _ in range(num_iterations):
            result_pytorch = torch.roll(x_cpu, shifts, dims)
        pytorch_time = time.time() - start
        
        print(f"CPU: Custom={custom_time:.4f}s, PyTorch={pytorch_time:.4f}s, Ratio={pytorch_time/custom_time:.2f}x")
        
        # Verify correctness
        assert torch.allclose(result_custom, result_pytorch), f"CPU mismatch for {description}"
        
        # CUDA Benchmark (if available)
        if torch.cuda.is_available():
            x_cuda = x_cpu.cuda()
            
            # Warmup
            for _ in range(5):
                _ = custom_ops.roll(x_cuda, shifts, dims)
                _ = torch.roll(x_cuda, shifts, dims)
            torch.cuda.synchronize()
            
            # Benchmark custom CUDA
            num_iterations = 50 if max(size) < 1000 else 20
            start = time.time()
            for _ in range(num_iterations):
                result_custom = custom_ops.roll(x_cuda, shifts, dims)
            torch.cuda.synchronize()
            custom_time = time.time() - start
            
            # Benchmark PyTorch CUDA
            start = time.time()
            for _ in range(num_iterations):
                result_pytorch = torch.roll(x_cuda, shifts, dims)
            torch.cuda.synchronize()
            pytorch_time = time.time() - start
            
            print(f"CUDA: Custom={custom_time:.4f}s, PyTorch={pytorch_time:.4f}s, Ratio={pytorch_time/custom_time:.2f}x")
            
            # Verify correctness
            assert torch.allclose(result_custom, result_pytorch), f"CUDA mismatch for {description}"
    
    print("\n=== Summary ===")
    print("✅ All correctness tests passed!")
    print("📊 Benchmarks completed for various tensor sizes and shift patterns")
    print("🔧 Optimized implementation uses:")
    print("   - Constant memory for dimension data")  
    print("   - Shared memory for fast block-local access")
    print("   - Branch-free kernel execution")

benchmark_roll()

## Edge Case Tests

In [None]:
def test_edge_cases():
    print("=== Edge Case Tests ===")
    
    # Test 1: 1D tensor
    x1d = torch.arange(5, dtype=torch.float32)
    result = custom_ops.roll(x1d, [2], [0])
    expected = torch.roll(x1d, [2], [0])
    print(f"1D tensor test: {torch.allclose(result, expected)}")
    
    # Test 2: 3D tensor
    x3d = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
    result = custom_ops.roll(x3d, [1, 1, 2], [0, 1, 2])
    expected = torch.roll(x3d, [1, 1, 2], [0, 1, 2])
    print(f"3D tensor test: {torch.allclose(result, expected)}")
    
    # Test 3: Empty tensor
    empty = torch.empty((0, 5), dtype=torch.float32)
    result = custom_ops.roll(empty, [1], [1])
    expected = torch.roll(empty, [1], [1])
    print(f"Empty tensor test: {result.shape == expected.shape and torch.allclose(result, expected)}")
    
    # Test 4: Large shifts (wrapping)
    x = torch.arange(12, dtype=torch.float32).reshape(3, 4)
    result = custom_ops.roll(x, [10], [0])  # 10 > 3, should wrap
    expected = torch.roll(x, [10], [0])
    print(f"Large shift test: {torch.allclose(result, expected)}")
    
    # Test 5: Negative dimension indices
    result = custom_ops.roll(x, [1], [-1])  # -1 should be dimension 1
    expected = torch.roll(x, [1], [-1])
    print(f"Negative dim test: {torch.allclose(result, expected)}")
    
    # Test 6: Different dtypes
    for dtype in [torch.int32, torch.int64, torch.float64]:
        x_dtype = x.to(dtype)
        result = custom_ops.roll(x_dtype, [1], [0])
        expected = torch.roll(x_dtype, [1], [0])
        print(f"{dtype} test: {torch.allclose(result, expected)}")
    
    print("All edge case tests completed!")

test_edge_cases()

## CUDA Profiling and Analysis

In [None]:
# Check available profiling tools in Colab
import subprocess
import os

print("=== CUDA Profiling Setup ===")

# Check NVIDIA tools availability
tools_to_check = [
    ('nvprof', 'NVIDIA profiler'),
    ('ncu', 'Nsight Compute'),
    ('nvidia-smi', 'NVIDIA System Management'),
]

available_tools = []
for tool, description in tools_to_check:
    try:
        result = subprocess.run([tool, '--version'], capture_output=True, text=True, timeout=5)
        if result.returncode == 0:
            print(f"✅ {description} ({tool}): Available")
            available_tools.append(tool)
        else:
            print(f"❌ {description} ({tool}): Not available")
    except:
        print(f"❌ {description} ({tool}): Not found")

# Check GPU info
try:
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
    if result.returncode == 0:
        print(f"\n=== GPU Information ===")
        lines = result.stdout.split('\n')
        for line in lines:
            if 'Tesla' in line or 'Quadro' in line or 'GeForce' in line or 'GPU' in line:
                print(line.strip())
    else:
        print("❌ Could not get GPU information")
except:
    print("❌ nvidia-smi not available")

print(f"\n=== Available Profiling Methods ===")
if available_tools:
    print("We can use the following profiling approaches:")
    if 'nvprof' in available_tools:
        print("1. nvprof - Legacy profiler (deprecated but often available)")
    if 'ncu' in available_tools:  
        print("2. ncu (Nsight Compute) - Modern profiler")
    print("3. PyTorch Profiler - Built into PyTorch")
    print("4. CUDA Events - Manual timing within code")
else:
    print("Limited profiling tools available, will use PyTorch profiler and CUDA events")

In [None]:
# PyTorch Profiler Analysis
import torch.profiler

def profile_roll_operations():
    """Profile our custom roll vs PyTorch roll using PyTorch's built-in profiler"""
    
    if not torch.cuda.is_available():
        print("CUDA not available for profiling")
        return
        
    print("=== PyTorch Profiler Analysis ===")
    
    # Test tensor
    x = torch.randn(1000, 1000, device='cuda', dtype=torch.float32)
    shifts = [100, 200]  
    dims = [0, 1]
    
    # Profile custom implementation
    print("\n--- Profiling Custom Roll ---")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        with torch.profiler.record_function("custom_roll"):
            for _ in range(10):
                result_custom = custom_ops.roll(x, shifts, dims)
                torch.cuda.synchronize()
    
    # Print key statistics
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    
    # Profile PyTorch implementation  
    print("\n--- Profiling PyTorch Roll ---")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof_pytorch:
        with torch.profiler.record_function("pytorch_roll"):
            for _ in range(10):
                result_pytorch = torch.roll(x, shifts, dims)
                torch.cuda.synchronize()
    
    print(prof_pytorch.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    
    # Memory usage comparison
    print("\n=== Memory Analysis ===")
    custom_events = prof.key_averages()
    pytorch_events = prof_pytorch.key_averages()
    
    print("Custom implementation memory usage:")
    for event in custom_events:
        if 'roll' in event.key.lower():
            # Use cpu_memory_usage and cuda_memory_usage if available
            cpu_mem = getattr(event, 'cpu_memory_usage', 0)
            cuda_mem = getattr(event, 'cuda_memory_usage', 0)
            if hasattr(event, 'cpu_memory_usage') and event.cpu_memory_usage:
                print(f"  {event.key}: CPU={cpu_mem/1024/1024:.2f} MB")
            if hasattr(event, 'cuda_memory_usage') and event.cuda_memory_usage:
                print(f"  {event.key}: CUDA={cuda_mem/1024/1024:.2f} MB")
    
    print("\nPyTorch implementation memory usage:")        
    for event in pytorch_events:
        if 'roll' in event.key.lower():
            cpu_mem = getattr(event, 'cpu_memory_usage', 0)
            cuda_mem = getattr(event, 'cuda_memory_usage', 0)
            if hasattr(event, 'cpu_memory_usage') and event.cpu_memory_usage:
                print(f"  {event.key}: CPU={cpu_mem/1024/1024:.2f} MB")
            if hasattr(event, 'cuda_memory_usage') and event.cuda_memory_usage:
                print(f"  {event.key}: CUDA={cuda_mem/1024/1024:.2f} MB")
    
    # Alternative: use tensor size for memory estimation
    tensor_size_mb = x.numel() * 4 / (1024 * 1024)  # float32 = 4 bytes
    print(f"\nTensor size: {tensor_size_mb:.2f} MB")
    print("Note: Roll operations typically require 2x tensor size (input + output)")

profile_roll_operations()

In [None]:
# CUDA Events for Detailed Timing
def detailed_cuda_timing():
    """Use CUDA events for precise kernel timing"""
    
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
        
    print("=== CUDA Events Timing Analysis ===")
    
    # Test configurations
    sizes = [(500, 500), (1000, 1000), (2000, 2000)]
    
    for size in sizes:
        print(f"\n--- Tensor Size: {size} ---")
        
        x = torch.randn(size, device='cuda', dtype=torch.float32)
        shifts = [50, 100]
        dims = [0, 1]
        
        # Warmup
        for _ in range(5):
            _ = custom_ops.roll(x, shifts, dims)
            _ = torch.roll(x, shifts, dims)
        torch.cuda.synchronize()
        
        # Custom implementation timing
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        start_event.record()
        for _ in range(20):
            result_custom = custom_ops.roll(x, shifts, dims)
        end_event.record()
        torch.cuda.synchronize()
        
        custom_time = start_event.elapsed_time(end_event) / 20  # Average per operation
        
        # PyTorch implementation timing  
        start_event.record()
        for _ in range(20):
            result_pytorch = torch.roll(x, shifts, dims)
        end_event.record()
        torch.cuda.synchronize()
        
        pytorch_time = start_event.elapsed_time(end_event) / 20  # Average per operation
        
        # Calculate memory bandwidth
        bytes_per_element = 4  # float32
        total_elements = x.numel()
        bytes_transferred = total_elements * bytes_per_element * 2  # Read + Write
        
        custom_bandwidth = (bytes_transferred / 1e9) / (custom_time / 1000)  # GB/s
        pytorch_bandwidth = (bytes_transferred / 1e9) / (pytorch_time / 1000)  # GB/s
        
        print(f"Custom:   {custom_time:.3f}ms, {custom_bandwidth:.1f} GB/s")
        print(f"PyTorch:  {pytorch_time:.3f}ms, {pytorch_bandwidth:.1f} GB/s")
        print(f"Speedup:  {pytorch_time/custom_time:.2f}x")
        
        # Verify correctness
        assert torch.allclose(result_custom, result_pytorch), f"Results don't match for {size}"

detailed_cuda_timing()

In [None]:
# nvprof Integration (if available)
def nvprof_analysis():
    """Try to use nvprof for detailed kernel analysis"""
    
    print("=== nvprof Analysis (if available) ===")
    
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
    
    # Create a simple test script that we can profile
    test_script = '''
import torch
import custom_ops

# Simple test
x = torch.randn(1000, 1000, device='cuda', dtype=torch.float32)
shifts = [100, 200]
dims = [0, 1]

# Run operations
for _ in range(5):
    result = custom_ops.roll(x, shifts, dims)
    torch.cuda.synchronize()
    '''
    
    # Write test script
    with open('/tmp/profile_test.py', 'w') as f:
        f.write(test_script)
    
    # Try to run nvprof
    try:
        import subprocess
        result = subprocess.run([
            'nvprof', 
            '--print-gpu-trace',
            '--metrics', 'achieved_occupancy,gld_efficiency,gst_efficiency',
            'python', '/tmp/profile_test.py'
        ], capture_output=True, text=True, timeout=30)
        
        if result.returncode == 0:
            print("nvprof output:")
            print(result.stdout)
            if result.stderr:
                print("nvprof stderr:")
                print(result.stderr)
        else:
            print("nvprof failed or not available")
            print(f"Return code: {result.returncode}")
            print(f"Error: {result.stderr}")
            
    except Exception as e:
        print(f"Could not run nvprof: {e}")
        print("This is normal in Colab - nvprof is often not available")
        print("Use the PyTorch profiler and CUDA events instead")

# Occupancy Analysis  
def analyze_occupancy():
    """Analyze theoretical occupancy of our kernel"""
    
    print("=== Occupancy Analysis ===")
    
    # Get GPU properties
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        props = torch.cuda.get_device_properties(device)
        
        print(f"GPU: {props.name}")
        print(f"Compute Capability: {props.major}.{props.minor}")
        print(f"Multiprocessors: {props.multi_processor_count}")
        print(f"Max threads per MP: {props.max_threads_per_multiprocessor}")
        print(f"Max blocks per MP: {props.max_blocks_per_multiprocessor}")
        print(f"Warp size: {props.warp_size}")
        print(f"Total memory: {props.total_memory / 1024**3:.1f} GB")
        
        # These attributes might not be available in all PyTorch versions
        if hasattr(props, 'max_threads_per_block'):
            print(f"Max threads per block: {props.max_threads_per_block}")
        if hasattr(props, 'max_shared_memory_per_block'):
            print(f"Shared memory per block: {props.max_shared_memory_per_block / 1024:.1f} KB")
        if hasattr(props, 'max_registers_per_block'):
            print(f"Registers per block: {props.max_registers_per_block}")
        
        # Analyze our kernel configuration
        threads_per_block = 256  # Our kernel uses 256 threads
        
        # Calculate theoretical occupancy based on threads
        warps_per_block = threads_per_block // props.warp_size
        max_warps_per_mp = props.max_threads_per_multiprocessor // props.warp_size
        
        # Maximum blocks limited by threads
        max_blocks_by_threads = props.max_threads_per_multiprocessor // threads_per_block
        
        # Use the smaller of the two limits
        max_active_blocks = min(max_blocks_by_threads, props.max_blocks_per_multiprocessor)
        
        # Calculate occupancy
        active_warps = max_active_blocks * warps_per_block
        occupancy_percentage = (active_warps / max_warps_per_mp) * 100
        
        print(f"\n=== Kernel Configuration Analysis ===")
        print(f"Threads per block: {threads_per_block}")
        print(f"Warps per block: {warps_per_block}")
        print(f"Max active blocks per MP: {max_active_blocks}")
        print(f"Active warps per MP: {active_warps}")
        print(f"Max warps per MP: {max_warps_per_mp}")
        print(f"Theoretical occupancy: {occupancy_percentage:.1f}%")
        
        # Memory bandwidth estimation
        memory_clock = getattr(props, 'memory_clock_rate', None)  # in kHz
        memory_bus_width = getattr(props, 'memory_bus_width', None)  # in bits
        
        if memory_clock and memory_bus_width:
            # Convert to GB/s: (clock_rate * bus_width * 2) / 8 / 10^6
            theoretical_bandwidth = (memory_clock * memory_bus_width * 2) / 8 / 1e6
            print(f"\n=== Memory Bandwidth ===")
            print(f"Memory clock: {memory_clock/1000:.0f} MHz")
            print(f"Memory bus width: {memory_bus_width} bits")
            print(f"Theoretical bandwidth: {theoretical_bandwidth:.1f} GB/s")
        
        # Performance recommendations
        print(f"\n=== Performance Analysis ===")
        if occupancy_percentage < 50:
            print("⚠️  Low occupancy - consider optimizing:")
            print("   • Reduce register usage per thread")
            print("   • Adjust block size (try 128 or 512 threads)")
            print("   • Check shared memory usage")
        elif occupancy_percentage < 75:
            print("✅ Reasonable occupancy")
            print("   • Performance is likely memory-bound")
            print("   • Focus on memory access patterns")
        else:
            print("🚀 High occupancy achieved!")
            print("   • Good thread utilization")
            print("   • Optimize memory access patterns for further gains")
            
        # Note about nvprof
        if props.major >= 7 and props.minor >= 5:
            print(f"\n📝 Note: GPU has compute capability {props.major}.{props.minor}")
            print("   nvprof is not supported for CC 7.5+")
            print("   Use Nsight Compute (ncu) or Nsight Systems (nsys) instead")
            
    else:
        print("CUDA not available for occupancy analysis")

# Run analyses
nvprof_analysis()
print("\n" + "="*60 + "\n")
analyze_occupancy()

## Memory Bandwidth Analysis

In [None]:
def analyze_memory_patterns():
    """Analyze memory access patterns and bandwidth utilization"""
    
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
        
    print("=== Memory Bandwidth Analysis ===")
    
    # Get theoretical peak bandwidth
    device = torch.cuda.current_device() 
    props = torch.cuda.get_device_properties(device)
    
    # Estimate peak bandwidth (this varies by GPU model)
    # Common values: GTX 1080 Ti: ~484 GB/s, RTX 2080: ~448 GB/s, Tesla V100: ~900 GB/s
    gpu_name = props.name.lower()
    if 'v100' in gpu_name:
        peak_bandwidth = 900  # GB/s
    elif 'a100' in gpu_name:
        peak_bandwidth = 1555  # GB/s
    elif 'rtx' in gpu_name or 'geforce' in gpu_name:
        peak_bandwidth = 400  # GB/s (conservative estimate)
    elif 'tesla' in gpu_name:
        peak_bandwidth = 500  # GB/s (conservative estimate)
    else:
        peak_bandwidth = 300  # GB/s (very conservative for unknown GPUs)
    
    print(f"GPU: {props.name}")
    print(f"Estimated Peak Bandwidth: {peak_bandwidth} GB/s")
    
    # Test different tensor sizes to analyze scaling
    sizes = [
        (100, 100),      # Small: 40 KB
        (500, 500),      # Medium: 1 MB  
        (1000, 1000),    # Large: 4 MB
        (2000, 2000),    # XL: 16 MB
        (4000, 4000),    # XXL: 64 MB
    ]
    
    print(f"\n{'Size':<12} {'Elements':<10} {'Memory':<8} {'Custom (ms)':<12} {'PyTorch (ms)':<13} {'Custom BW':<10} {'PyTorch BW':<11} {'Efficiency':<10}")
    print("-" * 85)
    
    for size in sizes:
        try:
            elements = size[0] * size[1]
            memory_mb = elements * 4 / (1024 * 1024)  # float32 = 4 bytes
            
            x = torch.randn(size, device='cuda', dtype=torch.float32)
            shifts = [size[0]//10, size[1]//10]  # 10% shifts
            dims = [0, 1]
            
            # Warmup
            for _ in range(3):
                _ = custom_ops.roll(x, shifts, dims)
                _ = torch.roll(x, shifts, dims)
            torch.cuda.synchronize()
            
            # Custom timing
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            
            start.record()
            for _ in range(10):
                result_custom = custom_ops.roll(x, shifts, dims)
            end.record()
            torch.cuda.synchronize()
            custom_time = start.elapsed_time(end) / 10
            
            # PyTorch timing
            start.record()
            for _ in range(10):
                result_pytorch = torch.roll(x, shifts, dims)
            end.record()
            torch.cuda.synchronize()
            pytorch_time = start.elapsed_time(end) / 10
            
            # Calculate bandwidth (read + write)
            bytes_transferred = elements * 4 * 2  # 4 bytes per float, read + write
            custom_bw = (bytes_transferred / 1e9) / (custom_time / 1000)
            pytorch_bw = (bytes_transferred / 1e9) / (pytorch_time / 1000)
            
            # Efficiency relative to peak
            custom_eff = custom_bw / peak_bandwidth * 100
            pytorch_eff = pytorch_bw / peak_bandwidth * 100
            
            print(f"{str(size):<12} {elements:<10} {memory_mb:<8.1f} {custom_time:<12.3f} {pytorch_time:<13.3f} {custom_bw:<10.1f} {pytorch_bw:<11.1f} {custom_eff:<10.1f}%")
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"{str(size):<12} {'OOM':<10}")
                break
            else:
                raise e
    
    print(f"\n=== Memory Access Pattern Analysis ===")
    print("Roll operation characteristics:")
    print("✓ Read: Sequential access to input tensor")
    print("✓ Write: Scattered access to output tensor (depends on shift pattern)")
    print("✓ Memory traffic: 2x tensor size (read input + write output)")
    print("✓ No intermediate storage needed (constant memory for metadata)")
    
    print(f"\nOptimization opportunities:")
    print("• Coalesced memory access depends on shift alignment")
    print("• Large shifts may cause poor cache locality")
    print("• Small tensors may be bandwidth-limited rather than compute-limited")

def compare_shift_patterns():
    """Compare performance with different shift patterns"""
    
    if not torch.cuda.is_available():
        print("CUDA not available")
        return
        
    print("=== Shift Pattern Analysis ===")
    
    size = (1000, 1000)
    x = torch.randn(size, device='cuda', dtype=torch.float32)
    
    # Different shift patterns
    patterns = [
        ([1, 1], [0, 1], "Small shifts (cache-friendly)"),
        ([100, 100], [0, 1], "Medium shifts"),
        ([500, 500], [0, 1], "Large shifts (worst case)"),
        ([0, 100], [0, 1], "Single dimension shift"),
        ([13, 17], [0, 1], "Prime number shifts"),
    ]
    
    print(f"{'Pattern':<30} {'Custom (ms)':<12} {'PyTorch (ms)':<13} {'Speedup':<8}")
    print("-" * 65)
    
    for shifts, dims, description in patterns:
        # Warmup
        for _ in range(3):
            _ = custom_ops.roll(x, shifts, dims)
            _ = torch.roll(x, shifts, dims)
        torch.cuda.synchronize()
        
        # Timing
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        # Custom
        start.record()
        for _ in range(20):
            result_custom = custom_ops.roll(x, shifts, dims)
        end.record()
        torch.cuda.synchronize()
        custom_time = start.elapsed_time(end) / 20
        
        # PyTorch
        start.record()
        for _ in range(20):
            result_pytorch = torch.roll(x, shifts, dims)
        end.record()
        torch.cuda.synchronize()
        pytorch_time = start.elapsed_time(end) / 20
        
        speedup = pytorch_time / custom_time
        print(f"{description:<30} {custom_time:<12.3f} {pytorch_time:<13.3f} {speedup:<8.2f}x")

# Run the analyses
analyze_memory_patterns()
print("\n" + "="*60 + "\n")
compare_shift_patterns()

## Optimization Summary and Recommendations

In [None]:
def optimization_summary():
    """Summary of profiling results and optimization recommendations"""
    
    print("=== Optimization Summary ===")
    print()
    print("🔧 Current Implementation Features:")
    print("   ✓ Constant memory for dimension metadata (fast access)")
    print("   ✓ Direct memory access without shared memory layer")
    print("   ✓ Branch-free kernel execution")
    print("   ✓ Efficient dimension filtering (skip zero shifts)")
    print("   ✓ Support for up to 8 dimensions")
    print()
    
    print("📊 Key Findings from Profiling:")
    print("   • PyTorch uses tensor slicing/concatenation rather than custom kernels")
    print("   • Our custom kernel is optimized for memory bandwidth")
    print("   • Performance depends heavily on:")
    print("     - Tensor size (small tensors: overhead dominates)")
    print("     - Shift patterns (large shifts: poor cache locality)")
    print("     - Memory coalescing (depends on dimension strides)")
    print()
    
    print("🚀 Optimization Recommendations:")
    print()
    print("1. **Hybrid Approach**: Consider switching to PyTorch's tensor operations")
    print("   for smaller tensors where kernel launch overhead dominates")
    print()
    print("2. **Memory Coalescing**: For better coalescing:")
    print("   • Process contiguous dimensions first")
    print("   • Consider transpose operations for better access patterns")
    print()
    print("3. **Kernel Optimization**:")
    print("   • Current kernel uses 256 threads/block - good for most GPUs")
    print("   • Consider dynamic block sizing based on tensor dimensions")
    print("   • Investigate vectorized loads (float4) for better bandwidth")
    print()
    print("4. **When to Use Custom vs PyTorch**:")
    print("   • Use custom kernel for: Large tensors (>1M elements)")
    print("   • Use PyTorch's roll for: Small tensors, complex multi-dim patterns")
    print()
    
    print("🎯 Next Steps for Further Optimization:")
    print("   1. Implement vectorized memory access (float2/float4)")
    print("   2. Add tensor size-based dispatch (custom vs PyTorch)")
    print("   3. Optimize for specific shift patterns")
    print("   4. Consider texture memory for read-only input")
    print("   5. Implement occupancy-driven block size selection")
    print()
    
    print("📋 Usage Guidelines:")
    print("   • For research/prototyping: Use PyTorch's roll (reliable, well-tested)")
    print("   • For production with large tensors: Use custom kernel")
    print("   • For mixed workloads: Implement size-based dispatch")
    print("   • Always profile your specific use case!")

def final_benchmark_summary():
    """Run a final comprehensive benchmark and show results"""
    
    if not torch.cuda.is_available():
        print("CUDA not available for final benchmark")
        return
        
    print("=== Final Benchmark Summary ===")
    
    # Representative test cases
    test_cases = [
        ((100, 100), "Small tensor"),
        ((1000, 1000), "Medium tensor"), 
        ((3000, 3000), "Large tensor"),
    ]
    
    print(f"{'Case':<15} {'Custom (ms)':<12} {'PyTorch (ms)':<13} {'Speedup':<8} {'Recommendation':<20}")
    print("-" * 75)
    
    for size, description in test_cases:
        try:
            x = torch.randn(size, device='cuda', dtype=torch.float32)
            shifts = [size[0]//10, size[1]//10]
            dims = [0, 1]
            
            # Warmup
            for _ in range(3):
                _ = custom_ops.roll(x, shifts, dims)
                _ = torch.roll(x, shifts, dims)
            torch.cuda.synchronize()
            
            # Timing
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            
            # Custom
            start.record()
            for _ in range(10):
                result_custom = custom_ops.roll(x, shifts, dims)
            end.record()
            torch.cuda.synchronize()
            custom_time = start.elapsed_time(end) / 10
            
            # PyTorch
            start.record()
            for _ in range(10):
                result_pytorch = torch.roll(x, shifts, dims)
            end.record()
            torch.cuda.synchronize()
            pytorch_time = start.elapsed_time(end) / 10
            
            speedup = pytorch_time / custom_time
            
            if speedup > 1.2:
                recommendation = "✅ Use Custom"
            elif speedup > 0.8:
                recommendation = "⚖️ Either works"
            else:
                recommendation = "❌ Use PyTorch"
                
            print(f"{description:<15} {custom_time:<12.3f} {pytorch_time:<13.3f} {speedup:<8.2f}x {recommendation:<20}")
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"{description:<15} {'OOM':<12}")
            else:
                raise e
    
    print()
    print("✅ Profiling complete! Use the analysis above to guide your optimization decisions.")

# Run final analysis
optimization_summary()
print("\n" + "="*60 + "\n")
final_benchmark_summary()