diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index da66343..024674f 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -87,11 +87,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index c08600b..afccef9 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -109,11 +109,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 97e80f9..94fc0cf 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -87,11 +87,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 967e009..b540bf6 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -109,11 +109,11 @@ def prepare_dynamic_mask( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ) valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) + attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 93cde0e..b4a657e 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -361,7 +361,7 @@ mha_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); @@ -512,7 +512,7 @@ mha_varlen_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); @@ -749,7 +749,7 @@ mha_bwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); @@ -951,7 +951,7 @@ mha_varlen_bwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); @@ -1136,7 +1136,7 @@ mha_varlen_bwd( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashDynamicMaskAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); - m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); + // m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); - m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); + // m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); } diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 643cfcd..723265b 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -144,7 +144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in make_stride(params.v_row_stride, _1{}) ); Tensor gMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + row_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + row_offset_mask), Shape, Int>{}, make_stride(params.mask_row_stride, _1{}) ); @@ -552,14 +552,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); // Do OR-reduce on the mask to see if any active threads Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); @@ -807,14 +808,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (m_block > m_block_min) { // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // FLASH_NAMESPACE::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration any_active_local_next = false; diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index c98e3b1..79dcbfb 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -169,7 +169,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(_, 0) ); // (kBlockN, kHeadDim, nblocksN) Tensor mMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); @@ -344,15 +344,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // Reverse iteration over N blocks int n_block = n_block_max - 1; - - FLASH_NAMESPACE::copy_MN( + + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads @@ -470,14 +470,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -593,14 +593,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask(_, _, _, n_block - 1), tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -873,7 +873,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons make_stride(params.v_row_stride, _1{}) ); Tensor gMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), + make_gmem_ptr(reinterpret_cast(params.mask_ptr) + col_offset_mask), Shape, Int>{}, make_stride(params.mask_row_stride, _1{}) ); @@ -999,14 +999,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -1146,14 +1146,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - cute::cp_async_fence(); - FLASH_NAMESPACE::cp_async_wait<0>(); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration @@ -1287,12 +1287,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( + FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_MaskBias, tMaskgMask, tMasksMask, tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); // Do OR-reduce on the mask to see if any active threads for next iteration any_active_local_next = false; diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index 81c716a..be28c10 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -521,7 +521,7 @@ __forceinline__ __device__ void copy( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=true, + bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 > @@ -543,7 +543,14 @@ __forceinline__ __device__ void copy_MN( #pragma unroll for (int n = 0; n < size<2>(S); ++n) { if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + if constexpr (Bool_to_Element) { + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = static_cast(S(i, m, n)) ? To_type(1) : To_type(0); + } + } else { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); } diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 1acdfdd..c4ac38d 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -126,89 +126,6 @@ def _flash_dmattn_forward_fake( _wrapped_flash_dmattn_forward = _flash_dmattn_forward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_forward", mutates_args=(), device_types="cuda") -def _flash_dmattn_varlen_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float = 0.0, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( - q, - k, - v, - mask, - bias, - None, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - zero_tensors, - is_causal, - softcap, - return_softmax, - ) - _sanitize_tensors(out) - return out, softmax_lse, S_dmask - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") -def _flash_dmattn_varlen_forward_fake( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float = 0.0, - return_softmax: bool = False, - block_table: Optional[torch.Tensor] = None, - leftpad_k: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - zero_tensors: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] - paged_kv = block_table is not None - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - out = torch.empty_like(q) - softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) - p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) - if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - return out, softmax_lse, p - - -_wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward - - @_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") def _flash_dmattn_backward( dout: torch.Tensor, @@ -294,108 +211,6 @@ def _flash_dmattn_backward_fake( _wrapped_flash_dmattn_backward = _flash_dmattn_backward -@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") -def _flash_dmattn_varlen_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] - ( - dq, - dk, - dv, - dbias, - softmax_d, - ) = flash_dmattn_gpu.varlen_bwd( - dout, - q, - k, - v, - mask, - bias, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - zero_tensors, - is_causal, - softcap, - deterministic, - ) - _sanitize_tensors(dq, dk, dv, dbias) - return softmax_d - - -@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_backward") -def _flash_dmattn_varlen_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dbias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - is_causal: bool, - softcap: float, - deterministic: bool, - zero_tensors: bool = False, -) -> torch.Tensor: - dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] - batch_size = cu_seqlens_q.numel() - 1 - total_q, num_heads, _ = q.shape - - if dq is None: - dq = torch.empty_like(q) - if dk is None: - dk = torch.empty_like(k) - if dv is None: - dv = torch.empty_like(v) - if dbias is None: - dbias = torch.empty_like(bias) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) - - return softmax_d - - -_wrapped_flash_dmattn_varlen_backward = _flash_dmattn_varlen_backward - - class FlashDMAttnFunc(torch.autograd.Function): @staticmethod def forward( @@ -413,17 +228,10 @@ def forward( is_grad_enabled: bool = True, ): # q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size) - batch_size, seqlen_k, num_heads_k, _ = k.shape - seqlen_q = q.shape[1] + seqlen_k = k.shape[1] is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] ) - if mask is None: - mask = torch.ones((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) - return_dbias = True - if bias is None: - return_dbias = False - bias = torch.zeros((batch_size, num_heads_k, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -444,9 +252,10 @@ def forward( if seqlen_k % 128 != 0: k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=0.0) - bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) - + if mask is not None: + mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) + if bias is not None: + bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -467,9 +276,9 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic - ctx.return_dbias = return_dbias out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod @@ -513,162 +322,8 @@ def backward( dk = dk[:, : ctx.seqlen_k, :, :] dv = dv[:, : ctx.seqlen_k, :, :] dbias = dbias[..., : ctx.seqlen_k] - if ctx.return_dbias: - return dq, dk, dv, None, dbias, None, None, None, None, None, None - return dq, dk, dv, None, None, None, None, None, None, None, None - -class FlashDMAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float], - is_causal: Optional[bool], - softcap: Optional[float], - deterministic: Optional[bool], - return_softmax: Optional[bool], - block_table: Optional[torch.Tensor] = None, - is_grad_enabled: bool = True, - ): - dtype = q.dtype - min_dtype = torch.finfo(dtype).min - # q, k, v are expected to be of shape (total, num_heads, head_size) - total_q = q.shape[0] - num_heads_k = k.shape[1] - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if mask is None: - mask = torch.ones((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) - return_dbias = True - if bias is None: - bias = torch.zeros((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device) - return_dbias = False - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if is_causal is None: - is_causal = False - if softcap is None: - softcap = 0.0 - if deterministic is None: - deterministic = True - if return_softmax is None: - return_softmax = False - - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - # FDMA requires the max sequence length to be a multiple of 128 - max_seqlen_q_og = max_seqlen_q - max_seqlen_k_og = max_seqlen_k - aligned_max_seqlen_q = ((max_seqlen_q + 128 - 1) // 128) * 128 - aligned_max_seqlen_k = ((max_seqlen_k + 128 - 1) // 128) * 128 - need_pad_q = aligned_max_seqlen_q != max_seqlen_q - need_pad_k = aligned_max_seqlen_k != max_seqlen_k - if need_pad_k: - pad_cols = aligned_max_seqlen_k - max_seqlen_k - mask = torch.nn.functional.pad(mask, [0, pad_cols], value=0.0) - bias = torch.nn.functional.pad(bias, [0, pad_cols], value=min_dtype) - max_seqlen_k = aligned_max_seqlen_k - if need_pad_q: - max_seqlen_q = aligned_max_seqlen_q - - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( - q, - k, - v, - mask, - bias, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - is_causal=is_causal, - softcap=softcap, - return_softmax=return_softmax, - block_table=block_table, - ) - - if is_grad: - ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k - ) - ctx.seqlen_q_og = max_seqlen_q_og - ctx.seqlen_k_og = max_seqlen_k_og - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.is_causal = is_causal - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.return_dbias = return_dbias - - out = out_padded[..., :head_size_og] - if return_softmax: - if max_seqlen_k != max_seqlen_k_og: - S_dmask = S_dmask[..., :max_seqlen_k_og] - return out, softmax_lse, S_dmask - return out - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - dout: torch.Tensor, - *args: Any, - ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) - - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - _wrapped_flash_dmattn_varlen_backward( - dout_padded, - q, - k, - v, - mask, - bias, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.softmax_scale, - ctx.is_causal, - ctx.softcap, - ctx.deterministic, - ) - - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - - if ctx.seqlen_k_og != ctx.max_seqlen_k: - dbias = dbias[:, :, :ctx.seqlen_k_og] - - if ctx.return_dbias: - return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, dbias, None, None, None, None, None, None def flash_dmattn_func( @@ -702,12 +357,14 @@ def flash_dmattn_func( If the row of the mask is all zero, the output will be zero. Arguments: - query: (batch_size, seqlen, nheads, headdim) - key: (batch_size, seqlen, nheads_k, headdim) - value: (batch_size, seqlen, nheads_k, headdim) - attn_mask: (batch_size, nheads_k, seqlen_q, seqlen_k). Attention mask to apply to the attention scores. + query: torch.Tensor. The query tensor of shape (batch_size, seqlen, nheads, headdim) + key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) + value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) + attn_mask: torch.Tensor, optional. The attention mask boolean tensor of + shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores. If None, no mask is applied. - attn_bias: (batch_size, nheads_k, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. + attn_bias: torch.Tensor, optional. The attention bias float tensor of + shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores. If None, no bias is applied. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. @@ -738,89 +395,3 @@ def flash_dmattn_func( return_attn_probs, torch.is_grad_enabled(), ) - - -def flash_dmattn_varlen_func( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_bias: Optional[torch.Tensor] = None, - cu_seqlens_q: torch.Tensor = None, - cu_seqlens_k: torch.Tensor = None, - max_seqlen_q: int = None, - max_seqlen_k: int = None, - scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, - deterministic: Optional[bool] = None, - return_attn_probs: Optional[bool] = None, - block_table: Optional[torch.Tensor] = None, -): - """ - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - Arguments: - query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - attn_mask: (total_q, nheads_k, max_seqlen_k). Attention mask to apply to the attention scores. - If None, no mask is applied. - attn_bias: (total_q, nheads_k, max_seqlen_k). Attention Bias to add to the attention scores. - If None, no bias is applied. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - softcap: float. Anything > 0 activates softcapping attention. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). - """ - return FlashDMAttnVarlenFunc.apply( - query, - key, - value, - attn_mask, - attn_bias, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - scale, - is_causal, - softcap, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - ) \ No newline at end of file diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index bf3ff7d..7d718e2 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -30,7 +30,7 @@ def flash_dynamic_mask_attention_forward( key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len). - attention_bias (Optional[torch.Tensor]): The attention bias tensor of shape (batch_size, num_kv_heads, query_len, key_len). + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len). scaling (Optional[float]): The scaling factor for the attention scores. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments. diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index 8f86542..e3a8a3b 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -93,17 +93,15 @@ def _flash_dynamic_mask_attention_forward( ~attention_mask, min_dtype ) - attention_mask = attention_mask.to(dtype) if keep_window_size is not None: if key_length > keep_window_size: topk_values, topk_indices = torch.topk( attention_bias, keep_window_size, dim=-1, largest=True, sorted=False ) - valid_topk = (topk_values != min_dtype).to(dtype) - attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) - attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) + attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) + attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal