Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Remove local attention and ALiBi support to simplify the backward pass. Introduce mask and bias tensor support with updated memory layouts and efficient data transfer operations. Rename variables for clarity and improve code readability. Fix tensor references for consistent device and shape alignment.

Simplifies the backward pass implementation by removing support for local attention patterns and Additive Linear Bias (ALiBi) features.

Updates copyright to include new contributor and removes unnecessary template parameters and conditional logic blocks that handled these specialized attention mechanisms.

Streamlines the kernel to focus on core functionality without the complexity of windowed attention and positional bias computation.
Introduces shared memory layouts and copy operations for mask and bias tensors in the backward kernel configuration.

Updates memory size calculations to account for the additional mask and bias storage requirements.

Adds specialized copy atoms with 64-byte alignment for efficient mask and bias data transfers.
Updates variable and parameter names from "dzoh" (dZeroHold) to "dbias"
throughout the Flash backward parameters structure to better reflect
their actual purpose as bias gradients.

Improves code readability and maintains consistency with naming conventions.
Introduces tensor definitions for attention mask, bias, and bias gradient to enable masked attention and bias computations in the backward kernel.

Calculates proper memory offsets for mask and bias tensors based on batch, head, and block dimensions to ensure correct data access patterns during gradient computation.
Corrects the tensor reference used for creating the attention mask when not using dynamic masking, ensuring consistent device and shape alignment with the attention bias tensor.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the backward attention kernel by removing local-attention and ALiBi support, and introducing explicit mask and bias tensor support throughout the backward pass.

  • Adds shared-memory layouts, copy-atom definitions, and smem-size calculations for attention masks and biases
  • Removes Is_local and Has_alibi template parameters and associated ALiBi/local masking code
  • Renames backward parameters in Flash_bwd_params (e.g., dzoh_*dbias_*) for clarity

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
csrc/src/kernel_traits.h Define SmemLayoutMask/SmemLayoutBias, SmemCopyAtomBias, update kSmemSize to include mask & bias
csrc/src/flash_bwd_kernel.h Drop ALiBi and local‐attention code, load gMask/gBias/gdBias, adjust template signatures
csrc/src/flash.h Rename dzoh_* fields and pointers in Flash_bwd_params to dbias_*
benchmarks/benchmark_forward_equivalence.py Change default mask init to use attn_bias
Comments suppressed due to low confidence (1)

csrc/src/flash_bwd_kernel.h:138

  • The loaded gMask, gBias, and gdBias tensors are never used later in the kernel, and no bias is applied to scores nor is gdBias written back to global memory. You should integrate mask/bias into the score computation (e.g., scores += gBias; scores = scores * gMask;) and add a final copy for gdBias, or remove these unused allocations.
    Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),

attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device)
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
Copy link

Copilot AI Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing attn_mask from attn_bias may yield an incorrect shape if the two tensors differ. Consider using torch.ones_like(attn_mask) or explicitly matching dimensions to ensure the mask has the intended shape.

Suggested change
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones(
attn_bias.shape, dtype=dtype, device=attn_bias.device
)

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit a0160ea into main Jul 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants