Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions flash_dmattn/flash_dmattn_flex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ def flex_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
attention_bias: torch.Tensor,
attn_mask: torch.Tensor,
attn_bias: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = True,
scaling: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
attn_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_bias = attention_bias[:, :, :, : key.shape[-2]]
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]
key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
Comment on lines +20 to +21
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor slicing uses key.shape[-2] but after transposition on line 18, the key tensor shape has changed. This should use the sequence length dimension from the transposed tensor, which would be key.shape[2] instead of key.shape[-2].

Suggested change
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
attn_mask = attn_mask[:, :, :, : key.shape[2]]
attn_bias = attn_bias[:, :, :, : key.shape[2]]

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the attn_mask slicing, this uses key.shape[-2] but should use key.shape[2] after the tensor transposition performed on line 18.

Suggested change
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[2]]

Copilot uses AI. Check for mistakes.

def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]
Expand Down Expand Up @@ -44,23 +47,21 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"num_stages": 1,
"num_warps": 8,
}
attn_output, attention_weights = compile_friendly_flex_attention(
attn_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask if is_causal else None,
scale=scaling,
scale=scale,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
return_lse=False,
training=False,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
return attn_output

flex_dmattn_func = flex_attention_forward
71 changes: 36 additions & 35 deletions flash_dmattn/flash_dmattn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _bwd_kernel(
)


def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False):
# shape constraints
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
Expand Down Expand Up @@ -919,7 +919,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):
seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal,
is_causal,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
Expand All @@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None):


def _flash_attn_backward(
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, causal=False, softmax_scale=None
do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, softmax_scale=None, is_causal=False
):
# Make sure that the last dimension is contiguous
if do.stride(-1) != 1:
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def _flash_attn_backward(
seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal,
is_causal,
BLOCK_HEADDIM,
# SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
Expand All @@ -1052,63 +1052,64 @@ def _flash_attn_backward(

class FlashDMAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None):
def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False):
"""
q: (batch_size, seqlen_q, nheads, headdim)
k: (batch_size, seqlen_k, nheads, headdim)
v: (batch_size, seqlen_k, nheads, headdim)
mask: optional, (batch, nheads, seqlen_q, seqlen_k)
bias: optional, (batch, nheads, seqlen_q, seqlen_k)
causal: bool, whether to apply causal masking
query: (batch_size, seqlen_q, nheads, headdim)
key: (batch_size, seqlen_k, nheads, headdim)
value: (batch_size, seqlen_k, nheads, headdim)
attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k)
attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k)
softmax_scale: float, scaling factor for attention scores
is_causal: bool, whether to apply causal masking
"""
batch, seqlen_q, nheads, _ = q.shape
_, seqlen_k, _, _ = k.shape
if mask is not None:
if mask.dtype == torch.bool:
mask = torch.where(mask, 1.0, 0.0)
batch, seqlen_q, nheads, _ = query.shape
_, seqlen_k, _, _ = key.shape
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask = torch.where(attn_mask, 1.0, 0.0)
else:
mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype)
if bias is None:
bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype)
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
if attn_bias is None:
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)

# Make sure that the last dimension is contiguous
q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]]
query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]]
o, lse, ctx.softmax_scale = _flash_attn_forward(
q, k, v, mask, bias, causal=causal, softmax_scale=softmax_scale
query, key, value, attn_mask, attn_bias, softmax_scale=softmax_scale, is_causal=is_causal
)
ctx.save_for_backward(q, k, v, o, lse, mask, bias)
ctx.causal = causal
ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias)
ctx.is_causal = is_causal
return o

@staticmethod
def backward(ctx, do):
q, k, v, o, lse, mask, bias = ctx.saved_tensors
query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors
assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbias = torch.empty_like(bias)
dq = torch.empty_like(query)
dk = torch.empty_like(key)
dv = torch.empty_like(value)
dbias = torch.empty_like(attn_bias)
_flash_attn_backward(
do,
q,
k,
v,
mask,
bias,
query,
key,
value,
attn_mask,
attn_bias,
o,
lse,
dq,
dk,
dv,
dbias,
causal=ctx.causal,
softmax_scale=ctx.softmax_scale,
is_causal=ctx.is_causal,
)
return dq, dk, dv, None, dbias, None, None


triton_dmattn_func = FlashDMAttnFunc.apply
def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False):
return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal)