diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py index a7b7e2a..dfaa74a 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_dmattn/flash_dmattn_flex.py @@ -8,14 +8,17 @@ def flex_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: torch.Tensor, - attention_bias: torch.Tensor, + attn_mask: torch.Tensor, + attn_bias: torch.Tensor, + scale: Optional[float] = None, is_causal: bool = True, - scaling: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - attn_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_bias = attention_bias[:, :, :, : key.shape[-2]] + query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] + key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] + value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] + attn_mask = attn_mask[:, :, :, : key.shape[-2]] + attn_bias = attn_bias[:, :, :, : key.shape[-2]] def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] @@ -44,23 +47,21 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): "num_stages": 1, "num_warps": 8, } - attn_output, attention_weights = compile_friendly_flex_attention( + attn_output = compile_friendly_flex_attention( query, key, value, score_mod=score_mod, block_mask=block_mask if is_causal else None, - scale=scaling, + scale=scale, kernel_options=kernel_options, # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. # For simplification, we thus always return it as no additional computations are introduced. - return_lse=True, + return_lse=False, training=False, ) - # lse is returned in float32 - attention_weights = attention_weights.to(value.dtype) attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attention_weights + return attn_output flex_dmattn_func = flex_attention_forward \ No newline at end of file diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 619042e..7a05066 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -846,7 +846,7 @@ def _bwd_kernel( ) -def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None): +def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape @@ -919,7 +919,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None): seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - causal, + is_causal, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -930,7 +930,7 @@ def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None): def _flash_attn_backward( - do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, causal=False, softmax_scale=None + do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, softmax_scale=None, is_causal=False ): # Make sure that the last dimension is contiguous if do.stride(-1) != 1: @@ -1040,7 +1040,7 @@ def _flash_attn_backward( seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - causal, + is_causal, BLOCK_HEADDIM, # SEQUENCE_PARALLEL=False, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -1052,63 +1052,64 @@ def _flash_attn_backward( class FlashDMAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): + def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False): """ - q: (batch_size, seqlen_q, nheads, headdim) - k: (batch_size, seqlen_k, nheads, headdim) - v: (batch_size, seqlen_k, nheads, headdim) - mask: optional, (batch, nheads, seqlen_q, seqlen_k) - bias: optional, (batch, nheads, seqlen_q, seqlen_k) - causal: bool, whether to apply causal masking + query: (batch_size, seqlen_q, nheads, headdim) + key: (batch_size, seqlen_k, nheads, headdim) + value: (batch_size, seqlen_k, nheads, headdim) + attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k) + attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k) softmax_scale: float, scaling factor for attention scores + is_causal: bool, whether to apply causal masking """ - batch, seqlen_q, nheads, _ = q.shape - _, seqlen_k, _, _ = k.shape - if mask is not None: - if mask.dtype == torch.bool: - mask = torch.where(mask, 1.0, 0.0) + batch, seqlen_q, nheads, _ = query.shape + _, seqlen_k, _, _ = key.shape + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask = torch.where(attn_mask, 1.0, 0.0) else: - mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) - if bias is None: - bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) + if attn_bias is None: + attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) # Make sure that the last dimension is contiguous - q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]] + query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]] o, lse, ctx.softmax_scale = _flash_attn_forward( - q, k, v, mask, bias, causal=causal, softmax_scale=softmax_scale + query, key, value, attn_mask, attn_bias, softmax_scale=softmax_scale, is_causal=is_causal ) - ctx.save_for_backward(q, k, v, o, lse, mask, bias) - ctx.causal = causal + ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias) + ctx.is_causal = is_causal return o @staticmethod def backward(ctx, do): - q, k, v, o, lse, mask, bias = ctx.saved_tensors + query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet" # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode(): - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - dbias = torch.empty_like(bias) + dq = torch.empty_like(query) + dk = torch.empty_like(key) + dv = torch.empty_like(value) + dbias = torch.empty_like(attn_bias) _flash_attn_backward( do, - q, - k, - v, - mask, - bias, + query, + key, + value, + attn_mask, + attn_bias, o, lse, dq, dk, dv, dbias, - causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_causal=ctx.is_causal, ) return dq, dk, dv, None, dbias, None, None -triton_dmattn_func = FlashDMAttnFunc.apply +def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False): + return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal)