# Dense vs Sparse Attention: Correctness and Performance

This notebook verifies the implementation of Block-Sparse Attention and compares its performance against standard Dense Attention.

## Goals
1. **Verify Correctness**: Ensure the optimized CUDA/CPU kernels produce the same output as the reference implementation.
2. **Benchmark Performance**: Measure the speedup of Sparse Attention for long sequences.
3. **Visualize Sparsity**: Show the attention patterns.

In [2]:
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# Set project path - ADJUST THIS TO YOUR PATH
project_path = '/workspace/manav/dl_sys_project/'
os.chdir(project_path)

# Check GPU
print("Checking GPU availability...")
try:
    import subprocess
    gpu_info = subprocess.check_output(['nvidia-smi'], stderr=subprocess.STDOUT)
    print("✓ GPU Available")
except:
    print("✗ No GPU detected - will use CPU")

# Rebuild project
print("\nRebuilding project...")
!make clean
!make

# Setup paths
sys.path.insert(0, os.path.join(project_path, 'python'))
sys.path.insert(0, os.path.join(project_path, 'apps'))
import sys
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import needle as ndl
import needle.nn as nn
from needle.nn.nn_sparse_attention import BlockSparseMultiHeadAttention, BlockSparsePattern

# Set device
try:
    device = ndl.cuda()
    print("Using CUDA (GPU)")
except:
    device = ndl.cpu()
    print("Using CPU")

print(f"Needle backend: {ndl.backend_selection.BACKEND}")

Checking GPU availability...
✓ GPU Available

Rebuilding project...
rm -rf build python/needle/backend_ndarray/ndarray_backend*.so
-- The C compiler identification is GNU 13.3.0
-- The CXX compiler identification is GNU 13.3.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /usr/local/bin/python (found version "3.12.3") found components: Development Interpreter Development.Module Development.Embed 
-- Performing Test HAS_FLTO_AUTO
-- Performing Test HAS_FLTO_AUTO - Success
-- Found pybind11: /usr/local/lib/python3.12/dist-packages/pybind11/include (found version "3.0.1")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-

## 1. Correctness Verification

We compare the output of the optimized Sparse Attention kernel against a Dense Attention implementation that uses a mask to simulate sparsity.

In [3]:
def verify_correctness(seq_len=128, block_size=16, num_heads=4, dim=32):
    print(f"Verifying with SeqLen={seq_len}, BlockSize={block_size}...")
    
    # Create random inputs
    q = ndl.Tensor(np.random.randn(1, num_heads, seq_len, dim), device=device)
    k = ndl.Tensor(np.random.randn(1, num_heads, seq_len, dim), device=device)
    v = ndl.Tensor(np.random.randn(1, num_heads, seq_len, dim), device=device)
    
    # Initialize Sparse Attention Module
    sparse_attn = BlockSparseMultiHeadAttention(
        device=device, 
        block_size=block_size, 
        sparse_pattern="local",
        dropout=0.0
    )
    
    # 1. Run Optimized Sparse Attention (uses kernel if available)
    # We force the use of the kernel by ensuring we are on CUDA/CPU and calling forward
    out_sparse, _ = sparse_attn(q, k, v)
    
    # 2. Run Reference Implementation (Dense with Mask)
    # We manually construct the mask and compute attention
    # This mimics what the "slow" path in nn_sparse_attention.py does
    
    # Compute scores
    scores = sparse_attn.matmul(q, k)
    scores = scores / (dim ** 0.5)
    
    # Apply mask
    mask = sparse_attn.create_block_mask(seq_len, device)
    mask_tensor = ndl.Tensor(mask, device=device, requires_grad=False)
    mask_expanded = mask_tensor.reshape((1, 1, seq_len, seq_len))
    mask_broadcast = mask_expanded.broadcast_to(scores.shape)
    
    scores_masked = scores + mask_broadcast
    
    # Softmax
    probs = sparse_attn.softmax(scores_masked)
    
    # Apply to values
    v_transpose = ndl.ops.transpose(v, axes=(2, 3))
    out_dense_masked = sparse_attn.matmul(probs, v_transpose)
    
    # Compare
    diff = (out_sparse - out_dense_masked).numpy()
    max_diff = np.abs(diff).max()
    
    print(f"Max difference: {max_diff:.6f}")
    if max_diff < 1e-4:
        print("✓ Verification PASSED")
    else:
        print("✗ Verification FAILED")
        
verify_correctness()

Verifying with SeqLen=128, BlockSize=16...
Max difference: 0.000001
✓ Verification PASSED


