From 4955b8b14580c5bda0eb77ce042b1ddf1914d109 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 27 Jun 2025 14:50:15 +0800 Subject: [PATCH] Update integration --- docs/integration.md | 1574 +++++++++++++++++++++++++++---------------- 1 file changed, 980 insertions(+), 594 deletions(-) diff --git a/docs/integration.md b/docs/integration.md index 0fb9aa6..80351ab 100644 --- a/docs/integration.md +++ b/docs/integration.md @@ -1,614 +1,1000 @@ -# Flash Attention and Dynamic Mask Attention Integration +# Flash Dynamic Mask Attention Integration Guide -## Table of Contents -1. Introduction -2. Flash Attention Algorithm - - Key Concepts - - Core Algorithm Overview - - Implementation Details - - Performance Characteristics -3. Dynamic Mask Attention Algorithm - - Key Concepts - - Core Algorithm Overview - - Implementation Details - - Performance Characteristics -4. Comparative Analysis - - Standard Attention vs. Flash Attention vs. DMA - - Memory and Computational Complexity - - Use Cases and Tradeoffs -5. Flash-DMA Integration - - Motivation and Benefits - - Architectural Overview - - Dynamic Mask Processing Integration - - Sparse Attention Weight Computation - - Memory Management Strategies -6. Technical Implementation Details - - ZOH States and Active Mask Preprocessing - - Global to MMA Format Conversion - - Mask Application in Attention Computation - - Sparse Matrix Multiplication -7. Optimization Strategies - - Memory Access Patterns - - Warp Efficiency and Load Balancing - - Numerical Stability Considerations -8. Integration Architecture - - Data Flow Pipeline - - Component Interaction - - Kernel Modifications -9. Performance Expectations - - Theoretical Analysis - - Benchmarking Strategy - - Validation Framework -10. Conclusion - - Summary of Benefits - - Future Directions - -## Introduction - -This document provides a comprehensive analysis of the integration between Flash Attention and Dynamic Mask Attention (DMA) algorithms. The integration combines the memory efficiency of Flash Attention with the computational efficiency of DMA to create a high-performance attention mechanism specifically designed for processing extremely long sequences. - -The core innovation lies in incorporating dynamic masking and sparse computation capabilities into Flash Attention's block-based processing framework. This hybrid approach maintains Flash Attention's O(N) memory complexity while achieving DMA's O(N*k) computational complexity, where k represents the number of selected keys per query. - -By leveraging both pre-computed dynamic masks and sparse matrix multiplication techniques, the integrated system can efficiently handle sequences of unprecedented length while maintaining numerical accuracy and computational throughput. - -## Flash Attention Algorithm - -### Key Concepts - -Flash Attention revolutionizes attention computation through several key innovations: - -1. **Block-wise Processing**: Divides the attention computation into manageable blocks that fit in GPU shared memory, eliminating the need to materialize the full attention matrix. - -2. **Online Softmax Algorithm**: Computes softmax incrementally as blocks are processed, maintaining numerical stability through careful tracking of maximum values and normalization constants. - -3. **Shared Memory Optimization**: Utilizes GPU shared memory efficiently through strategic data tiling and reuse patterns, minimizing expensive global memory accesses. - -4. **Precision Management**: Performs accumulation in higher precision (FP32) while storing intermediate results in lower precision (FP16/BF16) to balance accuracy and memory usage. - -5. **Log-Sum-Exp Tracking**: Maintains running statistics for numerically stable softmax computation across blocks. - -### Core Algorithm Overview - -Flash Attention processes attention computation in the following phases: - -1. **Initialization**: Allocate shared memory for query, key, and value blocks, initialize output accumulators and softmax statistics. - -2. **Block-wise Iteration**: For each query block, iterate through all key-value blocks: - - Load current blocks into shared memory - - Compute attention scores through matrix multiplication - - Apply causal masking if required - - Update softmax statistics and output accumulation - -3. **Normalization**: Apply final normalization using accumulated log-sum-exp values to produce the correct attention output. - -The algorithm's key insight is that by maintaining sufficient statistics (maximum values and exponential sums), it can produce identical results to standard attention without ever materializing the complete attention matrix. - -### Implementation Details - -Flash Attention employs sophisticated memory management and computational strategies: - -#### Memory Layout and Tiling - -The implementation uses carefully designed tensor layouts that maximize shared memory utilization. Query, key, and value tensors are partitioned into blocks that fit within the GPU's shared memory constraints, with swizzling patterns applied to minimize bank conflicts. - -#### Block Processing Strategy - -The computation is organized into two distinct phases: -- **Masking Phase**: Processes blocks that require causal or other forms of masking -- **Non-masking Phase**: Processes remaining blocks without masking overhead - -This separation optimizes performance by avoiding unnecessary conditional operations where possible. - -#### Online Softmax Implementation - -The online softmax algorithm is critical for memory efficiency. For each processed block, the algorithm: -- Updates the running maximum value across all processed elements -- Rescales previously accumulated values when the maximum changes -- Computes normalized attention weights for the current block -- Updates the running sum of exponentials for final normalization - -### Performance Characteristics - -Flash Attention achieves significant performance improvements: - -1. **Memory Complexity**: Reduces from O(N²) to O(N) for sequence length N -2. **Memory Bandwidth**: Optimized access patterns achieve near-peak bandwidth utilization -3. **Throughput**: Delivers 2-4x speedup over standard implementations for long sequences -4. **Scalability**: Performance gains increase with sequence length -5. **Accuracy**: Produces bit-exact results compared to standard attention - -## Dynamic Mask Attention Algorithm - -### DMA Key Concepts - -Dynamic Mask Attention introduces computational efficiency through selective processing: - -1. **Adaptive Key Selection**: Dynamically determines which keys are most relevant for each query based on learned importance criteria. - -2. **Zero-Order Hold (ZOH) States**: Computes importance scores through learned transformations of value states, creating dynamic attention masks. - -3. **Top-K Filtering**: Selects only the most important keys for attention computation, dramatically reducing computational requirements. - -4. **Sparse Computation**: Performs attention computation only with selected keys, avoiding unnecessary operations. - -5. **Content-Adaptive Processing**: Selection patterns adapt to input content, providing better focus than static sparse patterns. - -### DMA Core Algorithm - -The DMA algorithm consists of the following computational stages: - -1. **Importance Score Generation**: Transform value states through learned projections to generate importance scores for each key. - -2. **Activation and Scaling**: Apply activation functions (softplus) and learned scaling factors to create zero-order hold states. - -3. **Dynamic Mask Creation**: Select top-k keys based on importance scores, creating sparse attention masks that vary per query. - -4. **Mask Application**: Apply causal and padding masks to the dynamic masks as needed. - -5. **Sparse Attention Computation**: Compute attention only for selected keys, using sparse matrix multiplication techniques. - -6. **Output Generation**: Produce final attention outputs through weighted combination of selected values. - -### DMA Implementation Details - -Dynamic Mask Attention requires careful handling of sparse data structures and irregular computation patterns: - -#### Importance Score Computation - -The transformation from value states to importance scores involves learned linear projections followed by activation functions. This stage determines which keys will be selected for each query, making it critical for both accuracy and efficiency. - -#### Top-K Selection Strategy - -The selection process uses efficient sorting algorithms to identify the most important keys. The implementation must handle variable sparsity patterns and ensure consistent results across different hardware configurations. - -#### Sparse Data Management - -Managing sparse attention patterns requires efficient data structures for storing selected indices and values. The implementation must balance memory usage with access efficiency. - -### DMA Performance Characteristics - -Dynamic Mask Attention offers distinct computational advantages: - -1. **Computational Complexity**: Reduces from O(N²) to O(N*k) where k << N -2. **Adaptive Efficiency**: Performance scales with actual content complexity rather than sequence length -3. **Memory Access**: Sparse patterns reduce memory bandwidth requirements -4. **Scalability**: Benefits increase significantly with sequence length -5. **Content Sensitivity**: Focuses computational resources on relevant information - -## Comparative Analysis - -### Standard Attention vs. Flash Attention vs. DMA - -| Feature | Standard Attention | Flash Attention | Dynamic Mask Attention | -|---------|-------------------|----------------|------------------------| -| Memory Complexity | O(N²) | O(N) | O(N²) in naive form, O(N) optimized | -| Computational Complexity | O(N²*D) | O(N²*D) | O(N*k*D) | -| Processing Strategy | Dense matrix operations | Block-wise dense operations | Sparse selection and computation | -| Memory Bandwidth | High (full matrix) | Optimized (block reuse) | Reduced (sparse access) | -| Adaptability | Fixed | Fixed | Content-adaptive | -| Implementation Complexity | Low | Medium | High | - -### Memory and Computational Complexity - -**Memory Usage Analysis:** -- Standard Attention: Requires full N×N attention matrix storage -- Flash Attention: Uses O(N) memory through block-wise processing -- Integrated Flash-DMA: Maintains O(N) memory while adding sparse indexing overhead - -**Computational Analysis:** -- Standard Attention: Performs full dense matrix multiplications -- Flash Attention: Same operations but with optimized memory access -- Integrated Flash-DMA: Reduces operations through sparsity while maintaining memory efficiency - -### Use Cases and Tradeoffs - -**Standard Attention:** -- Best for short sequences where simplicity is valued -- Suitable when memory is abundant -- Optimal for debugging and reference implementations - -**Flash Attention:** -- Ideal for medium to long sequences -- When memory is the primary bottleneck -- Requires exact attention computation - -**Dynamic Mask Attention:** -- Optimal for very long sequences -- When computational cost is prohibitive -- Acceptable quality-performance tradeoffs - -**Integrated Flash-DMA:** -- Best for extremely long sequences (10K+ tokens) -- When both memory and computation are constraints -- Applications requiring adaptive attention patterns - -## Flash-DMA Integration - -### Motivation and Benefits - -The integration of Flash Attention and Dynamic Mask Attention creates a synergistic combination that addresses the limitations of each individual approach: - -1. **Complementary Optimization**: Flash Attention optimizes memory usage while DMA reduces computational requirements. - -2. **Extended Sequence Support**: Combined approach enables processing of sequences exceeding 100K tokens. - -3. **Adaptive Efficiency**: Maintains Flash Attention's memory efficiency while adding content-adaptive computation. - -4. **Hardware Utilization**: Maximizes GPU utilization through optimized memory access patterns and reduced computation. - -5. **Scalability**: Provides better scaling characteristics than either approach alone. - -### Architectural Overview - -The integrated Flash-DMA architecture modifies Flash Attention's core algorithm in two primary ways: - -1. **Dynamic Mask Integration**: Incorporates pre-computed ZOH states and active masks into the block-wise processing pipeline. - -2. **Sparse Computation**: Implements sparse matrix multiplication within the existing MMA (Matrix Multiply Accumulate) framework. - -The integration maintains Flash Attention's fundamental block-based structure while adding dynamic masking capabilities at the attention score computation level. - -### Dynamic Mask Processing Integration - -The dynamic mask processing is integrated into Flash Attention through several key components: - -#### Pre-computation Phase - -Before kernel execution, the Python frontend computes: -- **ZOH States**: Importance scores derived from value state transformations -- **Active Masks**: Binary masks indicating which keys are selected for each query -- **Index Maps**: Efficient representations of sparse patterns for GPU processing - -#### Format Conversion - -Global tensors are converted to MMA-compatible formats: -- **ZOH States**: Transformed from 1D global format to 3D MMA layout tensors -- **Active Masks**: Converted from global boolean masks to MMA-structured masks -- **Layout Adaptation**: Ensures compatibility with Flash Attention's tensor layouts - -#### Mask Application - -Dynamic masks are applied during attention score computation: -- **Score Scaling**: ZOH states are added to attention scores before softmax -- **Sparsity Enforcement**: Active masks eliminate computation for unselected keys -- **Numerical Stability**: Maintains Flash Attention's numerical properties - -### Sparse Attention Weight Computation - -The sparse computation integration modifies Flash Attention's matrix multiplication pipeline: - -#### Sparse MMA Operations - -Traditional dense matrix multiplications are replaced with sparse variants: -- **Key Selection**: Only selected keys participate in attention score computation -- **Value Aggregation**: Sparse attention weights are applied to corresponding values -- **Accumulation**: Results are accumulated using modified patterns that respect sparsity - -#### Efficiency Optimizations +## Overview -Several optimizations ensure efficient sparse computation: -- **Warp-level Coordination**: Threads within warps coordinate to handle irregular sparsity patterns -- **Load Balancing**: Work distribution adapts to varying sparsity levels across blocks -- **Memory Access**: Sparse access patterns are optimized for GPU memory hierarchy +This document describes the integration of Dynamic Mask Attention into the Flash Attention framework. The integration enables efficient sparse attention computation by combining Flash Attention's memory-efficient approach with dynamic masking capabilities for handling extremely long sequences. -### Memory Management Strategies +The integration implements a two-stage approach: Python frontend pre-computes Zero-Order Hold states and Active Mask tensors, while the CUDA backend performs sparse attention computation using these pre-computed masks. -The integrated system employs sophisticated memory management: - -#### Shared Memory Allocation - -Shared memory is allocated to accommodate: -- **Original Flash Attention Data**: Query, key, and value blocks -- **Mask Information**: Active mask data for current blocks -- **Index Structures**: Efficient representations of selected key indices - -#### Global Memory Access - -Global memory access patterns are optimized for: -- **ZOH State Loading**: Efficient transfer of importance scores -- **Sparse Index Management**: Compact storage and fast access of selection patterns -- **Output Writing**: Maintains Flash Attention's efficient output patterns - -## Technical Implementation Details - -### ZOH States and Active Mask Preprocessing - -The preprocessing phase prepares dynamic mask information for GPU consumption: - -#### ZOH State Generation - -Zero-Order Hold states are computed through: -- **Value Transformation**: Linear projection of value states using learned parameters -- **Activation Application**: Softplus activation followed by exponential scaling -- **Normalization**: Ensuring numerical stability and appropriate dynamic range - -#### Active Mask Creation - -Active masks are generated through: -- **Top-K Selection**: Identifying the most important keys for each query -- **Sparsity Pattern Creation**: Converting selection results to efficient mask representations -- **Causal Mask Integration**: Combining dynamic masks with causal and padding constraints - -#### Data Format Optimization - -Preprocessing optimizes data formats for GPU efficiency: -- **Memory Layout**: Arranging data to maximize coalesced access patterns -- **Compression**: Using compact representations for sparse patterns -- **Alignment**: Ensuring proper memory alignment for vector operations - -### Global to MMA Format Conversion - -The conversion process adapts global tensor formats to MMA-compatible layouts: - -#### Layout Transformation - -Global tensors undergo layout transformations: -- **Dimension Reordering**: Adapting from batch-sequence-head layout to MMA-friendly formats -- **Block Partitioning**: Dividing global tensors into block-sized chunks -- **Swizzling**: Applying memory access patterns that minimize bank conflicts - -#### MMA Compatibility - -Ensuring compatibility with Matrix Multiply Accumulate operations: -- **Fragment Generation**: Creating register-resident fragments for MMA operations -- **Type Conversion**: Handling precision conversions between global and local formats -- **Synchronization**: Coordinating data availability across thread groups - -#### Error Handling - -Robust conversion includes: -- **Bounds Checking**: Ensuring access patterns remain within valid memory ranges -- **Precision Preservation**: Maintaining numerical accuracy during format conversions -- **Invalid Pattern Handling**: Graceful handling of edge cases and boundary conditions - -### Mask Application in Attention Computation - -The mask application process integrates dynamic masks into attention score computation: - -#### Score Modification - -Attention scores are modified through: -- **ZOH Addition**: Adding importance scores to raw attention scores -- **Scale Application**: Applying learned scaling factors -- **Mask Enforcement**: Setting unselected positions to negative infinity - -#### Softmax Integration - -Softmax computation handles sparse patterns: -- **Numerically Stable Computation**: Maintaining stability with sparse inputs -- **Renormalization**: Proper normalization across selected keys only -- **Temperature Scaling**: Applying appropriate temperature parameters - -#### Output Generation - -Final outputs incorporate sparsity: -- **Sparse Aggregation**: Weighted combination using only selected values -- **Accumulation Patterns**: Efficient accumulation respecting sparse structure -- **Result Formatting**: Converting sparse results back to dense output format - -### Sparse Matrix Multiplication - -Sparse matrix multiplication requires specialized implementations: - -#### Sparsity Pattern Management - -Efficient handling of sparse patterns: -- **Pattern Encoding**: Compact representation of which operations to perform -- **Dynamic Dispatch**: Runtime selection of computation paths based on sparsity -- **Load Balancing**: Distributing work evenly across processing units - -#### MMA Integration - -Integration with CUDA's Matrix Multiply Accumulate operations: -- **Fragment Masking**: Applying masks at the MMA fragment level -- **Partial Operations**: Performing only necessary MMA operations -- **Result Combination**: Correctly combining partial results - -#### Performance Optimization - -Optimizations for sparse computation: -- **Branch Reduction**: Minimizing divergent execution paths -- **Memory Coalescing**: Maintaining efficient memory access despite sparsity -- **Register Usage**: Optimizing register allocation for sparse operations - -## Optimization Strategies - -### Memory Access Patterns - -Optimizing memory access for the integrated system: - -#### Coalesced Access - -Maintaining coalesced memory access: -- **Alignment Strategies**: Ensuring memory accesses align with hardware capabilities -- **Stride Optimization**: Minimizing memory stride effects in sparse patterns -- **Prefetching**: Strategic prefetching of sparse data - -#### Cache Utilization - -Maximizing cache efficiency: -- **Locality Preservation**: Maintaining temporal and spatial locality where possible -- **Cache Line Usage**: Optimizing access patterns for cache line efficiency -- **Shared Memory Management**: Effective use of shared memory as a managed cache - -### Warp Efficiency and Load Balancing - -Ensuring efficient utilization of CUDA warps: - -#### Divergence Minimization - -Reducing warp divergence: -- **Uniform Processing**: Grouping similar sparsity patterns for uniform execution -- **Predication**: Using predicated execution to minimize branching -- **Work Redistribution**: Dynamically redistributing work to balance loads - -#### Synchronization Optimization - -Efficient synchronization: -- **Barrier Reduction**: Minimizing synchronization points -- **Cooperative Groups**: Using CUDA cooperative groups for efficient coordination -- **Pipeline Optimization**: Overlapping computation and memory access phases - -### Numerical Stability Considerations - -Maintaining numerical accuracy in the integrated system: - -#### Precision Management - -Careful precision handling: -- **Mixed Precision**: Strategic use of different precisions for different operations -- **Accumulation Accuracy**: Ensuring high precision for critical accumulations -- **Overflow Prevention**: Preventing numerical overflow in sparse computations - -#### Stability Preservation +## Table of Contents -Maintaining Flash Attention's numerical stability: -- **Softmax Stability**: Preserving numerically stable softmax computation -- **Gradient Flow**: Ensuring stable gradient computation for training -- **Error Accumulation**: Minimizing error accumulation across blocks +1. [Integration Architecture](#integration-architecture) +2. [Core Modifications](#core-modifications) +3. [Implementation Details](#implementation-details) +4. [Sparse Computation Strategy](#sparse-computation-strategy) +5. [Memory Layout](#memory-layout) +6. [Performance Considerations](#performance-considerations) +7. [API Changes](#api-changes) ## Integration Architecture -### Data Flow Pipeline - -The integrated system follows a structured data flow: - -1. **Input Processing**: Receive query, key, value tensors and preprocessing parameters -2. **Mask Generation**: Compute ZOH states and active masks on the host -3. **Format Conversion**: Transform global tensors to MMA-compatible formats -4. **Block Processing**: Execute modified Flash Attention kernel with sparse operations -5. **Output Assembly**: Combine block results into final output tensors - -### Component Interaction - -Key components interact through well-defined interfaces: - -#### Frontend-Backend Interface - -Python frontend communicates with CUDA backend through: -- **Parameter Passing**: Efficient transfer of computation parameters -- **Tensor Management**: Memory-efficient tensor sharing between host and device -- **Error Handling**: Comprehensive error reporting and recovery - -#### Kernel Components - -Within the CUDA kernel, components interact through: -- **Shared Memory**: Coordinated use of shared memory resources -- **Register Communication**: Efficient register-level data sharing -- **Synchronization Points**: Strategic synchronization for correctness - -### Kernel Modifications - -The Flash Attention kernel is modified to support dynamic masking: - -#### Control Flow Changes - -Modified control flow includes: -- **Conditional Processing**: Runtime decisions based on sparsity patterns -- **Early Termination**: Skipping unnecessary computations -- **Adaptive Scheduling**: Adjusting processing order for efficiency - -#### Memory Access Modifications - -Updated memory access patterns: -- **Sparse Loading**: Loading only necessary data elements -- **Selective Caching**: Caching decisions based on access patterns -- **Efficient Indexing**: Fast lookup of sparse indices - -## Performance Expectations +### High-Level Design + +The Dynamic Mask Attention integration follows a two-phase approach: + +1. **Dynamic Mask Computation**: Python frontend pre-computes ZOH states and Active Mask tensors +2. **Sparse Attention Execution**: CUDA backend performs sparse attention computation using the pre-computed masks + +``` +Python Frontend CUDA Backend +┌─────────────────────────────┐ ┌──────────────────────────────┐ +│ dt_states = exp(A * softplus│ │ Global Memory Loading │ +│ (V @ dt_proj^T)) │────│ ├─ ZOH States │ +│ │ │ ├─ Active Mask │ +│ prepare_dynamic_mask() │ │ └─ Q, K, V Tensors │ +│ ├─ ZOH States Generation │ │ │ +│ ├─ Active Mask via TopK │ │ Sparse Attention Computation │ +│ └─ Dynamic Bias Calculation │ │ ├─ Sparse Q*K^T GEMM │ +└─────────────────────────────┘ │ ├─ Masked Softmax with ZOH │ + │ └─ Sparse Score*V GEMM │ + └──────────────────────────────┘ +``` + +### Key Components + +- **ZOH States**: Dynamic attention bias values `(batch, num_heads, query_len, key_len)` derived from value states and learned projections +- **Active Mask**: Binary mask `(batch, num_heads, query_len, key_len)` indicating which positions should be computed (1.0) or skipped (0.0) +- **Sparse GEMM**: Optimized matrix multiplication that only computes non-masked regions +- **Dynamic Masking**: Integration of ZOH bias and active mask into attention score computation + +## Core Modifications + +### 1. Parameter Structure Extensions (`flash.h`) + +**Purpose**: Extended parameter structures to support dynamic masking tensors with proper memory layout information. + +**Changes Made**: +```cpp +struct ZOH_params { + void *__restrict__ zoh_ptr; // ZOH states pointer + void *__restrict__ active_mask_ptr; // Active mask pointer + index_t zoh_batch_stride; // Batch stride for ZOH states + index_t active_mask_batch_stride; // Batch stride for active mask + index_t zoh_head_stride; // Head stride for ZOH states + index_t active_mask_head_stride; // Head stride for active mask + index_t zoh_row_stride; // Row stride for ZOH states + index_t active_mask_row_stride; // Row stride for active mask + int keep_window_size; // Sparsity control parameter +}; + +struct Flash_fwd_params : public QKV_params, public ZOH_params { + // Inherits both QKV and ZOH parameters through multiple inheritance + // Enables unified parameter passing to CUDA kernels +}; +``` + +**Rationale**: +- **Multiple Inheritance Design**: Cleanly separates QKV parameters from ZOH parameters while maintaining unified access +- **Comprehensive Stride Information**: Provides all necessary stride information for efficient tensor indexing in CUDA kernels +- **Memory Layout Optimization**: Enables optimal memory access patterns for both regular and sparse tensors + +### 2. Kernel Traits and Memory Layout (`kernel_traits.h`) + +**Purpose**: Define shared memory layouts and copy operations optimized for dynamic masking tensors. + +**Changes Made**: +```cpp +template +struct Flash_kernel_traits { + // ...existing Flash Attention traits... + + // ZOH States shared memory layout - matches attention score layout + using SmemLayoutZOH = decltype(make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, _1{}) + )); + + // Active Mask shared memory layout - row-major for efficient indexing + using SmemLayoutActiveMask = decltype(make_layout( + make_shape(Int{}, Int{}), + make_stride(Int{}, _1{}) + )); + + // Optimized copy atoms for ZOH and Active Mask data movement + using SmemCopyAtomZOH = Copy_Atom; + using SmemCopyAtomActiveMask = Copy_Atom; + + // Shared memory size calculations including masking tensors + static constexpr int kSmemSizeZOH = kBlockM * kBlockN * sizeof(elem_type); + static constexpr int kSmemSizeActiveMask = kBlockM * kBlockN * sizeof(elem_type); +}; +``` + +**Rationale**: +- **Layout Consistency**: ZOH states use the same layout as attention scores for efficient fusion +- **Memory Access Optimization**: Copy atoms leverage GPU's specialized load/store units for maximum bandwidth +- **Shared Memory Management**: Explicit size calculations ensure proper memory allocation + +### 3. Block Information Extension (`block_info.h`) + +**Purpose**: Calculate memory offsets for ZOH states and active masks within thread blocks, enabling efficient global memory access. + +**Changes Made**: +```cpp +template +struct BlockInfo { + // ...existing Flash Attention block info... + + index_t zoh_offset; // Global memory offset for ZOH states + index_t active_mask_offset; // Global memory offset for active mask + + template + __device__ BlockInfo(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + // ...existing initialization... + + // Calculate ZOH states offset: [batch][head][query_start_row][0] + zoh_offset = bidb * params.zoh_batch_stride + + bidh * params.zoh_head_stride + + m_block * kBlockM * params.zoh_row_stride; + + // Calculate Active Mask offset: [batch][head][query_start_row][0] + active_mask_offset = bidb * params.active_mask_batch_stride + + bidh * params.active_mask_head_stride + + m_block * kBlockM * params.active_mask_row_stride; + } +}; +``` + +**Rationale**: +- **Unified Offset Calculation**: Encapsulates complex address arithmetic in a single location +- **Block-Aware Indexing**: Accounts for thread block positioning within the global attention matrix +- **Type Safety**: Template-based design ensures compile-time optimization and type checking + +### 4. Memory Copy Operations (`utils.h`) + +**Purpose**: Implement efficient memory copy operations for loading ZOH states and active masks from global to shared memory. + +**Changes Made**: +```cpp +template +__forceinline__ __device__ void copy_ZOH( + Tensor0 &tSgZOH, // Global ZOH tensor view + Tensor1 &tSsZOH, // Shared ZOH tensor view + Tensor2 &tSrZOH, // Register ZOH tensor view + Tensor3 &tSgAM, // Global Active Mask tensor view + Tensor4 &tSsAM, // Shared Active Mask tensor view + TiledMma tiled_mma, // MMA tile configuration + TiledCopy smem_tiled_copy_ZOH, // Tiled copy for ZOH + ThrCopy smem_thr_copy_ZOH // Thread copy for ZOH +) { + // Copy ZOH states: Global Memory -> Shared Memory + copy(smem_tiled_copy_ZOH, tSgZOH, tSsZOH); + + // Copy Active Mask: Global Memory -> Shared Memory + copy(smem_tiled_copy_ZOH, tSgAM, tSsAM); + + // Synchronize to ensure all data is loaded before computation + __syncthreads(); + + // Copy to registers for computation: Shared Memory -> Registers + copy(smem_thr_copy_ZOH, tSsZOH, tSrZOH); + copy(smem_thr_copy_ZOH, tSsAM, tSrAM); +} +``` + +**Rationale**: +- **Multi-Level Memory Hierarchy**: Efficiently manages data movement through global -> shared -> register memory levels +- **Coalesced Access Patterns**: Leverages CUTLASS copy operations for optimal memory bandwidth utilization +- **Synchronization Management**: Proper thread synchronization ensures data consistency across the thread block + +### 5. Dynamic Masking Logic (`mask.h`) + +**Purpose**: Implement the core dynamic masking functionality that applies ZOH states and active masks during attention computation. + +**Changes Made**: +```cpp +template +struct DynamicMask { + const int max_seqlen_k, max_seqlen_q; + const int keep_window_size; + + template + __forceinline__ __device__ void apply_mask( + TensorType &tensor_, // Attention scores (MMA=4, MMA_M, MMA_N) + ZOHType &tSrZOH, // ZOH states in registers + ActiveMaskType &tSrAM, // Active mask in registers + const float scale_softmax, // Attention scaling factor + const int col_idx_offset_, // Column index offset for this thread block + const int row_idx_offset, // Row index offset for this thread block + const int warp_row_stride // Row stride within warp + ) { + // Convert MMA layout to row-column layout for easier indexing + Tensor tensor = make_tensor(tensor_.data(), convert_layout_acc_rowcol(tensor_.layout())); + Tensor zoh = make_tensor(tSrZOH.data(), convert_layout_acc_rowcol(tSrZOH.layout())); + Tensor active_mask = make_tensor(tSrAM.data(), convert_layout_acc_rowcol(tSrAM.layout())); + + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + // Apply causal masking if enabled + const int col_idx_limit = Causal_mask ? + std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : + max_seqlen_k; + + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j * 2; + + if (col_idx < col_idx_limit && row_idx < max_seqlen_q && col_idx < max_seqlen_k) { + // Check if this position should be computed (active mask = 1.0) + if (active_mask(i, mi, j, nj) == 0.0f) { + // Masked position: set to -infinity + tensor(i, mi, j, nj) = -INFINITY; + } else { + // Active position: apply scaling and add ZOH bias + tensor(i, mi, j, nj) = tensor(i, mi, j, nj) * scale_softmax + zoh(i, mi, j, nj); + } + } else { + // Out of bounds: always mask + tensor(i, mi, j, nj) = -INFINITY; + } + } + } + } + } + } +}; +``` + +**Rationale**: +- **Register-Level Operations**: All masking operations performed in registers for maximum efficiency +- **Unified Masking Logic**: Combines causal masking, boundary checking, and dynamic masking in a single pass +- **Numerical Stability**: Proper handling of infinity values for masked positions ensures stable softmax computation + +### 6. Sparse Matrix Operations (`utils.h`) + +**Purpose**: Implement sparse GEMM operations that utilize active masks to skip computation for masked regions, significantly reducing computational overhead. + +**Changes Made**: +```cpp +template +__forceinline__ __device__ void sparse_gemm( + Tensor0 &acc, // Output accumulator tensor + Tensor1 &tCrA, // A matrix in registers (Query) + Tensor2 &tCrB, // B matrix in registers (Key/Value) + Tensor3 &tCsA, // A matrix in shared memory + Tensor4 &tCsB, // B matrix in shared memory + Tensor5 &active_mask, // Sparsity mask in registers + TiledMma tiled_mma, // MMA tile configuration + TiledCopyA smem_tiled_copy_A, // Copy configuration for A + TiledCopyB smem_tiled_copy_B, // Copy configuration for B + ThrCopyA smem_thr_copy_A, // Thread copy for A + ThrCopyB smem_thr_copy_B // Thread copy for B +) { + // Load data based on sparsity pattern + if constexpr (!A_in_regs) { + copy(smem_tiled_copy_A, tCsA, tCrA); + } + if constexpr (!B_in_regs) { + copy(smem_tiled_copy_B, tCsB, tCrB); + } + + // Perform sparse matrix multiplication + // Only compute where active_mask indicates active positions + sparse_gemm_impl(tiled_mma, acc, tCrA, tCrB, active_mask); +} + +template +__forceinline__ __device__ void sparse_gemm_rs( + Tensor0 &acc, // Accumulator (attention scores) + Tensor1 &tCrA, // Query in registers + Tensor2 &tCrB, // Key in registers + Tensor3 &tCsA, // Query in shared memory + Tensor4 &tCsB, // Key in shared memory + Tensor5 &active_mask, // Active mask for sparsity + TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, + TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, + ThrCopyB smem_thr_copy_B +) { + // Row-major sparse GEMM variant optimized for Q*K^T computation + // Utilizes active mask to determine which K vectors to process +} +``` + +**Rationale**: +- **Computational Efficiency**: Skips matrix multiplication for masked regions, reducing FLOPs proportional to sparsity +- **Memory Bandwidth Optimization**: Avoids loading unnecessary data for masked positions +- **Flexible Sparsity Support**: Supports different sparsity patterns through the active mask tensor +- **Register/Shared Memory Optimization**: Provides variants for different data residency scenarios + +### 7. Attention Kernel Modifications (`flash_fwd_kernel.h`) + +**Purpose**: Integrate dynamic masking into the core attention computation kernels while maintaining Flash Attention's memory efficiency and optimization strategies. + +**Changes Made**: +```cpp +template +inline __device__ void compute_attn_1rowblock( + const Params ¶ms, + const int bidb, + const int bidh, + const int m_block +) { + // Initialize block information with ZOH and active mask offsets + const BlockInfo binfo(params, bidb, bidh, m_block); + + // Set up tensor views for ZOH states and active masks + Tensor mZOH = make_tensor(make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset), + make_shape(binfo.actual_seqlen_q, params.seqlen_k), + make_stride(params.zoh_row_stride, _1{})); + + Tensor mActiveMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset), + make_shape(binfo.actual_seqlen_q, params.seqlen_k), + make_stride(params.active_mask_row_stride, _1{})); + + // Main computation loop over key/value blocks + for (int n_block = n_block_min; n_block < n_block_max; ++n_block) { + // Load ZOH states and active masks for this block + copy_ZOH(tSgZOH, tSsZOH, tSrZOH, tSgActiveMask, tSsActiveMask, + tiled_mma, smem_tiled_copy_ZOH, smem_thr_copy_ZOH); + + // Perform sparse Q*K^T computation + sparse_gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tSrActiveMask, + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + // Apply dynamic masking (ZOH bias + active mask) + DynamicMask dynamic_mask(params.seqlen_k, params.seqlen_q, params.keep_window_size); + dynamic_mask.apply_mask(acc_s, tSrZOH, tSrActiveMask, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM, kBlockM); + + // Continue with softmax and attention*V computation + softmax.template softmax(acc_s); + + // Sparse attention*V computation + sparse_gemm_rs(acc_o, acc_s, tSrV, tSsS, tSsV, tSrActiveMask, + tiled_mma, smem_tiled_copy_S, smem_tiled_copy_V, + smem_thr_copy_S, smem_thr_copy_V); + } +} + +template +inline __device__ void compute_attn_1rowblock_splitkv( + const Params ¶ms, + const int bidb, + const int bidh, + const int m_block, + const int n_split_idx, + const int num_n_splits +) { + // Split-K variant with dynamic masking support + // Handles distributed computation across multiple thread blocks + // Maintains sparsity patterns across splits +} +``` + +**Rationale**: +- **Seamless Integration**: Dynamic masking logic integrated into existing Flash Attention computation flow +- **Memory Efficiency Preservation**: Maintains Flash Attention's tiling and shared memory optimization strategies +- **Split-K Support**: Extends dynamic masking to split-K attention variants for very long sequences +- **Template Specialization**: Compile-time optimization through template parameters + +### 8. Launch Template Updates (`flash_fwd_launch_template.h`) + +**Purpose**: Update kernel launch functions to properly configure and validate dynamic masking parameters, ensuring correct shared memory allocation and kernel selection. + +**Changes Made**: +```cpp +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // Calculate shared memory requirements including ZOH and active mask tensors + const size_t smem_size = Kernel_traits::kSmemSize + + Kernel_traits::kSmemSizeZOH + + Kernel_traits::kSmemSizeActiveMask; + + // Validate that shared memory requirements don't exceed device limits + TORCH_CHECK(smem_size <= 48 * 1024, "Shared memory requirement exceeds device limit"); + + // Set up grid dimensions + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + + // Determine kernel variant based on sequence lengths and alignment + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && + params.seqlen_k % Kernel_traits::kBlockN == 0 && + params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + + // Launch appropriate kernel variant with dynamic masking support + BOOL_SWITCH(is_even_MN, IsEvenMN, [&] { + BOOL_SWITCH(is_even_K, IsEvenK, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmax, [&] { + auto kernel = &flash_fwd_kernel; + + // Configure dynamic shared memory + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + // Launch kernel with extended parameter set + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // Split-K variant launch with dynamic masking support + // Handles cases where sequence length exceeds single kernel capacity + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + + // Configure split parameters based on sequence length and hardware capabilities + const int num_splits = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + // ... split-K launch logic with dynamic masking support +} +``` + +**Rationale**: +- **Resource Management**: Proper shared memory allocation and validation for extended tensor requirements +- **Kernel Selection**: Intelligent kernel variant selection based on problem size and hardware capabilities +- **Error Handling**: Comprehensive validation of parameters and device limits +- **Performance Optimization**: Compile-time optimizations through template specialization + +### 9. API Interface Extensions (`flash_api.cpp`) + +**Purpose**: Extend the Python-facing API to support dynamic masking tensors with comprehensive validation and backward compatibility. + +**Changes Made**: +```cpp +void set_params_fprop( + Flash_fwd_params ¶ms, + // ... existing parameters ... + const at::Tensor zoh, // ZOH states tensor + const at::Tensor active_mask, // Active mask tensor + const size_t keep_window_size, // Sparsity control parameter + // ... other parameters ... +) { + // Reset parameters and set basic properties + params = {}; + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set ZOH states pointers and strides + params.zoh_ptr = zoh.data_ptr(); + params.zoh_batch_stride = zoh.stride(-4); // [batch, head, query, key] + params.zoh_head_stride = zoh.stride(-3); + params.zoh_row_stride = zoh.stride(-2); + + // Set Active Mask pointers and strides + params.active_mask_ptr = active_mask.data_ptr(); + params.active_mask_batch_stride = active_mask.stride(-4); + params.active_mask_head_stride = active_mask.stride(-3); + params.active_mask_row_stride = active_mask.stride(-2); + + // Set sparsity control parameter + params.keep_window_size = keep_window_size; + + // ... existing parameter setup ... +} + +std::vector mha_fwd( + at::Tensor &q, // Query tensor + const at::Tensor &k, // Key tensor + const at::Tensor &v, // Value tensor + const at::Tensor &zoh, // ZOH states tensor + const at::Tensor &active_mask, // Active mask tensor + std::optional &out_, // Optional output tensor + const float p_dropout, + const float softmax_scale, + bool is_causal, + const int keep_window_size, // Sparsity control + const float softcap, + const bool return_softmax, + std::optional gen_ +) { + // Comprehensive input validation + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask); + CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v); + CHECK_CONTIGUOUS(zoh); CHECK_CONTIGUOUS(active_mask); + + // Validate tensor shapes + auto batch_size = q.size(0); + auto seqlen_q = q.size(1); + auto num_heads = q.size(2); + auto head_dim = q.size(3); + auto seqlen_k = k.size(1); + auto num_heads_k = k.size(2); + + CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + + // Validate data types consistency + TORCH_CHECK(q.dtype() == k.dtype() && k.dtype() == v.dtype(), + "All QKV tensors must have the same dtype"); + TORCH_CHECK(zoh.dtype() == q.dtype(), + "ZOH states must have the same dtype as QKV tensors"); + TORCH_CHECK(active_mask.dtype() == q.dtype(), + "Active mask must have the same dtype as QKV tensors"); + + // Validate sparsity parameter + TORCH_CHECK(keep_window_size > 0 && keep_window_size <= seqlen_k, + "keep_window_size must be positive and <= seqlen_k"); + + // Set up parameters and launch computation + Flash_fwd_params params; + set_params_fprop(params, batch_size, seqlen_q, seqlen_k, /* ... */, + q, k, v, zoh, active_mask, /* ... */, keep_window_size, /* ... */); + + // Launch kernel with appropriate configuration + run_mha_fwd(params, at::cuda::getCurrentCUDAStream()); + + // Return results + return {out, softmax_lse, /* ... */}; +} + +// Python binding +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashDynamicMaskAttention"; + m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass with dynamic masking", + py::arg("q"), py::arg("k"), py::arg("v"), + py::arg("zoh"), py::arg("active_mask"), // New required arguments + py::arg("out") = py::none(), + py::arg("p_dropout") = 0.0f, + py::arg("softmax_scale") = 0.0f, + py::arg("is_causal") = false, + py::arg("keep_window_size") = 2048, // New sparsity control + py::arg("softcap") = 0.0f, + py::arg("return_softmax") = false, + py::arg("gen") = py::none()); +} +``` + +**Rationale**: +- **Comprehensive Validation**: Thorough validation of all input tensors for shape, type, and device consistency +- **Backward Compatibility**: Maintains existing parameter order while adding new functionality +- **Error Handling**: Clear error messages for common usage mistakes +- **Type Safety**: Strict type checking to prevent runtime errors +- **Documentation**: Clear parameter documentation for Python users + +## Implementation Details + +### Python Frontend: Dynamic Mask Generation + +The Python frontend is responsible for computing the ZOH states and active masks before passing them to the CUDA backend: + +```python +def prepare_dynamic_mask( + hidden_states: torch.Tensor, + dt_states: torch.Tensor, + keep_window_size: int = 2048, + attention_mask: torch.Tensor = None, +): + """ + Core DMA function that generates dynamic attention masks for sparse computation. + + Process: + 1. Expand dt_states to match attention matrix dimensions + 2. Apply optional causal/padding masks + 3. Use TopK selection to identify most important positions + 4. Generate binary active mask for CUDA computation + """ + min_dtype = torch.finfo(hidden_states.dtype).min + dtype = hidden_states.dtype + + # Expand dt_states: [batch, num_heads, key_len] -> [batch, num_heads, query_len, key_len] + attn_mask = dt_states[:, :, None, :].expand(-1, -1, hidden_states.shape[2], -1) + + # Apply causal/padding masks by setting masked positions to -inf + if attention_mask is not None: + if attention_mask.dtype == torch.bool: + attention_mask = torch.where(attention_mask, 0.0, min_dtype) + attn_mask = attn_mask.masked_fill(attention_mask != 0, min_dtype) + + # Only apply when sequence length exceeds window size + if attn_mask.shape[-1] > keep_window_size: + active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) + # TopK selection identifies most important positions for each query + topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, + largest=True, sorted=False).indices + # Create binary mask: 1.0 for active positions, 0.0 for masked + active_mask = active_mask.scatter(-1, topk_indices, 1.0) + # Set non-selected positions to -inf in attention mask + attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) + else: + # If sequence length is within window size, all positions are active + active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) + return attn_mask, active_mask +``` + +### CUDA Backend: Sparse Attention Computation + +The CUDA backend implements three key stages of sparse attention: + +#### Stage 1: Memory Loading and Tensor Setup +```cpp +// Set up tensor views for ZOH states and active masks +Tensor mZOH = make_tensor( + make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset), + make_shape(binfo.actual_seqlen_q, params.seqlen_k), + make_stride(params.zoh_row_stride, _1{}) +); + +Tensor mActiveMask = make_tensor( + make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset), + make_shape(binfo.actual_seqlen_q, params.seqlen_k), + make_stride(params.active_mask_row_stride, _1{}) +); + +// Load data through memory hierarchy: Global -> Shared -> Registers +copy_ZOH(tSgZOH, tSsZOH, tSrZOH, tSgActiveMask, tSsActiveMask, + tiled_mma, smem_tiled_copy_ZOH, smem_thr_copy_ZOH); +``` + +#### Stage 2: Sparse Q*K^T Computation +```cpp +// Sparse GEMM that skips computation for masked positions +sparse_gemm_rs(acc_s, tSrQ, tSrK, tSsQ, tSsK, tSrActiveMask, + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + +// Apply dynamic masking: scaling + ZOH bias + masking +DynamicMask dynamic_mask(params.seqlen_k, params.seqlen_q, params.keep_window_size); +dynamic_mask.apply_mask(acc_s, tSrZOH, tSrActiveMask, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM, kBlockM); +``` + +#### Stage 3: Softmax and Sparse Attention*V +```cpp +// Online softmax computation (unchanged from Flash Attention) +softmax.template online_softmax(acc_s); + +// Sparse attention*V computation +sparse_gemm(acc_o, acc_s, tSrV, tSsS, tSsV, tSrActiveMask, + tiled_mma, smem_tiled_copy_S, smem_tiled_copy_V, + smem_thr_copy_S, smem_thr_copy_V); +``` + +## Sparse Computation Strategy + +### Sparsity Pattern Recognition + +The Dynamic Mask Attention implements structured sparsity based on learned importance scores: + +1. **ZOH State Computation**: `dt_states = exp(A * softplus(V @ dt_proj^T))` + - Learned projection matrix `dt_proj` maps value features to importance scores + - Coefficient `A` controls the dynamic range of importance values + - Exponential activation ensures positive importance scores + +2. **TopK Selection**: For sequences longer than `keep_window_size`: + - Select top-K most important positions per query token + - K = `keep_window_size` (typically 512-2048) + - Maintains fixed computational complexity regardless of sequence length + +3. **Binary Active Mask**: + - 1.0 for positions selected by TopK (compute) + - 0.0 for positions not selected (skip computation) + +### Sparse GEMM Implementation + +The sparse GEMM operations leverage the active mask to skip computation: + +```cpp +template +__forceinline__ __device__ void sparse_gemm_impl( + TiledMma tiled_mma, + AccType &acc, + AType &tCrA, + BType &tCrB, + MaskType &active_mask +) { + // Convert layouts for efficient indexing + auto acc_rowcol = make_tensor(acc.data(), convert_layout_acc_rowcol(acc.layout())); + auto mask_rowcol = make_tensor(active_mask.data(), convert_layout_acc_rowcol(active_mask.layout())); + + #pragma unroll + for (int mi = 0; mi < size<0, 1>(acc_rowcol); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1, 1>(acc_rowcol); ++ni) { + // Check if this position should be computed + if (mask_rowcol(0, mi, 0, ni) != 0.0f) { + // Perform computation only for active positions + gemm(tiled_mma, acc(_, mi, _, ni), tCrA(_, mi, _), tCrB(_, _, ni)); + } + // Skip computation for masked positions (acc remains unchanged) + } + } +} +``` + +### Memory Efficiency Optimizations + +1. **Shared Memory Reuse**: ZOH states and active masks share copy infrastructure with Q/K/V tensors +2. **Register Allocation**: Critical masking operations performed in registers to minimize memory traffic +3. **Coalesced Access**: Memory access patterns optimized for GPU memory hierarchy +4. **Template Specialization**: Compile-time optimization eliminates runtime branching + +## Memory Layout + +### Tensor Memory Organization + +The Dynamic Mask Attention extends Flash Attention's memory layout to include ZOH states and active masks: + +``` +Global Memory Layout: +┌─────────────────────────────────────────────────────────────────┐ +│ Q: [batch, seqlen_q, num_heads, head_dim] │ +│ K: [batch, seqlen_k, num_heads_k, head_dim] │ +│ V: [batch, seqlen_k, num_heads_k, head_dim] │ +│ ZOH: [batch, num_heads_k, seqlen_q, seqlen_k] │ +│ AM: [batch, num_heads_k, seqlen_q, seqlen_k] │ +│ Output: [batch, seqlen_q, num_heads, head_dim] │ +└─────────────────────────────────────────────────────────────────┘ + +Shared Memory Layout (per thread block): +┌─────────────────────────────────────────────────────────────────────┐ +│ Q Tile: [kBlockM, head_dim] │ K Tile: [kBlockN, head_dim] │ +│ V Tile: [kBlockN, head_dim] │ S Tile: [kBlockM, kBlockN] │ +│ ZOH Tile: [kBlockM, kBlockN] │ AM Tile: [kBlockM, kBlockN] │ +└─────────────────────────────────────────────────────────────────────┘ + +Register Memory (per thread): +┌─────────────────────────────────────────────────────────────────────┐ +│ Q Frag: [MMA_M, head_dim/N] │ K Frag: [MMA_N, head_dim/N] │ +│ V Frag: [MMA_N, head_dim/N] │ S Frag: [MMA_M, MMA_N] │ +│ ZOH Frag: [MMA_M, MMA_N] │ AM Frag: [MMA_M, MMA_N] │ +│ Acc Frag: [MMA_M, head_dim/N] │ │ +└─────────────────────────────────────────────────────────────────────┘ +``` -### Theoretical Analysis - -The integrated Flash-DMA approach offers quantifiable performance benefits: - -#### Memory Complexity - -- **Flash Attention**: O(B×H×N) memory usage -- **Flash-DMA**: O(B×H×N) + O(sparse indexing overhead) -- **Overhead**: Minimal additional memory for sparse pattern storage - -#### Computational Complexity - -- **Flash Attention**: O(B×H×N²×D) operations -- **Flash-DMA**: O(B×H×N×k×D) operations where k < N -- **Speedup**: Theoretical speedup of N/k - -#### Expected Performance Gains - -| Sequence Length | Selection Ratio (k/N) | Theoretical Speedup | Estimated Practical Speedup | -|-----------------|----------------------|---------------------|----------------------------| -| 4,096 | 0.25 | 4.0× | 2.5-3.0× | -| 16,384 | 0.125 | 8.0× | 4.0-5.0× | -| 65,536 | 0.0625 | 16.0× | 6.0-8.0× | -| 262,144 | 0.03125 | 32.0× | 8.0-12.0× | - -### Benchmarking Strategy - -Comprehensive benchmarking includes: - -#### Performance Metrics - -Key metrics for evaluation: -- **Throughput**: Tokens processed per second -- **Memory Usage**: Peak and average memory consumption -- **Energy Efficiency**: Performance per watt measurements -- **Accuracy**: Quality metrics compared to full attention - -#### Test Scenarios - -Diverse testing scenarios: -- **Synthetic Workloads**: Controlled tests with known characteristics -- **Real Applications**: Language modeling and document processing tasks -- **Scaling Studies**: Performance across different sequence lengths and batch sizes - -### Validation Framework - -Ensuring correctness and performance: - -#### Correctness Validation - -- **Output Comparison**: Bit-level comparison with reference implementations -- **Numerical Stability**: Testing across different numerical ranges -- **Edge Case Handling**: Validation of boundary conditions - -#### Performance Validation +### Memory Access Patterns -- **Baseline Comparison**: Performance relative to Flash Attention and standard attention -- **Regression Testing**: Ensuring performance doesn't degrade over time -- **Hardware Scaling**: Validation across different GPU architectures +#### ZOH States and Active Mask Loading +```cpp +// Global to Shared Memory (coalesced access) +Tensor tSgZOH = local_partition(mZOH, smem_tiled_copy_ZOH, thread_idx); +Tensor tSsZOH = local_partition(sZOH, smem_tiled_copy_ZOH, thread_idx); + +// Each thread loads a contiguous chunk to maximize memory bandwidth +copy(smem_tiled_copy_ZOH, tSgZOH, tSsZOH); + +// Shared to Register Memory (bank-conflict-free) +Tensor tSrZOH = local_partition(sZOH, smem_thr_copy_ZOH, thread_idx); +copy(smem_thr_copy_ZOH, tSsZOH, tSrZOH); +``` + +#### Memory Layout Transformations +```cpp +// Convert MMA accumulator layout to row-column layout for masking +// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) +auto convert_layout_acc_rowcol = [](auto layout) { + return make_layout( + make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), + make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), + make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), + make_stride(Int<1>{}, Int<2>{})) + ); +}; +``` + +### Shared Memory Optimization + +#### Bank Conflict Avoidance +- ZOH states and active masks use the same copy patterns as Q/K/V to avoid bank conflicts +- Padding added when necessary to ensure 128-bit aligned access +- Thread block size chosen to maximize occupancy while maintaining memory efficiency + +#### Memory Coalescing +```cpp +// Example: Loading 128-bit aligned chunks for optimal bandwidth +using SmemCopyAtomZOH = Copy_Atom; // 128-bit loads +using SmemCopyAtomActiveMask = Copy_Atom; +``` + +## Performance Considerations + +### Memory Efficiency +- **Reduced Memory Bandwidth**: Sparse computation reduces memory traffic +- **Optimized Layouts**: Tensor layouts optimized for GPU memory hierarchy +- **Shared Memory Reuse**: Efficient use of limited shared memory resources + +### Computational Efficiency +- **Sparse GEMM**: Skips computation for masked regions +- **Fused Operations**: Masking integrated into existing computation kernels +- **Warp-Level Optimization**: Optimized for GPU warp execution model + +### Scalability +- **Long Sequence Support**: Efficient handling of sequences > 32K tokens +- **Configurable Sparsity**: `keep_window_size` parameter controls sparsity level +- **Multi-Head Support**: Efficient handling of multiple attention heads + +## API Changes + +### New Required Parameters + +The Dynamic Mask Attention integration introduces new required parameters to the forward pass: + +- **`zoh`** (`torch.Tensor`): ZOH states tensor of shape `(batch, num_heads_k, seqlen_q, seqlen_k)` + - Contains dynamic attention bias values derived from value states + - Must have the same dtype and device as Q/K/V tensors + +- **`active_mask`** (`torch.Tensor`): Active mask tensor of shape `(batch, num_heads_k, seqlen_q, seqlen_k)` + - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed + - Determines the sparsity pattern for computational efficiency + +- **`keep_window_size`** (`int`): Sparsity control parameter + - Maximum number of key positions to attend to per query token + - Controls the computational complexity and memory usage + - Typical values: 512-2048 for long sequences + +### Updated Function Signature + +```python +def fwd( + q: torch.Tensor, # Query tensor + k: torch.Tensor, # Key tensor + v: torch.Tensor, # Value tensor + zoh: torch.Tensor, # ZOH states (NEW) + active_mask: torch.Tensor, # Active mask (NEW) + out: Optional[torch.Tensor] = None, # Pre-allocated output + p_dropout: float = 0.0, # Dropout probability + softmax_scale: float = None, # Attention scaling + is_causal: bool = False, # Causal masking + keep_window_size: int = 2048, # Sparsity control (NEW) + softcap: float = 0.0, # Soft capping + return_softmax: bool = False, # Return attention weights + gen: Optional[torch.Generator] = None # Random generator +) -> List[torch.Tensor] +``` + +### Backward Compatibility + +**Breaking Change Notice**: The integration requires ZOH states and active mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. + +**Migration Path**: Users need to: +1. Implement ZOH state computation using the `prepare_dynamic_mask` function +2. Update function calls to include the new required parameters +3. Choose appropriate `keep_window_size` values based on their use case + +### Complete Usage Example + +```python +import torch +import torch.nn.functional as F +import flash_dma + +# Setup +batch_size, seqlen_q, seqlen_k = 2, 4096, 4096 +num_heads, head_dim = 12, 128 +device, dtype = 'cuda', torch.bfloat16 + +# Input tensors +q = torch.randn(batch_size, seqlen_q, num_heads, head_dim, device=device, dtype=dtype) +k = torch.randn(batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype) +v = torch.randn(batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype) + +# Dynamic Mask Attention requires additional parameters +dt_proj = torch.randn(num_heads, num_heads * head_dim, device=device, dtype=dtype) +A = torch.randn(num_heads, device=device, dtype=dtype) + +# Step 1: Compute ZOH states +dt_states = torch.matmul( + v.transpose(-2, -3).reshape(batch_size, seqlen_k, -1), + dt_proj.T +) +dt_states = torch.exp(A * F.softplus(dt_states)).transpose(-1, -2) + +# Step 2: Generate dynamic masks +zoh_states, active_mask = flash_dma.prepare_dynamic_mask( + q, dt_states, keep_window_size=2048, attention_mask=None +) + +# Step 3: Run Dynamic Mask Attention +output = flash_dma.fwd( + q, k, v, zoh_states, active_mask, + keep_window_size=2048, + softmax_scale=1.0 / (head_dim ** 0.5), + is_causal=False +) + +print(f"Output shape: {output[0].shape}") # [batch_size, seqlen_q, num_heads, head_dim] +``` + +### Integration with Existing Codebases + +For users migrating from Flash Attention, the typical changes required are: + +```python +# Before (Flash Attention) +output = flash_attn.flash_attn_func(q, k, v, dropout_p=0.1, softmax_scale=scale, causal=True) + +# After (Dynamic Mask Attention) +# 1. Add ZOH computation +dt_states = compute_dt_states(v, dt_proj, A) +zoh_states, active_mask = prepare_dynamic_mask(q, dt_states, keep_window_size=2048) + +# 2. Update function call +output = flash_dma.fwd(q, k, v, zoh_states, active_mask, + p_dropout=0.1, softmax_scale=scale, is_causal=True, + keep_window_size=2048) +``` + +## Future Enhancements + +### Planned Improvements + +1. **Backward Pass Integration**: Complete gradient computation support for training Dynamic Mask Attention models + - Sparse gradient computation for ZOH states + - Efficient gradient propagation through active masks + - Memory-optimized backward kernels + +2. **Adaptive Sparsity Patterns**: Dynamic adjustment of attention patterns based on input characteristics + - Learned sparsity controllers + - Content-aware mask generation + - Adaptive `keep_window_size` selection + +3. **Multi-GPU Distributed Support**: Optimizations for large-scale distributed training + - Efficient tensor parallelism for long sequences + - Communication-optimal attention computation + - Memory-balanced workload distribution + +4. **Advanced Memory Optimizations**: Further reduce memory footprint for extremely long sequences + - Progressive attention computation + - Hierarchical sparsity patterns + - Memory-efficient checkpoint/recomputation strategies + +5. **Hardware-Specific Optimizations**: Leverage newer GPU architectures + - Hopper architecture optimizations + - Sparse Tensor Core utilization + - Advanced memory hierarchy exploitation + +### Performance Targets + +- **Sequence Length**: Support up to 1M+ tokens efficiently +- **Memory Reduction**: 50-80% memory savings compared to dense attention +- **Speed**: Maintain or improve upon Flash Attention performance for long sequences +- **Sparsity**: Flexible sparsity ratios from 10% to 90% depending on use case ## Conclusion -### Summary of Benefits - -The integrated Flash-DMA approach delivers significant advantages: - -1. **Memory Efficiency**: Maintains Flash Attention's O(N) memory complexity -2. **Computational Efficiency**: Achieves O(N×k) computational complexity through sparsity -3. **Scalability**: Enables processing of extremely long sequences (100K+ tokens) -4. **Adaptability**: Provides content-adaptive attention patterns -5. **Hardware Optimization**: Maximizes GPU utilization through optimized memory and compute patterns - -### Future Directions - -Potential areas for future development: +The Dynamic Mask Attention integration successfully combines Flash Attention's memory efficiency with structured sparsity to enable efficient processing of extremely long sequences. The implementation maintains the core optimization principles of Flash Attention while adding the capability to skip computation for less important token interactions. -#### Algorithmic Improvements +Key achievements of this integration: -- **Advanced Selection Criteria**: More sophisticated methods for key selection -- **Dynamic Sparsity Adaptation**: Runtime adjustment of sparsity levels -- **Multi-level Sparsity**: Hierarchical sparsity patterns for different sequence regions +1. **Seamless Integration**: All dynamic masking functionality integrated into Flash Attention's kernel architecture without compromising existing optimizations -#### Implementation Optimizations +2. **Comprehensive Implementation**: Complete pipeline from Python preprocessing to optimized CUDA kernels with proper memory management -- **Multi-GPU Support**: Scaling to multiple GPUs for even longer sequences -- **Specialized Hardware**: Optimizations for next-generation hardware architectures -- **Mixed Precision Enhancement**: Advanced mixed-precision strategies +3. **Flexible Sparsity Control**: Configurable sparsity levels through the `keep_window_size` parameter to balance quality and efficiency -#### Application Extensions +4. **Robust Validation**: Extensive testing infrastructure ensures numerical equivalence with reference implementations -- **Domain-Specific Optimizations**: Tailored versions for specific application domains -- **Integration with Other Techniques**: Combination with other attention optimization methods -- **Training Optimizations**: Specialized versions optimized for training workloads +5. **Performance Optimization**: Sparse computation patterns reduce both memory usage and computational overhead for long sequences -The Flash-DMA integration represents a significant advancement in attention mechanism efficiency, enabling new possibilities for long-context applications while maintaining the reliability and performance characteristics that make Flash Attention a fundamental building block for modern transformer architectures. \ No newline at end of file +This integration enables practitioners to efficiently handle very long sequences in transformer models while maintaining the numerical stability and optimization benefits that have made Flash Attention the standard for efficient attention computation. \ No newline at end of file