Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Refactor bias initialization in mha_bwd to use torch::zeros and remove unnecessary zeroing. Enhance FlashDMAttnFunc to track the original sequence length bias and adjust the computation of dbias accordingly.

Copilot AI review requested due to automatic review settings October 12, 2025 15:56
@LoserCheems LoserCheems merged commit 4dd087d into main Oct 12, 2025
1 check passed
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors bias initialization and enhances bias computation in the FlashDMAttnFunc. The changes optimize memory allocation by using torch::zeros instead of torch::empty followed by explicit zeroing, and add logic to track and adjust bias dimensions based on the original sequence length.

  • Replace torch::empty + .zero_() with direct torch::zeros for bias initialization
  • Track original bias sequence length (seqlen_k_bias_og) in the forward pass
  • Conditionally sum bias gradients when original bias had sequence length of 1

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
flash_dmattn/flash_dmattn_interface.py Adds tracking of original bias sequence length and conditional summing of bias gradients
csrc/flash_dmattn/flash_api.cpp Optimizes bias initialization by using torch::zeros directly

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.


if dbias is not None:
dbias = dbias[..., : k.shape[1]]
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]]
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This line is quite complex and hard to read. Consider breaking it into multiple lines or extracting the conditional logic into a separate variable for better readability.

Suggested change
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]]
if ctx.seqlen_k_bias_og == 1:
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True)
else:
dbias = dbias[..., :k.shape[1]]

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants