# Problem 3: Non-Causal Flash Attention

Implementation of FlashAttention-2 forward pass using Triton for non-causal attention.


In [None]:
import torch
import triton
import triton.language as tl
import math

@triton.jit
def _flash_attention_forward_kernel(
    # Pointers to Tensors
    Q_ptr, K_ptr, V_ptr, O_ptr,
    # Stride information for tensors
    q_stride_b, q_stride_h, q_stride_s,
    k_stride_b, k_stride_h, k_stride_s,
    v_stride_b, v_stride_h, v_stride_s,
    # Kernel parameters
    softmax_scale,
    SEQ_LEN,
    N_HEADS,
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """
    Triton kernel for the forward pass of FlashAttention-2 (non-causal).
    This is a template for student implementation.
    """
    # 1. Identify the block of queries and the batch/head to be processed.
    q_block_idx = tl.program_id(axis=0)
    batch_head_idx = tl.program_id(axis=1)
    
    batch_idx = batch_head_idx // N_HEADS
    head_idx = batch_head_idx % N_HEADS

    # 2. Initialize pointers and accumulators for the online softmax.
    m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # 3. Load the block of queries (Q_i).
    q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M))
    q_ptrs = Q_ptr + batch_idx * q_stride_b + head_idx * q_stride_h + \
             (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :])
    q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0)
    q_block = q_block.to(tl.float32)  # Convert to float32 for computation
    
    # PyTorch softmax is exp(x), Triton is exp2(x * log2(e)), log2(e) is approx 1.44269504
    qk_scale = softmax_scale * 1.44269504

    # 4. Main loop: Iterate over blocks of keys (K_j) and values (V_j).
    for start_n in range(0, SEQ_LEN, BLOCK_N):
        # - Load K_j
        k_offsets = start_n + tl.arange(0, BLOCK_N)
        k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
                 (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
        k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)
        k_block = k_block.to(tl.float32)  # Convert to float32 for computation
        
        # Compute attention scores S_ij = Q_i * K_j^T
        s_ij = tl.dot(q_block, k_block)
        s_ij *= qk_scale

        # Load V_j
        v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
                 (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
        v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)
        v_block = v_block.to(tl.float32)  # Convert to float32 for computation

        # --- STUDENT IMPLEMENTATION ---
        # Implement the online softmax update logic.
        # 1. Find the new running maximum (`m_new`).
        m_ij = tl.max(s_ij, 1)  # Row-wise max of current scores
        m_new = tl.maximum(m_i, m_ij)  # Element-wise max with previous running max
        
        # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
        alpha = tl.exp2(m_i - m_new)  # Rescaling factor for previous values
        beta = tl.exp2(m_ij - m_new)  # Rescaling factor for current values
        
        # Rescale previous accumulator and denominator
        acc = acc * alpha[:, None]  # Rescale accumulator
        l_i = l_i * alpha  # Rescale denominator
        
        # 3. Compute the attention probabilities for the current tile (`p_ij`).
        p_ij = tl.exp2(s_ij - m_new[:, None])  # Softmax probabilities for current tile
        
        # 4. Update the accumulator `acc` using `p_ij` and `v_block`.
        acc += tl.dot(p_ij, v_block)  # Add weighted values to accumulator
        
        # 5. Update the denominator `l_i`.
        l_i += tl.sum(p_ij, 1)  # Add current tile's softmax denominators
        
        # 6. Update the running maximum `m_i` for the next iteration.
        m_i = m_new  # Update running maximum
        # --- END OF STUDENT IMPLEMENTATION ---


    # 5. Normalize the accumulator and write the output block.
    # This part is provided. It handles the final normalization and write-back.
    l_i_safe = l_i[:, None] + 1e-6
    acc = acc / l_i_safe
    
    o_ptrs = O_ptr + batch_idx * q_stride_b + head_idx * q_stride_h + \
             (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :])
             
    tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=q_offsets[:, None] < SEQ_LEN)

def flash_attention_forward(q, k, v, is_causal=False):
    """
    Minimal Python wrapper for the FlashAttention-2 forward pass.
    """
    batch, n_heads, seq_len, head_dim = q.shape
    o = torch.empty_like(q)
    softmax_scale = 1.0 / math.sqrt(head_dim)
    BLOCK_M, BLOCK_N = 128, 64
    grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_heads)

    _flash_attention_forward_kernel[grid](
        q, k, v, o,
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        softmax_scale,
        seq_len,
        n_heads,
        HEAD_DIM=head_dim,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
    )
    return o

In [None]:
# Autograder for Problem 3: Non-Causal Flash Attention
import torch
import torch.nn.functional as F
import math
import time

DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

def repeat_kv(x, num_groups):
    """Helper function to repeat K/V heads for GQA naive implementation."""
    if num_groups == 1:
        return x
    B, H_kv, N, D = x.shape
    x = x.unsqueeze(2).expand(B, H_kv, num_groups, N, D)
    return x.reshape(B, H_kv * num_groups, N, D)

def create_mask_bool(
    seq_len: int,
    window_size: int,
    sink_size: int,
    device=None
    ) -> torch.Tensor:
    
    idx = torch.arange(seq_len, device=device)
    row = idx.unsqueeze(1)   # (seq_len, 1)
    col = idx.unsqueeze(0)   # (1, seq_len)

    # 1) sliding window:  i - (window_size-1) <= j <= i
    sliding = (col <= row) & (col >= row - (window_size - 1))

    # 2) sink at start:   j < sink_size  *and*  j <= i
    sink = (col < sink_size) & (col <= row)

    return sliding | sink

def naive_attention(Q, K, V, is_causal=False, window_size=None, sink_size=None):
    """
    A correct, robust PyTorch implementation of standard attention for comparison.
    Supports GQA, Sliding Window, and Attention Sinks.
    """
    
    batch_size, num_heads_q, seq_len, head_dim = Q.shape
    _, num_heads_kv, seq_len, head_dim = K.shape

    if num_heads_q != num_heads_kv:
        num_groups = num_heads_q // num_heads_kv
        K = repeat_kv(K, num_groups)
        V = repeat_kv(V, num_groups)

    scale = 1.0 / math.sqrt(head_dim)
    S = (Q @ K.transpose(-1, -2)) * scale
    
    if is_causal:
        mask = None
        if window_size is None: # Causal only
            mask = create_mask_bool(seq_len=seq_len, window_size=seq_len, sink_size=0, device=Q.device)
        else:
            if sink_size is None: # SWA only
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=0, device=Q.device)
            else: # SWA + Sink
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=sink_size, device=Q.device)
                
        S.masked_fill_(~mask, -float('inf'))

    P = torch.nn.functional.softmax(S, dim=-1, dtype=torch.float32).to(Q.dtype)
    O_final = P @ V
    L_final = torch.logsumexp(S.to(torch.float32), dim=-1)
    
    return O_final, L_final

def benchmark_attention(triton_func, naive_func, test_params, is_causal, is_gqa=False, is_swa=False):
    """Utility to benchmark an attention function and compare it to a naive implementation."""
    print("\n--- Running Performance Benchmark ---")
    window_size, sink_size = None, None
    if is_gqa and not is_swa: # GQA only 
        batch, heads_q, heads_kv, seq_len, dim = test_params
        config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}"
    elif is_swa: # GQA + SWA
        batch, heads_q, heads_kv, seq_len, dim, *window_params = test_params
        if len(window_params) == 1:
            window_size = window_params[0]
            config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}"
        else:
            window_size, sink_size = window_params
            config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}, S={sink_size}"
    else:
        batch, heads_q, seq_len, dim = test_params
        heads_kv = heads_q
        config_str = f"B={batch}, H={heads_q}, L={seq_len}, D={dim}"

    print(f"Benchmark Config: {config_str}, Causal={is_causal}")
    
    q = torch.randn(batch, heads_q, seq_len, dim, device='cuda', dtype=DTYPE)
    k = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)
    v = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)

    def _run_benchmark(func, is_triton):
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        # Warm-up runs
        for _ in range(5):
            _ = func(q, k, v, is_causal=is_causal)
        
        torch.cuda.synchronize()
        start_time = time.time()
        # Timed runs
        for _ in range(20):
            _ = func(q, k, v, is_causal=is_causal)

        torch.cuda.synchronize()
        end_time = time.time()
        
        avg_time_ms = (end_time - start_time) * 1000 / 20
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
        return avg_time_ms, peak_mem_gb

    triton_time, triton_mem = _run_benchmark(triton_func, is_triton=True)
    # Wrap naive func to discard the L output for benchmarking
    naive_wrapper = lambda q, k, v, is_causal: naive_func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)[0]
    torch_time, torch_mem = _run_benchmark(naive_wrapper, is_triton=False)

    print("\n--- Benchmark Results ---")
    print(f"{'Implementation':<25} | {'Avg Time (ms)':<20} | {'Peak Memory (GB)':<20}")
    print("-" * 70)
    print(f"{'PyTorch (Naive)':<25} | {torch_time:<20.4f} | {torch_mem:<20.4f}")
    print(f"{'Triton (Flash)':<25} | {triton_time:<20.4f} | {triton_mem:<20.4f}")
    print("-" * 70)
    
    # Highlight improvements
    speedup = torch_time / triton_time if triton_time > 0 else float('inf')
    mem_saving = torch_mem / triton_mem if triton_mem > 0 else float('inf')

    print(f"Triton is {speedup:.2f}x faster than PyTorch (Naive).")
    print(f"Triton uses {mem_saving:.2f}x less memory.")

def run_correctness_test(test_params, student_func, is_causal, is_gqa=False, is_swa=False, problem_num=None):
    """Runs a single correctness test case for Triton implementations."""
    window_size = None
    sink_size = None

    if is_swa:
        batch, heads_q, heads_kv, seq_len, dim, *window_params = test_params
        if len(window_params) == 1:
            window_size = window_params[0]
            param_str = f"(B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size})"
        elif len(window_params) == 2:
            window_size, sink_size = window_params
            param_str = f"(B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}, S={sink_size})"
        else:
            raise ValueError(f"Invalid window_params length: {len(window_params)}")
    elif is_gqa:
        batch, heads_q, heads_kv, seq_len, dim = test_params
        param_str = f"(B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim})"
    else:
        batch, heads_q, seq_len, dim = test_params
        heads_kv = heads_q
        param_str = f"(B={batch}, H={heads_q}, L={seq_len}, D={dim})"

    q = torch.randn(batch, heads_q, seq_len, dim, device='cuda', dtype=DTYPE)
    k = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)
    v = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)
    
    torch_result, _ = naive_attention(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
    if sink_size is not None and window_size is not None:
        triton_result = student_func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
    elif window_size is not None:
        triton_result = student_func(q, k, v, is_causal=is_causal, window_size=window_size)
    else:
        triton_result = student_func(q, k, v, is_causal=is_causal)

    if torch.allclose(torch_result, triton_result, rtol=5e-2, atol=5e-2):
        print(f"✅ P{problem_num} Correctness Test Passed! {param_str}")
        return True
    else:
        print(f"❌ P{problem_num} Correctness Test Failed! {param_str}")
        print(f" Max diff: {(torch_result - triton_result).abs().max()}")
        return False

def check_problem_3():
    """Checks Problem 3: Non-Causal Flash Attention."""
    problem_num = 3
    print(f"\n--- Running Autograder for Problem {problem_num}: Non-Causal Flash Attention ---")
    
    torch.manual_seed(44)
    test_cases = [
        (1, 8, 512, 16),
        (1, 8, 1024, 16),
        (1, 16, 2048, 16),
        (1, 16, 4096, 16),
    ]
    
    results = [run_correctness_test(case, flash_attention_forward, is_causal=False, is_gqa=False, is_swa=False, problem_num=problem_num) for case in test_cases]
    if all(results):
        print(f"\nAll P{problem_num} correctness tests passed!")
        benchmark_attention(flash_attention_forward, naive_attention, test_cases[-1], is_causal=False, is_gqa=False)

# Run the autograder
check_problem_3()