## 2. Performance Benchmark

We measure the forward pass time for increasing sequence lengths.

In [4]:
def benchmark_performance(seq_lens=[128, 256, 512, 1024, 2048], block_size=32):
    dense_times = []
    sparse_times = []
    
    num_heads = 4
    dim = 32
    batch_size = 4
    
    print(f"Benchmarking (Batch={batch_size}, Heads={num_heads}, Dim={dim})...")
    print(f"{'SeqLen':<10} {'Dense (ms)':<15} {'Sparse (ms)':<15} {'Speedup':<10}")
    print("-"*50)
    
    for seq_len in seq_lens:
        # Inputs
        q = ndl.Tensor(np.random.randn(batch_size, num_heads, seq_len, dim), device=device)
        k = ndl.Tensor(np.random.randn(batch_size, num_heads, seq_len, dim), device=device)
        v = ndl.Tensor(np.random.randn(batch_size, num_heads, seq_len, dim), device=device)
        
        # Sparse Module
        sparse_attn = BlockSparseMultiHeadAttention(
            device=device, block_size=block_size, sparse_pattern="local"
        )
        
        # Dense Reference (Standard Attention)
        # We simulate dense by using a full mask (all ones)
        # Or just use the matmul/softmax manually to avoid overhead of creating sparse mask
        
        def run_dense():
            scores = sparse_attn.matmul(q, k) / (dim**0.5)
            probs = sparse_attn.softmax(scores)
            v_T = ndl.ops.transpose(v, axes=(2, 3))
            return sparse_attn.matmul(probs, v_T)
            
        # Warmup
        _ = run_dense()
        _ = sparse_attn(q, k, v)
        
        # Time Dense
        start = time.time()
        for _ in range(10):
            _ = run_dense()
            if device.name == 'cuda':
                ndl.backend_ndarray.ndarray_backend_cuda.Compact(q.compact()._handle, q.compact()._handle, [], [], 0) # Sync?
                # Needle CUDA backend is synchronous for now usually, or we trust time.time()
        dense_time = (time.time() - start) / 10 * 1000
        
        # Time Sparse
        start = time.time()
        for _ in range(10):
            _ = sparse_attn(q, k, v)
        sparse_time = (time.time() - start) / 10 * 1000
        
        speedup = dense_time / sparse_time
        
        dense_times.append(dense_time)
        sparse_times.append(sparse_time)
        
        print(f"{seq_len:<10} {dense_time:<15.2f} {sparse_time:<15.2f} {speedup:<10.2f}x")
        
    return seq_lens, dense_times, sparse_times

seq_lens, dense_times, sparse_times = benchmark_performance()

# Plot
plt.figure(figsize=(10, 6))
plt.plot(seq_lens, dense_times, 'o-', label='Dense Attention')
plt.plot(seq_lens, sparse_times, 's-', label='Sparse Attention')
plt.xlabel('Sequence Length')
plt.ylabel('Time (ms)')
plt.title('Dense vs Sparse Attention Performance')
plt.legend()
plt.grid(True)
plt.show()

Benchmarking (Batch=4, Heads=4, Dim=32)...
SeqLen     Dense (ms)      Sparse (ms)     Speedup   
--------------------------------------------------
Falling back to slow implementation: Kernel execution failed: an illegal memory access was encountered
Attempting CUDA device reset to recover from error...
CUDA reset failed: module 'needle.backend_ndarray.ndarray_backend_cuda' has no attribute 'cuda_reset'


RuntimeError: CUDA Malloc failed: an illegal memory access was encountered

## 3. Sparsity Pattern Visualization

Visualizing the block-sparse patterns used.

In [None]:
def visualize_pattern(pattern_name, seq_len=64, block_size=8):
    if pattern_name == "local":
        mask = BlockSparsePattern.local_pattern(seq_len, block_size)
    elif pattern_name == "global":
        mask = BlockSparsePattern.global_pattern(seq_len, block_size)
    elif pattern_name == "mixed":
        mask = BlockSparsePattern.mixed_pattern(seq_len, block_size)
        
    plt.figure(figsize=(6, 6))
    plt.imshow(mask, cmap='binary', interpolation='nearest')
    plt.title(f"{pattern_name.capitalize()} Pattern (Blocks)")
    plt.xlabel("Key Blocks")
    plt.ylabel("Query Blocks")
    plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5)
    plt.show()

visualize_pattern("local")
visualize_pattern("global")
visualize_pattern("mixed")