In [None]:
import triton
import triton.language as tl
import torch
import math
import torch
import torch.nn.functional as F
import time
import numpy as np
import matplotlib.pyplot as plt
import gc
from typing import Tuple, List
import psutil
import os

@triton.jit
def flash_attn_v2_fwd_kernel(
    Q, K, V, O, LSE, # Add LSE output tensor
    stride_qz, stride_qh, stride_qm, stride_qk,  # Q strides: Batch, Head, SeqLen, HeadDim
    stride_kz, stride_kh, stride_kn, stride_kk,  # K strides: Batch, Head, SeqLen, HeadDim
    stride_vz, stride_vh, stride_vn, stride_vk,  # V strides: Batch, Head, SeqLen, HeadDim
    stride_oz, stride_oh, stride_om, stride_ok,  # O strides: Batch, Head, SeqLen, HeadDim
    stride_lse_z, stride_lse_h, stride_lse_m,    # LSE strides: Batch, Head, SeqLen
    Z, H, N_CTX, D_HEAD,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr, # Note: BLOCK_SIZE_K is same as D_HEAD in this simplified version
    IF_CAUSAL_MASK: tl.constexpr,
):

    # Loadin the indcies of batch and head and axis
    start_m = tl.program_id(axis=0) # Block row index
    batch_head_id = tl.program_id(axis=1) # Batch and Head index combined

    # Loadin the indcies of batch and head
    batch_id = batch_head_id // H
    head_id = batch_head_id % H

    # Row offsets for the Q block (BLOCK_SIZE_M rows)
    offs_m = start_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    # Head dimension offsets (BLOCK_SIZE_K columns, which is D_HEAD)
    offs_k = tl.arange(0, BLOCK_SIZE_K) # BLOCK_SIZE_K == D_HEAD

    # Q pointers: Shape (BLOCK_SIZE_M, BLOCK_SIZE_K)
    q_ptrs = (Q + batch_id * stride_qz + head_id * stride_qh +
              offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)

    # K/V base pointers for the current batch and head
    # We will add sequence dimension offsets in the loop
    k_base_ptr = K + batch_id * stride_kz + head_id * stride_kh
    v_base_ptr = V + batch_id * stride_vz + head_id * stride_vh

    # --- Initialize ---
    # Accumulator for the output O = Softmax(QK^T)V
    # Needs to be float32 for precision
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
    # Statistics for online softmax:
    # Max value encountered so far per row: m_i
    m_i = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - float('inf') # Initialize max to -inf
    # Sum of exp(x - max) encountered so far per row: l_i
    l_i = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) # Initialize sum to 0

    # Scale factor for QK^T (usually 1 / sqrt(D_HEAD))
    qk_scale = tl.rsqrt(D_HEAD.to(tl.float32)) # Use 1.0/sqrt for clarity

    # --- Load Q Tile ---
    # Load Q for the current block row (offs_m)
    # Boundary check for Q rows (sequence length N_CTX)
    q_mask = offs_m[:, None] < N_CTX
    # Boundary check for Q/K/V columns (head dimension D_HEAD) - implicit via BLOCK_SIZE_K == D_HEAD
    # If BLOCK_SIZE_K could be < D_HEAD, we'd need a mask: (offs_k[None, :] < D_HEAD)
    q = tl.load(q_ptrs, mask=q_mask, other=0.0) # Shape: (BLOCK_SIZE_M, BLOCK_SIZE_K)

    # --- Determine Loop Bounds for K/V Blocks ---
    # For causal mask, we only need K/V blocks up to the current Q block's end
    # Otherwise, we need all K/V blocks
    end_n = N_CTX
    if IF_CAUSAL_MASK:
      # K/V blocks should end where Q block ends
      # For token i (row in Q), we only attend to tokens j <= i (columns in K/V)
      end_n = (start_m + 1) * BLOCK_SIZE_M

    # --- Loop over K/V Blocks (Columns of QK^T matrix) ---
    for start_n in range(0, end_n, BLOCK_SIZE_N):
        # --- Load K and V Tiles ---
        # Column offsets for the current K/V block (BLOCK_SIZE_N columns)
        offs_n = start_n + tl.arange(0, BLOCK_SIZE_N)

        # K pointers: Shape (BLOCK_SIZE_K, BLOCK_SIZE_N) because K is transposed for matmul
        k_ptrs = (k_base_ptr +
                  offs_n[:,None] * stride_kn + offs_k[None,:] * stride_kk)
        # V pointers: Shape (BLOCK_SIZE_N, BLOCK_SIZE_K)
        v_ptrs = (v_base_ptr +
                  offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # Corrected V pointer

        # Boundary check masks for K and V tiles
        # K mask depends on N_CTX for columns (offs_n) and D_HEAD for rows (offs_k)
        k_mask = (offs_n[None, :] < N_CTX)# & (offs_k[:, None] < D_HEAD) # D_HEAD check implicit
        # V mask depends on N_CTX for rows (offs_n) and D_HEAD for columns (offs_k)
        v_mask = (offs_n[:, None] < N_CTX)# & (offs_k[None, :] < D_HEAD) # D_HEAD check implicit

        # Load K tile (transposed layout for dot product)
        k = tl.load(k_ptrs, mask=k_mask, other=0.0) # Shape: (BLOCK_SIZE_K, BLOCK_SIZE_N)
        # Load V tile
        v = tl.load(v_ptrs, mask=v_mask, other=0.0) # Shape: (BLOCK_SIZE_N, BLOCK_SIZE_K)

        # --- Compute QK^T Score Block ---
        # q shape: (BLOCK_SIZE_M, BLOCK_SIZE_K)
        # k shape: (BLOCK_SIZE_K, BLOCK_SIZE_N)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k,trans_b=True) # Shape: (BLOCK_SIZE_M, BLOCK_SIZE_N)
        qk *= qk_scale

        # --- Apply Causal Mask (if enabled) ---
        if IF_CAUSAL_MASK:
            # Create mask where q_row_index >= k_col_index
            # offs_m are row indices, offs_n are column indices
            causal_mask = offs_m[:, None] >= offs_n[None, :]
            # Apply mask: Set scores for future tokens to negative infinity
            qk = tl.where(causal_mask, qk, -float('inf'))

        # --- Online Softmax Calculation ---
        # 1. Find the new maximum across the block's scores and the old maximum
        m_i_new = tl.maximum(m_i, tl.max(qk, axis=1)) # Shape: (BLOCK_SIZE_M,)

        # 2. Calculate probabilities P_ij = exp(qk_ij - m_i_new)
        # Subtracting the new max prevents overflow
        p_ij = tl.exp2(qk - m_i_new[:, None]) # Shape: (BLOCK_SIZE_M, BLOCK_SIZE_N)

        # 3. Calculate scaling factor for previous accumulator and l_i
        # scale = exp(m_i_old - m_i_new)
        scale = tl.exp2(m_i - m_i_new) # Shape: (BLOCK_SIZE_M,)

        # 4. Rescale previous accumulator: acc = acc * scale
        # Convert acc to float32 before scaling if it's not already
        acc = acc * scale[:, None] # Shape: (BLOCK_SIZE_M, BLOCK_SIZE_K)

        # 5. Update accumulator: acc = acc + P_ij @ V_j
        # p_ij shape: (BLOCK_SIZE_M, BLOCK_SIZE_N)
        # v shape: (BLOCK_SIZE_N, BLOCK_SIZE_K)
        # Ensure matching types for dot product
        acc += tl.dot(p_ij.to(v.dtype), v) # Corrected dot product

        # 6. Update the sum denominator: l_i = l_i * scale + sum(P_ij, axis=1)
        # Calculate block sum: l_i_current = sum(P_ij, axis=1)
        l_i_current = tl.sum(p_ij, axis=1) # Shape: (BLOCK_SIZE_M,)
        # Rescale previous l_i and add current block's contribution
        l_i = l_i * scale + l_i_current # Shape: (BLOCK_SIZE_M,)

        # 7. Update running max for next iteration: m_i = m_i_new
        m_i = m_i_new
        # --- End of Loop Iteration ---

    # --- Post-Loop Calculation and Storage ---


    # LSE = m_i + log(l_i)
    log_l_i = tl.log(l_i)
    lse_final = m_i + log_l_i # Shape: (BLOCK_SIZE_M,)



    acc_o = acc * 1/l_i[:, None] # Shape: (BLOCK_SIZE_M, BLOCK_SIZE_K)

    # 3. Store Output O
    # Output pointers: Shape (BLOCK_SIZE_M, BLOCK_SIZE_K)
    o_ptrs = (O + batch_id * stride_oz + head_id * stride_oh +
              offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok)
    # Boundary check mask for O (same as Q mask for rows, implicit for cols)
    o_mask = offs_m[:, None] < N_CTX
    # Store the result (convert to output dtype if necessary)
    tl.store(o_ptrs, acc_o.to(Q.dtype.element_ty), mask=o_mask)

    # 4. Store LogSumExp (LSE)
    # LSE pointers: Shape (BLOCK_SIZE_M,)
    lse_ptrs = (LSE + batch_id * stride_lse_z + head_id * stride_lse_h +
                offs_m * stride_lse_m)
    # Boundary check mask for LSE rows
    lse_mask = offs_m < N_CTX
    # Store LSE
    tl.store(lse_ptrs, lse_final, mask=lse_mask)


# --- Wrapper Function (Example) ---
def flash_attn_v2_fwd(q, k, v, causal=False):
    # q, k, v: (Z, H, N, D) tensors
    Z, H, N_CTX, D_HEAD = q.shape
    assert D_HEAD == k.shape[-1] and D_HEAD == v.shape[-1]
    assert k.shape[0] == Z and k.shape[1] == H and k.shape[2] == N_CTX
    assert v.shape[0] == Z and v.shape[1] == H and v.shape[2] == N_CTX

    # Output tensor
    o = torch.empty_like(q)
    # LSE tensor: (Batch, Head, SeqLen) - stores logsumexp for backward pass
    lse = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32)

    # Choose block sizes (heuristic, may need tuning)
    BLOCK_SIZE_M = 128
    BLOCK_SIZE_N = 64
    # BLOCK_SIZE_K must be D_HEAD in this implementation
    BLOCK_SIZE_K = D_HEAD

    # Check if D_HEAD is a power of 2 <= 128 (common constraint)
    if D_HEAD not in [16, 32, 64, 128]:
         print(f"Warning: D_HEAD={D_HEAD} might not be optimal or supported by all Triton configurations.")
         # Adjust BLOCK_SIZE_K if needed, but this kernel assumes it's D_HEAD
         assert BLOCK_SIZE_K == D_HEAD, "This kernel requires BLOCK_SIZE_K == D_HEAD"


    num_m_blocks = triton.cdiv(N_CTX, BLOCK_SIZE_M)
    num_bh_groups = Z * H

    grid = (num_m_blocks, num_bh_groups)

    # Launch kernel
    flash_attn_v2_fwd_kernel[grid](
        q, k, v, o, lse,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        lse.stride(0), lse.stride(1), lse.stride(2),
        Z, H, N_CTX, D_HEAD,
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K, # Pass D_HEAD as BLOCK_SIZE_K
        IF_CAUSAL_MASK=causal,
        #num_warps=4, # Example, tune based on block sizes and hardware
        #num_stages=2 # Example, tune based on block sizes and hardware
    )

    return o, lse # Return output and LSE

@triton.jit
def flash_attn_v1_fwd_kernel(
    Q, K, V, O, M, L,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vn, stride_vk,
    stride_oz, stride_oh, stride_om, stride_ok,
    stride_mz, stride_mh, stride_mm,
    stride_lz, stride_lh, stride_lm,
    Z, H, N_CTX,
    SOFTMAX_SCALE,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_HEAD_DIM: tl.constexpr,
    IF_CAUSAL_MASK: tl.constexpr,
):
    start_m = tl.program_id(0)
    batch_head_id = tl.program_id(1)

    batch_id = batch_head_id // H
    head_id = batch_head_id % H

    offs_m = start_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = tl.arange(0, BLOCK_SIZE_HEAD_DIM)

    q_ptrs = (Q + batch_id * stride_qz + head_id * stride_qh +
              offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)

    q_mask = offs_m[:, None] < N_CTX
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)

    # Initialize online softmax variables
    l_i = tl.zeros([BLOCK_SIZE_M], dtype=tl.float32)
    m_i = tl.zeros([BLOCK_SIZE_M], dtype=tl.float32) + float('-inf')
    acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_HEAD_DIM], dtype=tl.float32)

    #qk_scale = tl.rsqrt(BLOCK_SIZE_HEAD_DIM)

    end_n = N_CTX if not IF_CAUSAL_MASK else (start_m + 1) * BLOCK_SIZE_M

    for start_n in range(0, end_n, BLOCK_SIZE_N):
        offs_n = start_n + tl.arange(0, BLOCK_SIZE_N)

        k_ptrs = (K + batch_id * stride_kz + head_id * stride_kh +
                  offs_k[:, None] * stride_kk + offs_n[None, :] * stride_kn)
        v_ptrs = (V + batch_id * stride_vz + head_id * stride_vh +
                  offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)

        kv_mask = offs_n[None, :] < N_CTX # Mask applies to the dimension varying with 'n'
        # V mask
        v_mask = offs_n[:, None] < N_CTX

        # Load K tile (shape will be BLOCK_SIZE_HEAD_DIM x BLOCK_SIZE_N due to pointer layout)
        k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
        # Load V tile (shape will be BLOCK_SIZE_N x BLOCK_SIZE_HEAD_DIM)
        v = tl.load(v_ptrs, mask=v_mask, other=0.0)

        # Compute attention scores
        qk = tl.dot(q, k)
        qk *= SOFTMAX_SCALE

        # Apply causal mask if needed
        if IF_CAUSAL_MASK:
            causal_mask = offs_m[:, None] >= offs_n[None, :]
            qk = tl.where(causal_mask, qk, float('-inf'))

        # Online softmax computation
        m_i_new = tl.maximum(m_i, tl.max(qk, axis=1))
        p_ij = tl.exp(qk - m_i_new[:, None])
        scale = tl.exp(m_i - m_i_new)

        # Update accumulator
        acc = acc * scale[:, None]
        acc += tl.dot(p_ij.to(v.dtype), v)

        # Update normalizing factors
        l_i_current = tl.sum(p_ij, axis=1)
        l_i = l_i * scale + l_i_current
        m_i = m_i_new

    # Store outputs
    O_ptrs = (O + batch_id * stride_oz + head_id * stride_oh +
              offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok)
    M_ptrs = (M + batch_id * stride_mz + head_id * stride_mh + offs_m * stride_mm)
    L_ptrs = (L + batch_id * stride_lz + head_id * stride_lh + offs_m * stride_lm)

    acc_o = acc / l_i[:, None]

    o_mask = offs_m[:, None] < N_CTX
    m_mask = offs_m < N_CTX

    tl.store(O_ptrs, acc_o, mask=o_mask)
    tl.store(M_ptrs, m_i, mask=m_mask)
    tl.store(L_ptrs, l_i, mask=m_mask)


def flash_attention_v1_forward(q, k, v, causal=False):
  
    batch_size, num_heads, seq_len, head_dim = q.shape

    block_m=64 #Hard coded please check the triton heuristice fucntion for the same
    block_n=64

    # Create output tensors
    output = torch.empty_like(q)
    m = torch.empty((batch_size, num_heads, seq_len), device=q.device, dtype=torch.float32)
    l = torch.empty((batch_size, num_heads, seq_len), device=q.device, dtype=torch.float32)

    softmax_scale=1/math.sqrt(head_dim)

    # Calculate grid dimensions
    grid_m = triton.cdiv(seq_len, block_m)
    grid = (grid_m, batch_size * num_heads)

    # Launch kernel
    flash_attn_v1_fwd_kernel[grid](
        q, k, v, output, m, l,
        # Q strides
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        # K strides
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        # V strides
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        # O strides
        output.stride(0), output.stride(1), output.stride(2), output.stride(3),
        # M strides
        m.stride(0), m.stride(1), m.stride(2),
        # L strides
        l.stride(0), l.stride(1), l.stride(2),
        # Dimensions
        batch_size, num_heads, seq_len,
        #Softmax_scale
        SOFTMAX_SCALE=softmax_scale,
        # Block sizes
        BLOCK_SIZE_M=block_m,
        BLOCK_SIZE_N=block_n,
        BLOCK_SIZE_HEAD_DIM=head_dim,
        IF_CAUSAL_MASK=causal,
    )

    return output, m, l


In [None]:

# Benchmark utilities
def get_gpu_memory():
    """Get current GPU memory usage in MB"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**2
    return 0

def benchmark_function(func, *args, num_warmup=5, num_runs=10):
    """Benchmark a function with warmup and multiple runs"""
    # Warmup
    for _ in range(num_warmup):
        with torch.no_grad():
            func(*args)
    torch.cuda.synchronize()
    
    # Measure memory before
    mem_before = get_gpu_memory()
    
    # Actual timing
    torch.cuda.synchronize()
    start_time = time.time()
    
    for _ in range(num_runs):
        with torch.no_grad():
            result = func(*args)
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    # Measure memory after
    mem_after = get_gpu_memory()
    memory_used = mem_after - mem_before
    
    avg_time = (end_time - start_time) / num_runs
    return avg_time, memory_used, result



# Attention implementations to compare
def pytorch_attention(q, k, v, causal=False):
    """Standard PyTorch attention implementation"""
    scale = 1.0 / (q.size(-1) ** 0.5)
    scores = torch.matmul(q, k.transpose(-2, -1)) * scale
    
    if causal:
        seq_len = q.size(-2)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device), diagonal=1).bool()
        scores.masked_fill_(mask, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v)
    return output

def pytorch_sdpa(q, k, v, causal=False):
    """PyTorch Scaled Dot Product Attention (optimized)"""
    return F.scaled_dot_product_attention(q, k, v, is_causal=causal)

def triton_flash_attention_v2(q, k, v, causal=False):
    """Our Triton Flash Attention V2 implementation"""
    output, _ = flash_attn_v2_fwd(q, k, v, causal=causal)
    return output

def triton_flash_attention_v1(q, k, v, causal=False):
    """Our Triton Flash Attention V1 implementation"""
    output, _, _ = flash_attention_v1_forward(q, k, v, causal=causal)
    return output

class AttentionBenchmark:
    def __init__(self):
        self.implementations = {
            'PyTorch Standard': pytorch_attention,
            'PyTorch SDPA': pytorch_sdpa,
            'Triton Flash Attn V1': triton_flash_attention_v1,
            'Triton Flash Attn V2': triton_flash_attention_v2
        }
        self.results = {name: {'times': [], 'memory': [], 'seq_lens': []} 
                       for name in self.implementations.keys()}
    
    def run_benchmark(self, seq_lens: List[int], batch_size: int = 2, 
                     num_heads: int = 8, head_dim: int = 64, 
                     causal: bool = True, dtype=torch.float16):
        """Run benchmark across different sequence lengths"""
        
        print(f"Running benchmark with:")
        print(f"  Batch size: {batch_size}")
        print(f"  Num heads: {num_heads}")
        print(f"  Head dim: {head_dim}")
        print(f"  Causal: {causal}")
        print(f"  Data type: {dtype}")
        print("-" * 50)
        
        for seq_len in seq_lens:
            print(f"\nSequence length: {seq_len}")
            
            # Create test tensors
            q = torch.randn(batch_size, num_heads, seq_len, head_dim, 
                          dtype=dtype, device='cuda', requires_grad=False)
            k = torch.randn(batch_size, num_heads, seq_len, head_dim, 
                          dtype=dtype, device='cuda', requires_grad=False)
            v = torch.randn(batch_size, num_heads, seq_len, head_dim, 
                          dtype=dtype, device='cuda', requires_grad=False)
            
            for name, func in self.implementations.items():
                try:
                    # Clear cache before each test
                    torch.cuda.empty_cache()
                    gc.collect()
                    
                    print(f"  Testing {name}...")
                    avg_time, memory_used, _ = benchmark_function(func, q, k, v, causal)
                    
                    self.results[name]['times'].append(avg_time * 1000)  # Convert to ms
                    self.results[name]['memory'].append(memory_used)
                    self.results[name]['seq_lens'].append(seq_len)
                    
                    print(f"    Time: {avg_time*1000:.2f}ms, Memory: {memory_used:.4f}MB")
                    
                except Exception as e:
                    print(f"    Error: {str(e)}")
                    # Still append data points to keep arrays aligned
                    self.results[name]['times'].append(float('nan'))
                    self.results[name]['memory'].append(float('nan'))
                    self.results[name]['seq_lens'].append(seq_len)
            
            # Clear tensors
            del q, k, v
            torch.cuda.empty_cache()
    
    def plot_results(self):
        """Plot benchmark results"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot execution time
        ax1.set_title('Execution Time vs Sequence Length')
        ax1.set_xlabel('Sequence Length')
        ax1.set_ylabel('Time (ms)')
        ax1.set_xscale('log', base=2)
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
        
        colors = ['blue', 'red', 'green', 'purple']
        markers = ['o', 's', '^', 'D']
        
        for i, (name, data) in enumerate(self.results.items()):
            # Filter out NaN values
            valid_indices = [j for j, t in enumerate(data['times']) if not np.isnan(t)]
            if valid_indices:
                seq_lens = [data['seq_lens'][j] for j in valid_indices]
                times = [data['times'][j] for j in valid_indices]
                color = colors[i % len(colors)]
                marker = markers[i % len(markers)]
                ax1.plot(seq_lens, times, marker=marker, color=color, 
                        label=name, linewidth=2, markersize=6)
        
        ax1.legend()
        
        # Plot memory usage
        ax2.set_title('Memory Usage vs Sequence Length')
        ax2.set_xlabel('Sequence Length')
        ax2.set_ylabel('Memory Usage (MB)')
        ax2.set_xscale('log', base=2)
        ax2.grid(True, alpha=0.3)
        
        for i, (name, data) in enumerate(self.results.items()):
            # Filter out NaN values
            valid_indices = [j for j, m in enumerate(data['memory']) if not np.isnan(m)]
            if valid_indices:
                seq_lens = [data['seq_lens'][j] for j in valid_indices]
                memory = [data['memory'][j] for j in valid_indices]
                color = colors[i % len(colors)]
                marker = markers[i % len(markers)]
                ax2.plot(seq_lens, memory, marker=marker, color=color, 
                        label=name, linewidth=2, markersize=6)
        
        ax2.legend()
        
        plt.tight_layout()
        plt.savefig('attention_benchmark.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Print summary table
        self.print_summary()
    
    def print_summary(self):
        """Print a summary table of results"""
        print("\n" + "="*80)
        print("BENCHMARK SUMMARY")
        print("="*80)
        
        # Find the sequence length with all implementations working
        common_seq_lens = []
        for seq_len in self.results['PyTorch Standard']['seq_lens']:
            if all(seq_len in self.results[name]['seq_lens'] for name in self.results.keys()):
                # Check if all implementations have valid results for this seq_len
                seq_idx = self.results['PyTorch Standard']['seq_lens'].index(seq_len)
                if all(not np.isnan(self.results[name]['times'][seq_idx]) for name in self.results.keys()):
                    common_seq_lens.append(seq_len)
        
        if common_seq_lens:
            # Print comparison for the largest common sequence length
            seq_len = max(common_seq_lens)
            print(f"\nComparison at sequence length {seq_len}:")
            print("-" * 50)
            
            baseline_time = None
            for name, data in self.results.items():
                seq_idx = data['seq_lens'].index(seq_len)
                time_ms = data['times'][seq_idx]
                memory_mb = data['memory'][seq_idx]
                
                if baseline_time is None:
                    baseline_time = time_ms
                    speedup = 1.0
                else:
                    speedup = baseline_time / time_ms
                
                print(f"{name:20} | {time_ms:8.2f}ms | {memory_mb:8.1f}MB | {speedup:.2f}x speedup")
        
        print("\nNote: Speedup is relative to the first implementation")

def main():
    """Main benchmark execution"""
    if not torch.cuda.is_available():
        print("CUDA is not available. This benchmark requires GPU.")
        return
    
    print("Flash Attention Benchmark Suite")
    print("=" * 50)
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    
    # Test different sequence lengths (powers of 2)
    seq_lens = [128, 256, 512, 1024, 2048]
    
    # For very long sequences, test separately to avoid OOM
    extended_seq_lens = [8192, 50000]
    
    benchmark = AttentionBenchmark()
    
    try:
        # Run main benchmark
        benchmark.run_benchmark(seq_lens, batch_size=2, num_heads=8, head_dim=64, causal=True)
        
        # Try extended sequence lengths (might OOM for some implementations)
        print(f"\n{'='*50}")
        print("Testing extended sequence lengths (some may fail due to memory)")
        print(f"{'='*50}")
        
        for seq_len in extended_seq_lens:
            try:
                benchmark.run_benchmark([seq_len], batch_size=1, num_heads=4, head_dim=64, causal=True)
            except Exception as e:
                print(f"Extended test failed at seq_len={seq_len}: {str(e)}")
        
        # Plot and summarize results
        benchmark.plot_results()
        
    except Exception as e:
        print(f"Benchmark failed: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()