<a href="https://colab.research.google.com/github/kobejean/torch-ops/blob/main/test_roll_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Custom PyTorch Roll Operation Test

This notebook tests the custom `torch.roll` implementation with both CPU and CUDA support.
The implementation includes optimized CUDA kernels inspired by TensorFlow's roll operation.

## Setup Environment

In [None]:
# Check if we're in Colab and have GPU access
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## Clone Repository and Build Extension

In [None]:
# Clone the repository
!git clone https://github.com/kobejean/torch-ops.git
%cd torch-ops

In [None]:
# Install the extension
!pip install -e .

## Import and Basic Tests

In [None]:
import torch
import custom_ops  # Import the custom_ops module directly
import numpy as np
import time

print("Custom ops available:")
print(dir(custom_ops))

## Test CPU Roll Operation

In [None]:
def test_cpu_roll():
    print("=== CPU Roll Tests ===")
    
    # Test 1: Basic 2D roll
    x = torch.arange(12, dtype=torch.float32).reshape(3, 4)
    print(f"Original tensor:\n{x}")
    
    # Roll along dimension 0
    result = custom_ops.roll(x, [1], [0])
    expected = torch.roll(x, [1], [0])
    print(f"\nRoll by 1 along dim 0:")
    print(f"Custom: \n{result}")
    print(f"PyTorch:\n{expected}")
    print(f"Match: {torch.allclose(result, expected)}")
    
    # Test 2: Multi-dimensional roll
    result = custom_ops.roll(x, [1, 2], [0, 1])
    expected = torch.roll(x, [1, 2], [0, 1])
    print(f"\nRoll by [1, 2] along dims [0, 1]:")
    print(f"Custom: \n{result}")
    print(f"PyTorch:\n{expected}")
    print(f"Match: {torch.allclose(result, expected)}")
    
    # Test 3: Negative shifts
    result = custom_ops.roll(x, [-1, -2], [0, 1])
    expected = torch.roll(x, [-1, -2], [0, 1])
    print(f"\nRoll by [-1, -2] along dims [0, 1]:")
    print(f"Custom: \n{result}")
    print(f"PyTorch:\n{expected}")
    print(f"Match: {torch.allclose(result, expected)}")
    
    # Test 4: Zero shifts (should be no-op)
    result = custom_ops.roll(x, [0, 0], [0, 1])
    print(f"\nZero shift test:")
    print(f"Match original: {torch.allclose(result, x)}")
    
    return True

test_cpu_roll()

## Test CUDA Roll Operation (if available)

In [None]:
def test_cuda_roll():
    if not torch.cuda.is_available():
        print("CUDA not available, skipping CUDA tests")
        return False
        
    print("=== CUDA Roll Tests ===")
    
    # Test 1: Basic CUDA roll
    x_cpu = torch.arange(12, dtype=torch.float32).reshape(3, 4)
    x_cuda = x_cpu.cuda()
    
    print(f"Original tensor:\n{x_cpu}")
    
    # Test single dimension roll
    result_cuda = custom_ops.roll(x_cuda, [1], [0])
    expected_cuda = torch.roll(x_cuda, [1], [0])
    
    print(f"\nCUDA Roll by 1 along dim 0:")
    print(f"Custom: \n{result_cuda.cpu()}")
    print(f"PyTorch:\n{expected_cuda.cpu()}")
    print(f"Match: {torch.allclose(result_cuda, expected_cuda)}")
    
    # Test 2: Multi-dimensional CUDA roll
    result_cuda = custom_ops.roll(x_cuda, [1, 2], [0, 1])
    expected_cuda = torch.roll(x_cuda, [1, 2], [0, 1])
    
    print(f"\nCUDA Roll by [1, 2] along dims [0, 1]:")
    print(f"Custom: \n{result_cuda.cpu()}")
    print(f"PyTorch:\n{expected_cuda.cpu()}")
    print(f"Match: {torch.allclose(result_cuda, expected_cuda)}")
    
    # Test 3: Larger tensor for performance
    large_tensor = torch.randn(100, 100, device='cuda')
    result_large = custom_ops.roll(large_tensor, [10, 20], [0, 1])
    expected_large = torch.roll(large_tensor, [10, 20], [0, 1])
    
    print(f"\nLarge tensor (100x100) test:")
    print(f"Match: {torch.allclose(result_large, expected_large)}")
    print(f"Max difference: {torch.max(torch.abs(result_large - expected_large)).item()}")
    
    return True

test_cuda_roll()

## Performance Benchmark

In [None]:
def benchmark_roll():
    print("=== Performance Benchmark ===")
    
    # CPU Benchmark
    print("\nCPU Benchmark:")
    sizes = [(100, 100), (500, 500), (1000, 1000)]
    
    for size in sizes:
        x = torch.randn(size, dtype=torch.float32)
        shifts = [10, 20]
        dims = [0, 1]
        
        # Warmup
        for _ in range(5):
            _ = custom_ops.roll(x, shifts, dims)
            _ = torch.roll(x, shifts, dims)
        
        # Benchmark custom
        start = time.time()
        for _ in range(50):
            result_custom = custom_ops.roll(x, shifts, dims)
        custom_time = time.time() - start
        
        # Benchmark PyTorch
        start = time.time()
        for _ in range(50):
            result_pytorch = torch.roll(x, shifts, dims)
        pytorch_time = time.time() - start
        
        print(f"Size {size}: Custom={custom_time:.4f}s, PyTorch={pytorch_time:.4f}s, Ratio={pytorch_time/custom_time:.2f}x")
        
        # Verify correctness
        assert torch.allclose(result_custom, result_pytorch), f"Mismatch for size {size}"
    
    # CUDA Benchmark
    if torch.cuda.is_available():
        print("\nCUDA Benchmark:")
        
        for size in sizes:
            x = torch.randn(size, dtype=torch.float32, device='cuda')
            shifts = [10, 20]
            dims = [0, 1]
            
            # Warmup
            for _ in range(10):
                _ = custom_ops.roll(x, shifts, dims)
                _ = torch.roll(x, shifts, dims)
            torch.cuda.synchronize()
            
            # Benchmark custom
            start = time.time()
            for _ in range(100):
                result_custom = custom_ops.roll(x, shifts, dims)
            torch.cuda.synchronize()
            custom_time = time.time() - start
            
            # Benchmark PyTorch
            start = time.time()
            for _ in range(100):
                result_pytorch = torch.roll(x, shifts, dims)
            torch.cuda.synchronize()
            pytorch_time = time.time() - start
            
            print(f"Size {size}: Custom={custom_time:.4f}s, PyTorch={pytorch_time:.4f}s, Ratio={pytorch_time/custom_time:.2f}x")
            
            # Verify correctness
            assert torch.allclose(result_custom, result_pytorch), f"CUDA mismatch for size {size}"

benchmark_roll()

## Edge Case Tests

In [None]:
def test_edge_cases():
    print("=== Edge Case Tests ===")
    
    # Test 1: 1D tensor
    x1d = torch.arange(5, dtype=torch.float32)
    result = custom_ops.roll(x1d, [2], [0])
    expected = torch.roll(x1d, [2], [0])
    print(f"1D tensor test: {torch.allclose(result, expected)}")
    
    # Test 2: 3D tensor
    x3d = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
    result = custom_ops.roll(x3d, [1, 1, 2], [0, 1, 2])
    expected = torch.roll(x3d, [1, 1, 2], [0, 1, 2])
    print(f"3D tensor test: {torch.allclose(result, expected)}")
    
    # Test 3: Empty tensor
    empty = torch.empty((0, 5), dtype=torch.float32)
    result = custom_ops.roll(empty, [1], [1])
    expected = torch.roll(empty, [1], [1])
    print(f"Empty tensor test: {result.shape == expected.shape and torch.allclose(result, expected)}")
    
    # Test 4: Large shifts (wrapping)
    x = torch.arange(12, dtype=torch.float32).reshape(3, 4)
    result = custom_ops.roll(x, [10], [0])  # 10 > 3, should wrap
    expected = torch.roll(x, [10], [0])
    print(f"Large shift test: {torch.allclose(result, expected)}")
    
    # Test 5: Negative dimension indices
    result = custom_ops.roll(x, [1], [-1])  # -1 should be dimension 1
    expected = torch.roll(x, [1], [-1])
    print(f"Negative dim test: {torch.allclose(result, expected)}")
    
    # Test 6: Different dtypes
    for dtype in [torch.int32, torch.int64, torch.float64]:
        x_dtype = x.to(dtype)
        result = custom_ops.roll(x_dtype, [1], [0])
        expected = torch.roll(x_dtype, [1], [0])
        print(f"{dtype} test: {torch.allclose(result, expected)}")
    
    print("All edge case tests completed!")

test_edge_cases()

## Summary

In [None]:
print("🎉 All tests completed successfully!")
print("\nCustom torch.roll implementation features:")
print("✅ CPU and CUDA support")
print("✅ Multi-dimensional rolling")
print("✅ Negative shifts and dimensions")
print("✅ Multiple data types (float32, float64, int32, int64)")
print("✅ Optimized CUDA kernel with branch-free execution")
print("✅ Edge case handling (empty tensors, large shifts, etc.)")
print("✅ Compatible with PyTorch's torch.roll interface")