# Problem 4: Causal Flash Attention

Implementation of FlashAttention-2 forward pass using Triton for causal attention.

This implementation uses a two-phase approach:

1. **Off-diagonal blocks**: Process blocks where all queries have indices > all keys (no masking needed)
2. **Diagonal blocks**: Process blocks where queries and keys can overlap (causal masking required)


In [None]:
import torch
import triton
import triton.language as tl
import math

@triton.jit
def _flash_attention_forward_causal_kernel(
    # Pointers to Tensors
    Q_ptr, K_ptr, V_ptr, O_ptr,
    # Stride information for tensors
    q_stride_b, q_stride_h, q_stride_s,
    k_stride_b, k_stride_h, k_stride_s,
    v_stride_b, v_stride_h, v_stride_s,
    # Kernel parameters
    softmax_scale,
    SEQ_LEN,
    N_HEADS,
    # Constexpr tile sizes
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """
    Triton kernel for the forward pass of causal FlashAttention.
    This is a template for student implementation.
    """
    # 1. Identify the block of queries and the batch/head to be processed.
    q_block_idx = tl.program_id(axis=0)
    batch_head_idx = tl.program_id(axis=1)
    
    batch_idx = batch_head_idx // N_HEADS
    head_idx = batch_head_idx % N_HEADS

    # 2. Initialize accumulators in SRAM.
    m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # 3. Load the block of queries (Q_i).
    q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M))
    q_ptrs = Q_ptr + batch_idx * q_stride_b + head_idx * q_stride_h + \
             (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :])
    q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0)
    q_block = q_block.to(tl.float32)  # Convert to float32 for computation
    
    # PyTorch softmax is exp(x), Triton is exp2(x * log2(e)), log2(e) is approx 1.44269504
    qk_scale = softmax_scale * 1.44269504

    # --- Phase 1: Accumulate in Off-Diagonal Blocks (No Masking) ---
    # Process key/value blocks that are strictly in the past (q_idx > k_idx).
    for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N):
        # Load K_j
        k_offsets = start_n + tl.arange(0, BLOCK_N)
        k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
                 (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
        k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)
        k_block = k_block.to(tl.float32)
        
        # Compute attention scores S_ij = Q_i * K_j^T
        s_ij = tl.dot(q_block, k_block)
        s_ij *= qk_scale
        
        # Load V_j
        v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
                 (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
        v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)
        v_block = v_block.to(tl.float32)
        
        # Online softmax update (same as non-causal case)
        m_ij = tl.max(s_ij, 1)  # Row-wise max of current scores
        m_new = tl.maximum(m_i, m_ij)  # Element-wise max with previous running max
        
        alpha = tl.exp2(m_i - m_new)  # Rescaling factor for previous values
        beta = tl.exp2(m_ij - m_new)  # Rescaling factor for current values
        
        # Rescale previous accumulator and denominator
        acc = acc * alpha[:, None]
        l_i = l_i * alpha
        
        # Compute probabilities and update accumulator
        p_ij = tl.exp2(s_ij - m_new[:, None])
        acc += tl.dot(p_ij, v_block)
        l_i += tl.sum(p_ij, 1)
        
        # Update running maximum
        m_i = m_new


    # --- Phase 2: Run on the Diagonal Blocks (With Masking) ---
    # Process the blocks where query and key indices can overlap.
    diag_start = q_block_idx * BLOCK_M
    for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N):
        # Load K_j
        k_offsets = start_n + tl.arange(0, BLOCK_N)
        k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
                 (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
        k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)
        k_block = k_block.to(tl.float32)
        
        # Compute attention scores S_ij = Q_i * K_j^T
        s_ij = tl.dot(q_block, k_block)
        s_ij *= qk_scale
        
        # Apply causal mask: q_idx >= k_idx
        # Create causal mask where s_ij[i, j] = -inf if q_offsets[i] < k_offsets[j]
        causal_mask = q_offsets[:, None] >= k_offsets[None, :]
        s_ij = tl.where(causal_mask, s_ij, -float('inf'))
        
        # Load V_j
        v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
                 (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
        v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)
        v_block = v_block.to(tl.float32)
        
        # Online softmax update with masked scores
        m_ij = tl.max(s_ij, 1)  # Row-wise max of current scores
        m_new = tl.maximum(m_i, m_ij)  # Element-wise max with previous running max
        
        alpha = tl.exp2(m_i - m_new)  # Rescaling factor for previous values
        beta = tl.exp2(m_ij - m_new)  # Rescaling factor for current values
        
        # Rescale previous accumulator and denominator
        acc = acc * alpha[:, None]
        l_i = l_i * alpha
        
        # Compute probabilities and update accumulator
        p_ij = tl.exp2(s_ij - m_new[:, None])
        acc += tl.dot(p_ij, v_block)
        l_i += tl.sum(p_ij, 1)
        
        # Update running maximum
        m_i = m_new


    # 4. Normalize and write the final output block.
    l_i_safe = l_i[:, None] + 1e-6
    acc = acc / l_i_safe
    
    o_ptrs = O_ptr + batch_idx * q_stride_b + head_idx * q_stride_h + \
             (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :])
             
    tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=q_offsets[:, None] < SEQ_LEN)

def flash_attention_forward(q, k, v, is_causal=True):
    """
    Python wrapper for the single-kernel, two-phase causal FlashAttention.
    """
    if not is_causal:
        raise NotImplementedError("This implementation is for causal attention. Use solution_3 for non-causal.")

    batch, n_heads, seq_len, head_dim = q.shape
    o = torch.empty_like(q)
    softmax_scale = 1.0 / math.sqrt(head_dim)
    
    BLOCK_M, BLOCK_N = 128, 64
    grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_heads)

    _flash_attention_forward_causal_kernel[grid](
        q, k, v, o,
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        softmax_scale,
        seq_len,
        n_heads,
        HEAD_DIM=head_dim,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
    )
    return o

In [None]:
# Autograder for Problem 4: Causal Flash Attention
import torch
import torch.nn.functional as F
import math
import time

DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

def repeat_kv(x, num_groups):
    """Helper function to repeat K/V heads for GQA naive implementation."""
    if num_groups == 1:
        return x
    B, H_kv, N, D = x.shape
    x = x.unsqueeze(2).expand(B, H_kv, num_groups, N, D)
    return x.reshape(B, H_kv * num_groups, N, D)

def create_mask_bool(
    seq_len: int,
    window_size: int,
    sink_size: int,
    device=None
    ) -> torch.Tensor:
    
    idx = torch.arange(seq_len, device=device)
    row = idx.unsqueeze(1)   # (seq_len, 1)
    col = idx.unsqueeze(0)   # (1, seq_len)

    # 1) sliding window:  i - (window_size-1) <= j <= i
    sliding = (col <= row) & (col >= row - (window_size - 1))

    # 2) sink at start:   j < sink_size  *and*  j <= i
    sink = (col < sink_size) & (col <= row)

    return sliding | sink

def naive_attention(Q, K, V, is_causal=False, window_size=None, sink_size=None):
    """
    A correct, robust PyTorch implementation of standard attention for comparison.
    Supports GQA, Sliding Window, and Attention Sinks.
    """
    
    batch_size, num_heads_q, seq_len, head_dim = Q.shape
    _, num_heads_kv, seq_len, head_dim = K.shape

    if num_heads_q != num_heads_kv:
        num_groups = num_heads_q // num_heads_kv
        K = repeat_kv(K, num_groups)
        V = repeat_kv(V, num_groups)

    scale = 1.0 / math.sqrt(head_dim)
    S = (Q @ K.transpose(-1, -2)) * scale
    
    if is_causal:
        mask = None
        if window_size is None: # Causal only
            mask = create_mask_bool(seq_len=seq_len, window_size=seq_len, sink_size=0, device=Q.device)
        else:
            if sink_size is None: # SWA only
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=0, device=Q.device)
            else: # SWA + Sink
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=sink_size, device=Q.device)
                
        S.masked_fill_(~mask, -float('inf'))

    return F.softmax(S, dim=-1, dtype=torch.float32).to(Q.dtype) @ V

def run_correctness_test(test_case, student_fn, is_causal=False, is_gqa=False, is_swa=False, problem_num=None):
    """Run a single correctness test case."""
    if len(test_case) == 4:
        batch, n_heads, seq_len, head_dim = test_case
        n_heads_kv = n_heads
        window_size, sink_size = None, None
    elif len(test_case) == 5:
        batch, n_heads, n_heads_kv, seq_len, head_dim = test_case
        window_size, sink_size = None, None
    elif len(test_case) == 7:
        batch, n_heads, n_heads_kv, seq_len, head_dim, window_size, sink_size = test_case
    else:
        raise ValueError(f"Invalid test case format: {test_case}")

    Q = torch.randn(batch, n_heads, seq_len, head_dim, dtype=DTYPE, device='cuda')
    K = torch.randn(batch, n_heads_kv, seq_len, head_dim, dtype=DTYPE, device='cuda')
    V = torch.randn(batch, n_heads_kv, seq_len, head_dim, dtype=DTYPE, device='cuda')

    expected = naive_attention(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
    
    try:
        if is_gqa and is_swa:
            actual = student_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        elif is_gqa:
            actual = student_fn(Q, K, V, is_causal=is_causal)
        elif is_swa:
            actual = student_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        else:
            actual = student_fn(Q, K, V, is_causal=is_causal)
    except Exception as e:
        print(f"❌ P{problem_num} Test Failed! (B={batch}, H={n_heads}, L={seq_len}, D={head_dim}) - Error: {e}")
        return False

    if torch.allclose(actual, expected, atol=1e-2, rtol=1e-2):
        if len(test_case) == 4:
            print(f"✅ P{problem_num} Correctness Test Passed! (B={batch}, H={n_heads}, L={seq_len}, D={head_dim})")
        elif len(test_case) == 5:
            print(f"✅ P{problem_num} Correctness Test Passed! (B={batch}, H_Q={n_heads}, H_KV={n_heads_kv}, L={seq_len}, D={head_dim})")
        elif len(test_case) == 7:
            print(f"✅ P{problem_num} Correctness Test Passed! (B={batch}, H_Q={n_heads}, H_KV={n_heads_kv}, L={seq_len}, D={head_dim}, W={window_size}, S={sink_size})")
        return True
    else:
        if len(test_case) == 4:
            print(f"❌ P{problem_num} Test Failed! (B={batch}, H={n_heads}, L={seq_len}, D={head_dim}) - Results do not match.")
        elif len(test_case) == 5:
            print(f"❌ P{problem_num} Test Failed! (B={batch}, H_Q={n_heads}, H_KV={n_heads_kv}, L={seq_len}, D={head_dim}) - Results do not match.")
        elif len(test_case) == 7:
            print(f"❌ P{problem_num} Test Failed! (B={batch}, H_Q={n_heads}, H_KV={n_heads_kv}, L={seq_len}, D={head_dim}, W={window_size}, S={sink_size}) - Results do not match.")
        print(f"Max difference: {torch.max(torch.abs(actual - expected)).item():.6f}")
        return False

def benchmark_attention(triton_fn, pytorch_fn, test_case, is_causal=False, is_gqa=False, is_swa=False):
    """Benchmark the Triton implementation against PyTorch."""
    if len(test_case) == 4:
        batch, n_heads, seq_len, head_dim = test_case
        n_heads_kv = n_heads
        window_size, sink_size = None, None
    elif len(test_case) == 5:
        batch, n_heads, n_heads_kv, seq_len, head_dim = test_case
        window_size, sink_size = None, None
    elif len(test_case) == 7:
        batch, n_heads, n_heads_kv, seq_len, head_dim, window_size, sink_size = test_case
    else:
        raise ValueError(f"Invalid test case format: {test_case}")

    print("\n--- Running Performance Benchmark ---")
    if len(test_case) == 4:
        print(f"Benchmark Config: B={batch}, H={n_heads}, L={seq_len}, D={head_dim}, Causal={is_causal}")
    elif len(test_case) == 5:
        print(f"Benchmark Config: B={batch}, H_Q={n_heads}, H_KV={n_heads_kv}, L={seq_len}, D={head_dim}, Causal={is_causal}")
    elif len(test_case) == 7:
        print(f"Benchmark Config: B={batch}, H_Q={n_heads}, H_KV={n_heads_kv}, L={seq_len}, D={head_dim}, W={window_size}, S={sink_size}, Causal={is_causal}")

    Q = torch.randn(batch, n_heads, seq_len, head_dim, dtype=DTYPE, device='cuda')
    K = torch.randn(batch, n_heads_kv, seq_len, head_dim, dtype=DTYPE, device='cuda')
    V = torch.randn(batch, n_heads_kv, seq_len, head_dim, dtype=DTYPE, device='cuda')

    # Warm up
    for _ in range(10):
        if is_gqa and is_swa:
            _ = triton_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        elif is_gqa:
            _ = triton_fn(Q, K, V, is_causal=is_causal)
        elif is_swa:
            _ = triton_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        else:
            _ = triton_fn(Q, K, V, is_causal=is_causal)
        _ = pytorch_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)

    torch.cuda.synchronize()

    # Benchmark Triton
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    triton_times = []
    for _ in range(100):
        start_event.record()
        if is_gqa and is_swa:
            _ = triton_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        elif is_gqa:
            _ = triton_fn(Q, K, V, is_causal=is_causal)
        elif is_swa:
            _ = triton_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        else:
            _ = triton_fn(Q, K, V, is_causal=is_causal)
        end_event.record()
        torch.cuda.synchronize()
        triton_times.append(start_event.elapsed_time(end_event))

    # Benchmark PyTorch
    pytorch_times = []
    for _ in range(100):
        start_event.record()
        _ = pytorch_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        end_event.record()
        torch.cuda.synchronize()
        pytorch_times.append(start_event.elapsed_time(end_event))

    triton_avg = sum(triton_times) / len(triton_times)
    pytorch_avg = sum(pytorch_times) / len(pytorch_times)

    # Memory usage
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    if is_gqa and is_swa:
        _ = triton_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
    elif is_gqa:
        _ = triton_fn(Q, K, V, is_causal=is_causal)
    elif is_swa:
        _ = triton_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
    else:
        _ = triton_fn(Q, K, V, is_causal=is_causal)
    triton_memory = torch.cuda.max_memory_allocated() / (1024**3)  # Convert to GB

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    _ = pytorch_fn(Q, K, V, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
    pytorch_memory = torch.cuda.max_memory_allocated() / (1024**3)  # Convert to GB

    speedup = pytorch_avg / triton_avg
    memory_reduction = pytorch_memory / triton_memory

    print("\n--- Benchmark Results ---")
    print(f"{'Implementation':<25} | {'Avg Time (ms)':<20} | {'Peak Memory (GB)':<20}")
    print("-" * 70)
    print(f"{'PyTorch (Naive)':<25} | {pytorch_avg:<20.4f} | {pytorch_memory:<20.4f}")
    print(f"{'Triton (Flash)':<25} | {triton_avg:<20.4f} | {triton_memory:<20.4f}")
    print("-" * 70)
    print(f"Triton is {speedup:.2f}x faster than PyTorch (Naive).")
    print(f"Triton uses {memory_reduction:.2f}x less memory.")

def check_problem_4():
    """Checks Problem 4: Causal Flash Attention."""
    problem_num = 4
    print(f"\n--- Running Autograder for Problem {problem_num}: Causal Flash Attention ---")
    
    torch.manual_seed(45)
    test_cases = [
        (1, 8, 512, 16),
        (1, 8, 1024, 16),
        (1, 16, 2048, 16),
        (1, 16, 4096, 16),
    ]
    
    results = [run_correctness_test(case, flash_attention_forward, is_causal=True, is_gqa=False, is_swa=False, problem_num=problem_num) for case in test_cases]
    if all(results):
        print(f"\nAll P{problem_num} correctness tests passed!")
        benchmark_attention(flash_attention_forward, naive_attention, test_cases[-1], is_causal=True, is_gqa=False)

# Run the autograder
check_problem_4()