# CUDA Transformer Attention - Google Colab

‚ö†Ô∏è **FIRST: Enable GPU!**
- Runtime ‚Üí Change runtime type ‚Üí GPU ‚Üí Save

Then run all cells.

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if not torch.cuda.is_available():
    raise RuntimeError("‚ùå GPU not enabled! Runtime ‚Üí Change runtime type ‚Üí GPU")
print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Clone repo (only if not already cloned)
import os
if not os.path.exists('/content/cuda-transformer-attention'):
    !git clone https://github.com/isahan78/cuda-transformer-attention.git /content/cuda-transformer-attention
else:
    print("Repository already cloned")

%cd /content/cuda-transformer-attention
!pwd

In [None]:
!pip install pytest ninja -q
print("‚úÖ Dependencies installed")

In [None]:
from torch.utils.cpp_extension import load
import os, shutil

# Set CUDA arch
os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5;8.0;8.6;8.9'

# Clear cache
cache = os.path.expanduser('~/.cache/torch_extensions')
if os.path.exists(cache):
    shutil.rmtree(cache, ignore_errors=True)

print("üî® Compiling CUDA extension... (~2-5 min)\n" + "="*70)

cuda_attn = load(
    name="cuda_attn",
    sources=[
        "cuda/attention_qk.cu",
        "cuda/attention_softmax.cu",
        "cuda/attention_av.cu",
        "cuda/attention_fused.cu",
        "cpp/attention_binding.cpp"
    ],
    extra_cuda_cflags=["-O2", "-std=c++17", "-D_GLIBCXX_USE_CXX11_ABI=0"],
    extra_cflags=["-O2", "-std=c++17", "-D_GLIBCXX_USE_CXX11_ABI=0"],
    verbose=True
)

print("="*70 + "\n‚úÖ Compilation successful!")

In [None]:
import sys
sys.path.insert(0, '.')

from python.reference_attention import reference_attention
from python.cuda_attention import cuda_attention_forward

B, H, S, D = 2, 4, 128, 64
Q = torch.randn(B, H, S, D, device='cuda')
K = torch.randn(B, H, S, D, device='cuda')
V = torch.randn(B, H, S, D, device='cuda')

print(f"Test: B={B}, H={H}, S={S}, D={D}\n")

output_ref = reference_attention(Q, K, V)
print(f"‚úì Reference: {output_ref.shape}")

for mode in ['naive', 'tiled', 'fused']:
    output = cuda_attention_forward(Q, K, V, mode=mode)
    diff = (output - output_ref).abs().max().item()
    print(f"‚úì {mode.capitalize():10s}: {output.shape}, diff={diff:.2e}")

print("\n‚úÖ All kernels working!")

In [None]:
def bench(func, *args, **kwargs):
    for _ in range(5): func(*args, **kwargs)
    torch.cuda.synchronize()
    times = []
    for _ in range(20):
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record(); func(*args, **kwargs); e.record()
        torch.cuda.synchronize()
        times.append(s.elapsed_time(e))
    return sum(times) / len(times)

B, H, S, D = 4, 8, 512, 64
Q = torch.randn(B, H, S, D, device='cuda')
K = torch.randn(B, H, S, D, device='cuda')
V = torch.randn(B, H, S, D, device='cuda')

print(f"\nBenchmark: B={B}, H={H}, S={S}, D={D}")
print("="*70)

ref = bench(reference_attention, Q, K, V)
print(f"Reference:  {ref:7.3f} ms")

for mode in ['naive', 'tiled', 'fused']:
    t = bench(cuda_attention_forward, Q, K, V, mode=mode)
    print(f"{mode.capitalize():10s}:  {t:7.3f} ms  ({ref/t:.2f}x speedup)")

print("="*70)

## ‚úÖ Done!

All CUDA kernels compiled and tested successfully. You can now:
- Test with your own data
- Run full test suite: `!pytest tests/ -v`
- Experiment with different sequence lengths