Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Implements bias gradient calculation in the backward pass by adding a new column-block kernel that computes DBias alongside existing DQ, DK, and DV gradients.

Updates initialization function to support multiple tensor names and modifies autotuning configurations to initialize both DQ and DBias tensors.

Includes extensive masking logic and memory access patterns to handle various sequence length and head dimension configurations while maintaining numerical stability.

Implements bias gradient calculation in the backward pass by adding a new column-block kernel that computes DBias alongside existing DQ, DK, and DV gradients.

Updates initialization function to support multiple tensor names and modifies autotuning configurations to initialize both DQ and DBias tensors.

Includes extensive masking logic and memory access patterns to handle various sequence length and head dimension configurations while maintaining numerical stability.
Copilot AI review requested due to automatic review settings July 7, 2025 08:45
@LoserCheems LoserCheems requested review from Evanwu1125, SNHuan and wubingheng111 and removed request for Copilot July 7, 2025 08:45
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Implements bias gradient computation in the Triton backward kernel, updates the tensor initialization helper to zero multiple tensors, and adjusts autotuning configurations to include the new DBias tensor.

  • Introduces _bwd_kernel_one_col_block to compute DBias alongside existing gradients.
  • Replaces the single-name init_to_zero helper with a version that accepts multiple tensor names.
  • Updates autotuning pre_hook entries to zero both DQ and DBias.
Comments suppressed due to low confidence (2)

flash_dmattn/flash_dmattn_triton.py:339

  • Consider adding a docstring above the _bwd_kernel_one_col_block Triton kernel to describe its purpose, parameters, and outputs, which will improve readability and onboarding for new contributors.
@triton.jit

flash_dmattn/flash_dmattn_triton.py:541

  • The new bias gradient computation introduces both atomic and non-atomic code paths; add unit tests covering these scenarios (varying ATOMIC_ADD, block sizes, and mask configurations) to ensure correctness across configurations.
        dbias = (p * (dp - Di[:, None]))

def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@triton.jit
def _bwd_kernel_one_col_block(
Copy link

Copilot AI Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This kernel function spans 300+ lines and handles multiple responsibilities; consider breaking it into smaller helper functions (e.g., pointer setup, mask application, gradient computation) to improve readability and ease future maintenance.

Copilot uses AI. Check for mistakes.
Comment on lines 675 to +679
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])),
Copy link

Copilot AI Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There are several commented-out Triton configuration lines that include DBias; once DBias support is verified, consider removing or updating these stale lines to reduce clutter.

Copilot uses AI. Check for mistakes.
Updates the function alias to better reflect its Triton-based implementation, improving code readability and making the backend technology more explicit for developers.
Creates a more convenient function name that follows
the module's naming convention and improves code readability
@LoserCheems LoserCheems merged commit 2dd4d26 into main Jul 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants