# GStar Assignment 1: FlashAttention2 Implementation

This notebook implements FlashAttention-2 forward pass in PyTorch for Problem 1 of the GStar Bootcamp assignment. We'll:

1. **Setup Dependencies**: Install required packages for Google Colab
2. **Implement FlashAttention2**: Complete the PyTorch implementation with online softmax algorithm
3. **Test with Autograder**: Validate correctness against reference implementations

## About FlashAttention2

FlashAttention-2 is a memory-efficient attention mechanism that uses tiled computation and online softmax to reduce memory usage from O(N²) to O(N) while maintaining mathematical exactness. This implementation uses:

- **Tiled computation**: Process attention in blocks to fit in GPU memory
- **Online softmax**: Compute softmax incrementally without storing full attention matrix
- **Causal masking**: Support for autoregressive models like GPT

Let's get started! 🚀


In [None]:
# ============================================================================
# Block 2: Problem 1 - FlashAttention2 PyTorch Implementation
# ============================================================================

import torch
import torch.nn as nn
import math

class FlashAttention2Function(torch.autograd.Function):
    """
    A pure PyTorch implementation of the FlashAttention-2 forward pass.
    This version implements the complete online softmax algorithm with tiled computation.
    """

    @staticmethod
    def forward(ctx, Q, K, V, is_causal=False):
        # Get dimensions from input tensors following the (B, H, N, D) convention
        B, H, N_Q, D_H = Q.shape
        _, _, N_K, _ = K.shape

        # Define tile sizes
        Q_TILE_SIZE = 128
        K_TILE_SIZE = 128
        
        N_Q_tiles = math.ceil(N_Q / Q_TILE_SIZE)
        N_K_tiles = math.ceil(N_K / K_TILE_SIZE)

        # Initialize final output tensors
        O_final = torch.zeros_like(Q, dtype=Q.dtype)
        L_final = torch.zeros((B, H, N_Q), device=Q.device, dtype=torch.float32)
        
        scale = 1.0 / math.sqrt(D_H)

        # Main loops: Iterate over each batch and head
        for b in range(B):
            for h in range(H):
                Q_bh = Q[b, h, :, :]
                K_bh = K[b, h, :, :]
                V_bh = V[b, h, :, :]

                # Loop over query tiles
                for i in range(N_Q_tiles):
                    q_start = i * Q_TILE_SIZE
                    q_end = min((i + 1) * Q_TILE_SIZE, N_Q)
                    Q_tile = Q_bh[q_start:q_end, :]

                    # Initialize accumulators for this query tile
                    o_i = torch.zeros_like(Q_tile, dtype=torch.float32)  # Use float32 for accumulators
                    l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32)
                    m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32)

                    # Inner loop over key/value tiles
                    for j in range(N_K_tiles):
                        k_start = j * K_TILE_SIZE
                        k_end = min((j + 1) * K_TILE_SIZE, N_K)

                        K_tile = K_bh[k_start:k_end, :]
                        V_tile = V_bh[k_start:k_end, :]
                        
                        S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale
                        
                        # --- STUDENT IMPLEMENTATION STARTS HERE ---
                        
                        # 1. Apply causal masking if is_causal is True
                        if is_causal:
                            # Create causal mask for this tile
                            q_indices = torch.arange(q_start, q_end, device=Q.device).unsqueeze(1)  # [q_tile_size, 1]
                            k_indices = torch.arange(k_start, k_end, device=Q.device).unsqueeze(0)  # [1, k_tile_size]
                            causal_mask = q_indices >= k_indices  # True where causal is allowed
                            
                            # Apply mask: set positions where causal_mask is False to -inf
                            S_ij = S_ij.masked_fill(~causal_mask, -float('inf'))
                        
                        # 2. Compute the new running maximum
                        m_ij = torch.max(S_ij, dim=-1)[0]  # Row-wise maximum of current tile
                        m_new = torch.maximum(m_i, m_ij)   # Element-wise maximum with previous running max
                        
                        # 3. Rescale the previous accumulators (o_i, l_i) using the corrected algorithm
                        alpha = torch.exp(m_i - m_new)     # Rescaling factor for previous accumulators
                        
                        # Rescale previous accumulators
                        o_i = o_i * alpha.unsqueeze(-1)
                        l_i = l_i * alpha
                        
                        # 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new)
                        P_ij = torch.exp(S_ij - m_new.unsqueeze(-1))
                        
                        # 5. Accumulate the current tile's contribution to the accumulators
                        # Convert V_tile to float32 for precise accumulation
                        V_tile_f32 = V_tile.to(torch.float32)
                        
                        # Update output accumulator: o_i = o_i + P_ij @ V_tile
                        o_i = o_i + (P_ij @ V_tile_f32)
                        
                        # Update normalizer accumulator: l_i = l_i + rowsum(P_ij)
                        l_i = l_i + torch.sum(P_ij, dim=-1)
                        
                        # 6. Update the running max for the next iteration
                        m_i = m_new
                        
                        # --- STUDENT IMPLEMENTATION ENDS HERE ---

                    # After iterating through all key tiles, normalize the output
                    # This part is provided for you. It handles the final division safely.
                    l_i_reciprocal = torch.where(l_i > 0, 1.0 / l_i, 0)
                    o_i_normalized = o_i * l_i_reciprocal.unsqueeze(-1)
                    
                    L_tile = m_i + torch.log(l_i)
                    
                    # Write results for this tile back to the final output tensors
                    O_final[b, h, q_start:q_end, :] = o_i_normalized.to(Q.dtype)
                    L_final[b, h, q_start:q_end] = L_tile
        
        O_final = O_final.to(Q.dtype)

        ctx.save_for_backward(Q, K, V, O_final, L_final)
        ctx.is_causal = is_causal
 
        return O_final, L_final
    
    @staticmethod
    def backward(ctx, grad_out, grad_L):
        raise NotImplementedError("Backward pass not yet implemented for FlashAttention2Function")

print("✅ FlashAttention2Function implemented successfully!")
print("🎯 Key features implemented:")
print("   - Tiled computation with Q_TILE_SIZE=128, K_TILE_SIZE=128")
print("   - Online softmax algorithm with running maximum")
print("   - Causal masking support for autoregressive models")
print("   - Memory-efficient O(N) implementation")
print("🔧 Fixed: Corrected online softmax algorithm per FlashAttention-2 paper")

In [None]:
# ============================================================================
# Block 3: Autograder Testing for Problem 1
# ============================================================================

# Helper functions and autograder implementation
def repeat_kv(x, num_groups):
    """Helper function to repeat K/V heads for GQA naive implementation."""
    if num_groups == 1:
        return x
    B, H_kv, N, D = x.shape
    x = x.unsqueeze(2).expand(B, H_kv, num_groups, N, D)
    return x.reshape(B, H_kv * num_groups, N, D)

def create_mask_bool(
    seq_len: int,
    window_size: int,
    sink_size: int,
    device=None
    ) -> torch.Tensor:
    
    idx = torch.arange(seq_len, device=device)
    row = idx.unsqueeze(1)   # (seq_len, 1)
    col = idx.unsqueeze(0)   # (1, seq_len)

    # 1) sliding window:  i - (window_size-1) <= j <= i
    sliding = (col <= row) & (col >= row - (window_size - 1))

    # 2) sink at start:   j < sink_size  *and*  j <= i
    sink = (col < sink_size) & (col <= row)

    return sliding | sink

def naive_attention(Q, K, V, is_causal=False, window_size=None, sink_size=None):
    """
    A correct, robust PyTorch implementation of standard attention for comparison.
    Supports GQA, Sliding Window, and Attention Sinks.
    """
    
    batch_size, num_heads_q, seq_len, head_dim = Q.shape
    _, num_heads_kv, seq_len, head_dim = K.shape

    if num_heads_q != num_heads_kv:
        num_groups = num_heads_q // num_heads_kv
        K = repeat_kv(K, num_groups)
        V = repeat_kv(V, num_groups)

    scale = 1.0 / math.sqrt(head_dim)
    S = (Q @ K.transpose(-1, -2)) * scale
    
    if is_causal:
        mask = None
        if window_size is None: # Causal only
            mask = create_mask_bool(seq_len=seq_len, window_size=seq_len, sink_size=0, device=Q.device)
        else:
            if sink_size is None: # SWA only
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=0, device=Q.device)
            else: # SWA + Sink
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=sink_size, device=Q.device)
                
        S.masked_fill_(~mask, -float('inf'))

    P = torch.nn.functional.softmax(S, dim=-1, dtype=torch.float32).to(Q.dtype)
    O_final = P @ V
    L_final = torch.logsumexp(S.to(torch.float32), dim=-1)
    
    return O_final, L_final

def check_problem_1():
    """Checks Problem 1: PyTorch Tiled Attention."""
    problem_num = 1
    print(f"\n--- Running Autograder for Problem {problem_num}: Tiled Flash Attention ---")
    
    torch.manual_seed(42)
    test_cases = [
        (1, 8, 512, 512, 16, False),
        (1, 8, 1024, 1024, 16, True),
        (1, 16, 2048, 2048, 16, True),
        (1, 16, 4096, 4096, 16, True),
    ]
    
    # Custom test runner for P1 which checks both O and L
    def run_p1_test(B, H, N_Q, N_K, D_H, is_causal):
        q = torch.randn(B, H, N_Q, D_H, device='cuda', dtype=DTYPE)
        k = torch.randn(B, H, N_K, D_H, device='cuda', dtype=DTYPE)
        v = torch.randn(B, H, N_K, D_H, device='cuda', dtype=DTYPE)
        
        naive_O, naive_L = naive_attention(q, k, v, is_causal=is_causal)
        student_O, student_L = FlashAttention2Function.apply(q, k, v, is_causal)
        
        o_match = torch.allclose(naive_O, student_O, rtol=5e-2, atol=5e-2)
        l_match = torch.allclose(naive_L, student_L, rtol=5e-2, atol=5e-2)
        
        param_str = f"(B={B}, H={H}, Nq={N_Q}, Nk={N_K}, D={D_H}, Causal={is_causal})"
        if o_match and l_match:
            print(f"✅ P{problem_num} Correctness Test Passed! {param_str}")
            return True
        else:
            print(f"❌ P{problem_num} Correctness Test Failed! {param_str}")
            if not o_match: print(f"   Output 'O' mismatch. Max diff: {(naive_O - student_O).abs().max()}")
            if not l_match: print(f"   Logsumexp 'L' mismatch. Max diff: {(naive_L - student_L).abs().max()}")
            return False

    results = [run_p1_test(*case) for case in test_cases]
    if all(results):
        print(f"\n🎉 All P{problem_num} correctness tests passed!")
        print("🚀 Your FlashAttention2 implementation is working correctly!")
        
        # Run a simple performance benchmark
        print(f"\n--- Performance Test ---")
        B, H, N, D = 1, 16, 2048, 16
        q = torch.randn(B, H, N, D, device='cuda', dtype=DTYPE)
        k = torch.randn(B, H, N, D, device='cuda', dtype=DTYPE)
        v = torch.randn(B, H, N, D, device='cuda', dtype=DTYPE)
        
        # Warmup
        for _ in range(3):
            _ = FlashAttention2Function.apply(q, k, v, True)
        
        torch.cuda.synchronize()
        start_time = time.time()
        
        for _ in range(10):
            output, logsumexp = FlashAttention2Function.apply(q, k, v, True)
        
        torch.cuda.synchronize()
        end_time = time.time()
        
        avg_time = (end_time - start_time) * 1000 / 10
        print(f"⚡ Average execution time: {avg_time:.2f} ms")
        print(f"📊 Test config: B={B}, H={H}, N={N}, D={D}, Causal=True")
        
    else:
        print(f"\n❌ Some tests failed. Please check your implementation.")
    
    return all(results)

# Run the autograder
if torch.cuda.is_available():
    print("🔥 Starting FlashAttention2 Autograder Tests...")
    success = check_problem_1()
    
    if success:
        print("\n" + "="*60)
        print("🎉 CONGRATULATIONS! 🎉")
        print("Your FlashAttention2 implementation passes all tests!")
        print("You've successfully implemented:")
        print("✅ Tiled computation for memory efficiency")
        print("✅ Online softmax algorithm")
        print("✅ Causal masking for autoregressive models")
        print("✅ Mathematical exactness with reference implementation")
        print("="*60)
    else:
        print("\n" + "="*60)
        print("❌ Some tests failed.")
        print("Please review your implementation and try again.")
        print("Check the error messages above for specific issues.")
        print("="*60)
else:
    print("❌ CUDA not available. Cannot run GPU tests.")
    print("Please ensure you're running on a CUDA-enabled environment.")