Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

This pull request introduces a new header file, flash.h, which defines the core structures and functions for a CUDA-based multi-head attention mechanism. The changes include the addition of parameter structures for forward and backward passes, utility constants, and function templates for executing the attention mechanism.

Core functionality for multi-head attention:

  • Definition of parameter structures:

    • Added QKV_params to encapsulate query, key, and value tensor pointers, strides, and head-related dimensions.
    • Added ZeroHold_params to manage zero-hold states, causal masks, and associated strides for attention mechanisms.
    • Introduced Flash_fwd_params and Flash_bwd_params to extend the above structures for forward and backward pass parameters, including dropout, scaling factors, and random state handling.
  • Function templates for execution:

    • Added templates run_mha_fwd_, run_mha_fwd_splitkv_dispatch, and run_mha_bwd_ for executing forward and backward multi-head attention operations with CUDA streams.
  • Namespace organization:

    • Encapsulated all additions within FLASH_NAMESPACE for modularity and clarity.

@LoserCheems LoserCheems merged commit 8890fe7 into main May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants