# CUDA Transformer Attention - Google Colab Execution

This notebook compiles and tests the CUDA attention kernels on Google Colab's GPU.

## Setup Steps:
1. **Enable GPU**: Runtime → Change runtime type → GPU (T4 or better)
2. **Clone repository** (or upload files)
3. **Compile CUDA extension**
4. **Run tests and benchmarks**

## Step 1: Check GPU Availability

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("⚠️ WARNING: GPU not available! Please enable GPU in Runtime settings.")

## Step 2: Clone Repository

Replace `<your_username>` with your GitHub username if you've pushed the code.

In [None]:
# Option 1: Clone from GitHub
# !git clone https://github.com/<your_username>/cuda-transformer-attention.git
# %cd cuda-transformer-attention

# Option 2: If files are already uploaded, just navigate to the directory
# %cd cuda-transformer-attention

# For testing, let's check if we're in the right directory
!pwd
!ls -la

## Step 3: Install Dependencies

In [None]:
# Install required packages
!pip install pytest -q

## Step 4: Compile CUDA Extension

This uses PyTorch's JIT compilation to build the C++/CUDA extension.

In [None]:
from torch.utils.cpp_extension import load
import os

# Ensure we're in the project directory
# os.chdir('/content/cuda-transformer-attention')  # Adjust path as needed

print("Compiling CUDA extension... This may take 2-5 minutes.")
print("="*70)

cuda_attn = load(
    name="cuda_attn",
    sources=[
        "cuda/attention_qk.cu",
        "cuda/attention_softmax.cu",
        "cuda/attention_av.cu",
        "cuda/attention_fused.cu",
        "cpp/attention_binding.cpp"
    ],
    extra_cuda_cflags=[
        "-O3",
        "--use_fast_math",
        "-std=c++14"
    ],
    verbose=True
)

print("="*70)
print("✅ Compilation successful!")
print(f"Extension module: {cuda_attn}")

## Step 5: Basic Functionality Test

In [None]:
import sys
sys.path.insert(0, '.')

from python.reference_attention import reference_attention
from python.cuda_attention import cuda_attention_forward

# Create test inputs
B, H, S, D = 2, 4, 128, 64
Q = torch.randn(B, H, S, D, device='cuda')
K = torch.randn(B, H, S, D, device='cuda')
V = torch.randn(B, H, S, D, device='cuda')

print(f"Test configuration: B={B}, H={H}, S={S}, D={D}")
print("\nTesting different kernel modes...\n")

# Test reference
output_ref = reference_attention(Q, K, V)
print(f"✓ Reference:    shape={output_ref.shape}")

# Test CUDA kernels
for mode in ['naive', 'tiled', 'fused']:
    output = cuda_attention_forward(Q, K, V, mode=mode)
    
    # Check correctness
    max_diff = (output - output_ref).abs().max().item()
    mean_diff = (output - output_ref).abs().mean().item()
    
    print(f"✓ {mode.capitalize():10s}: shape={output.shape}, "
          f"max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}")

print("\n✅ All modes working correctly!")

## Step 6: Test Causal Masking

In [None]:
print("Testing causal (autoregressive) attention...\n")

output_ref_causal = reference_attention(Q, K, V, is_causal=True)
print(f"✓ Reference (causal):  shape={output_ref_causal.shape}")

for mode in ['naive', 'tiled', 'fused']:
    output = cuda_attention_forward(Q, K, V, mode=mode, is_causal=True)
    max_diff = (output - output_ref_causal).abs().max().item()
    mean_diff = (output - output_ref_causal).abs().mean().item()
    
    print(f"✓ {mode.capitalize():10s} (causal): max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}")

print("\n✅ Causal masking working correctly!")

## Step 7: Performance Benchmark

In [None]:
import time

def benchmark(func, *args, warmup=5, repeat=20, **kwargs):
    """Benchmark a function with CUDA synchronization."""
    # Warmup
    for _ in range(warmup):
        func(*args, **kwargs)
    torch.cuda.synchronize()
    
    # Timing
    times = []
    for _ in range(repeat):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        func(*args, **kwargs)
        end.record()
        
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))
    
    return sum(times) / len(times), min(times), max(times)

# Benchmark configuration
configs = [
    (4, 8, 256, 64, "Small"),
    (4, 8, 512, 64, "Medium"),
    (2, 8, 1024, 64, "Large"),
]

print("\n" + "="*80)
print(" "*25 + "PERFORMANCE BENCHMARK")
print("="*80)

for B, H, S, D, name in configs:
    print(f"\n{name}: B={B}, H={H}, S={S}, D={D}")
    print("-" * 80)
    
    # Create inputs
    Q = torch.randn(B, H, S, D, device='cuda')
    K = torch.randn(B, H, S, D, device='cuda')
    V = torch.randn(B, H, S, D, device='cuda')
    
    # Benchmark reference
    mean_t, min_t, max_t = benchmark(reference_attention, Q, K, V)
    ref_time = mean_t
    print(f"Reference:  {mean_t:7.3f} ms  (min: {min_t:6.3f}, max: {max_t:6.3f})")
    
    # Benchmark CUDA modes
    for mode in ['naive', 'tiled', 'fused']:
        mean_t, min_t, max_t = benchmark(cuda_attention_forward, Q, K, V, mode=mode)
        speedup = ref_time / mean_t if mean_t > 0 else 0
        print(f"{mode.capitalize():10s}:  {mean_t:7.3f} ms  (min: {min_t:6.3f}, max: {max_t:6.3f})  "
              f"Speedup: {speedup:.2f}x")

print("\n" + "="*80)

## Step 8: Run Test Suite

In [None]:
# Run correctness tests
print("Running correctness tests...\n")
!pytest tests/test_correctness.py -v --tb=short

print("\n" + "="*80)
print("Running mask tests...\n")
!pytest tests/test_masks.py -v --tb=short

## Step 9: Memory Usage Comparison

In [None]:
print("Memory Usage Comparison\n" + "="*80)

B, H, S, D = 2, 8, 1024, 64
Q = torch.randn(B, H, S, D, device='cuda')
K = torch.randn(B, H, S, D, device='cuda')
V = torch.randn(B, H, S, D, device='cuda')

print(f"Configuration: B={B}, H={H}, S={S}, D={D}\n")

for mode in ['naive', 'tiled', 'fused']:
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    start_mem = torch.cuda.memory_allocated()
    output = cuda_attention_forward(Q, K, V, mode=mode)
    torch.cuda.synchronize()
    peak_mem = torch.cuda.max_memory_allocated()
    
    mem_used = (peak_mem - start_mem) / (1024**2)  # MB
    print(f"{mode.capitalize():10s}: {mem_used:8.2f} MB")
    
    del output

print("\n" + "="*80)
print("Note: Fused kernel should use less memory for large sequences")
print("as it avoids materializing the full attention matrix.")

## Summary

This notebook demonstrated:
1. ✅ Successful compilation of CUDA kernels on Google Colab
2. ✅ Correctness validation against PyTorch reference
3. ✅ Causal masking support
4. ✅ Performance benchmarking
5. ✅ Memory usage comparison

### Key Takeaways:
- **Naive kernel**: Simpler but slower, good for debugging
- **Tiled kernel**: Better performance through shared memory optimization
- **Fused kernel**: FlashAttention-style, memory-efficient for long sequences

### Next Steps:
- Experiment with different configurations
- Test on your own data
- Compare with PyTorch's native attention
- Profile with Nsight Systems for detailed analysis