diff --git a/flake.lock b/flake.lock index 15d19a1..494ac87 100644 --- a/flake.lock +++ b/flake.lock @@ -58,11 +58,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1759736557, - "narHash": "sha256-O+mWjaNlpFpf8VIOl9rAoSvix4Fq6t/wQ344tQvVOqM=", + "lastModified": 1762268370, + "narHash": "sha256-gf3TJcaiHdw3dvLL7RF6hc/5BLzQDQj5oakFrKZkOZo=", "owner": "huggingface", "repo": "hf-nix", - "rev": "a887e0a831c37663426cd0f084d6c5d79a495fa8", + "rev": "25c23c765a907d1a5528c5ce65c58a73e974603f", "type": "github" }, "original": { diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index 74b4e4f..64c8777 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -121,16 +121,22 @@ backend = "xpu" src = [ "flash_attn_xpu/flash_api.cpp", - "flash_attn_xpu/src/prefill.hpp", + "flash_attn_xpu/src/fixed.hpp", + "flash_attn_xpu/src/varlen.hpp", "flash_attn_xpu/src/fmha_utils.hpp", "flash_attn_xpu/src/compat_wrapper.hpp", "flash_attn_xpu/src/collective/fmha_fusion.hpp", - "flash_attn_xpu/src/collective/xe_flash_attn_prefill_epilogue.hpp", - "flash_attn_xpu/src/collective/xe_flash_attn_prefill_mma.hpp", - "flash_attn_xpu/src/collective/xe_flash_attn_prefill_softmax_epilogue.hpp", + "flash_attn_xpu/src/collective/fixed_epilogue.hpp", + "flash_attn_xpu/src/collective/fixed_mma.hpp", + "flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp", + "flash_attn_xpu/src/collective/varlen_epilogue.hpp", + "flash_attn_xpu/src/collective/varlen_mma.hpp", + "flash_attn_xpu/src/collective/varlen_softmax_epilogue.hpp", - "flash_attn_xpu/src/kernel/tile_scheduler.hpp", - "flash_attn_xpu/src/kernel/xe_flash_attn_prefill.hpp", + "flash_attn_xpu/src/kernel/fixed_scheduler.hpp", + "flash_attn_xpu/src/kernel/fixed_kernel.hpp", + "flash_attn_xpu/src/kernel/varlen_scheduler.hpp", + "flash_attn_xpu/src/kernel/varlen_kernel.hpp", ] depends = ["torch", "cutlass_sycl"] \ No newline at end of file diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 539e30d..153b9a2 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -1,15 +1,19 @@ #include -#include "src/prefill.hpp" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#include "src/fixed.hpp" +#include "src/varlen.hpp" namespace FLASH_NAMESPACE { inline int round_multiple(int x, int m) { - return (x + m - 1) / m * m; + int pad_res = (x + m - 1) / m * m; + if (pad_res == 224) + pad_res = 256; + return pad_res; +} + +inline at::Tensor ensure_contiguous(const at::Tensor& tensor) { + return tensor.is_contiguous() ? tensor : tensor.contiguous(); } std::vector @@ -32,6 +36,7 @@ mha_fwd( COMPAT::select_device(device_idx); // check inputs + q = ensure_contiguous(q); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; @@ -40,9 +45,11 @@ mha_fwd( const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); // XPU requires head_size to be a multiple of 32 const int head_size_padded = round_multiple(head_size_og, 32); @@ -72,7 +79,10 @@ mha_fwd( out_padded = torch::zeros_like(q_padded); } - cutlass_prefill_fixed_impl(q_padded, k_padded, v_padded, out_padded, softmax_scale, is_causal); + q_padded = ensure_contiguous(q_padded); + k_padded = ensure_contiguous(k_padded); + v_padded = ensure_contiguous(v_padded); + cutlass::flash_attention::fixed::cutlass_fixed_impl(q_padded, k_padded, v_padded, out_padded, softmax_scale, is_causal); // Remove padding from output at::Tensor out = out_padded; @@ -80,7 +90,7 @@ mha_fwd( out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}); } - out = out.contiguous(); + out = ensure_contiguous(out); // TODO: current do not support store softmax_lse out // hard code to return empty tensor for softmax_lse, S_dmask, rng_state @@ -118,16 +128,7 @@ mha_varlen_fwd( COMPAT::select_device(device_idx); // check inputs - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - CHECK_DEVICE(cu_seqlens_q); - CHECK_DEVICE(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_q.dim() == 1, "cu_seqlens_q must be 1-dimensional, but got ", cu_seqlens_q.dim(), " dimensions"); - TORCH_CHECK(cu_seqlens_k.dim() == 1, "cu_seqlens_k must be 1-dimensional, but got ", cu_seqlens_k.dim(), " dimensions"); - CHECK_CONTIGUOUS(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_k); - - // Extract dimensions + q = ensure_contiguous(q); const auto sizes = q.sizes(); const int total_q = sizes[0]; const int num_heads = sizes[1]; @@ -136,11 +137,11 @@ mha_varlen_fwd( const int num_heads_k = k.size(1); const int batch_size = cu_seqlens_q.numel() - 1; - CHECK_SHAPE(q, total_q, num_heads, head_size_og); - CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); // XPU requires head_size to be a multiple of 32 const int head_size_padded = round_multiple(head_size_og, 32); @@ -169,8 +170,11 @@ mha_varlen_fwd( } else { out_padded = torch::zeros_like(q_padded); } - - cutlass_prefill_varlen_impl(q_padded, k_padded, v_padded, out_padded, + + q_padded = ensure_contiguous(q_padded); + k_padded = ensure_contiguous(k_padded); + v_padded = ensure_contiguous(v_padded); + cutlass::flash_attention::varlen::cutlass_varlen_impl(q_padded, k_padded, v_padded, out_padded, block_table_, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal); @@ -181,7 +185,7 @@ mha_varlen_fwd( out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}); } - out = out.contiguous(); + out = ensure_contiguous(out); // TODO: current do not support store softmax_lse out // hard code to return empty tensor for softmax_lse, S_dmask, rng_state diff --git a/flash-attn2/flash_attn_xpu/src/collective/xe_flash_attn_prefill_epilogue.hpp b/flash-attn2/flash_attn_xpu/src/collective/fixed_epilogue.hpp similarity index 100% rename from flash-attn2/flash_attn_xpu/src/collective/xe_flash_attn_prefill_epilogue.hpp rename to flash-attn2/flash_attn_xpu/src/collective/fixed_epilogue.hpp diff --git a/flash-attn2/flash_attn_xpu/src/collective/xe_flash_attn_prefill_mma.hpp b/flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp similarity index 100% rename from flash-attn2/flash_attn_xpu/src/collective/xe_flash_attn_prefill_mma.hpp rename to flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp diff --git a/flash-attn2/flash_attn_xpu/src/collective/xe_flash_attn_prefill_softmax_epilogue.hpp b/flash-attn2/flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp similarity index 100% rename from flash-attn2/flash_attn_xpu/src/collective/xe_flash_attn_prefill_softmax_epilogue.hpp rename to flash-attn2/flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp diff --git a/flash-attn2/flash_attn_xpu/src/collective/varlen_epilogue.hpp b/flash-attn2/flash_attn_xpu/src/collective/varlen_epilogue.hpp new file mode 100644 index 0000000..9923e0b --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/collective/varlen_epilogue.hpp @@ -0,0 +1,316 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class FlashChunkPrefillEpilogue { + static_assert(cutlass::detail::dependent_false, + "Could not find an epilogue specialization."); +}; + +template +class FlashChunkPrefillEpilogue< + epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, + ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> { + public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using ElementO = ElementO_; + using StrideO = StrideO_; + using ElementLSE = ElementLSE_; + using CopyOpO = CopyOpO_; + using SubgroupLayout = SubgroupLayout_; + using TileShapeOutput = TileShapeOutput_; + using TiledMmaOutput = + typename TiledMMAHelper, Layout, + SubgroupLayout>::TiledMMA; + using GmemTiledCopyO = CopyOpO; + using ElementOutput = ElementO_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementCompute_; + using SubgroupTileShape = + decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape()))); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert( + cute::rank(TileShapeOutput{}) == 3, + "TileShapeOutput must be rank-3: [CTA_M_QO, CTA_N_VO, CTA_K_PV]"); + static_assert( + cute::rank(StrideO{}) == 3, + "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]"); + + using CopyThreadShape = Shape<_1, Int>; + + using traits_store_O = Copy_Traits; + using atom_load_O = Copy_Atom; + using val_layout_load_O = decltype(make_layout( + shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_O = decltype(make_tiled_copy( + atom_load_O{}, Layout{}, val_layout_load_O{})); + + private: + constexpr static bool is_destination_supported = + not cute::is_void_v; + + public: + using EmptyType = cute::tuple<>; + + struct TensorStorageImpl : cute::tuple {}; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + ElementO const* ptr_O; + StrideO dO; + }; + + // Device side epilogue params + struct Params { + XE_Copy_O xe_store_o; + }; + + // + // Methods + // + template + CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + [[maybe_unused]] void* workspace) { + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv_cache, + head_size_qk, head_size_vo] = problem_shape; + auto q_group_size = num_heads_q / num_heads_kv; + auto q_group_num = num_heads_q / q_group_size; + auto tensorO = + make_tensor(make_gmem_ptr(static_cast(args.ptr_O)), + make_layout(make_shape(seq_len_qo * q_group_size, + head_size_vo, batch * q_group_num), + args.dO)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return { + xe_store_o, + }; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, + Arguments const& args) { + return 0; + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, Arguments const& args, void* workspace, + cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillEpilogue(Params const& params_, TensorStorage const&) + : params(params_) {} + + template + CUTLASS_DEVICE void operator()(ProblemShape problem_shape, + SequenceLengthShape sequence_length_shape, + TileCoord tile_coord, FragOut& out, + FragMax const& max, FragSum& sum) { + using namespace cute; + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<2, ProblemShape>>; + + using FragOutLayout = typename FragOut::layout_type; + + constexpr int Vec = shape<0>(FragOutLayout{}); + constexpr int FragsM = shape<1>(FragOutLayout{}); + constexpr int FragsN = size(select<2, 3>(shape(FragOutLayout{}))); + + auto sg = COMPAT::get_nd_item<1>().get_sub_group(); + auto out_reg = make_tensor(static_cast(out).data(), + Shape, Int, Int>{}); + + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragsM; y++) { + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < Vec; x++) { + int index = y * Vec + x; + auto cur_sum = reduce_over_group(sg, sum(index), sycl::plus<>()); + if (cur_sum == 0.f || cur_sum != cur_sum) { + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + out_reg(x, y, z) = ElementCompute{0}; + } + } else { + auto cur_scale = sycl::native::recip(cur_sum); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + out_reg(x, y, z) *= cur_scale; + } + } + } + } + + // Indexing variables + auto [batch, num_heads_q, num_heads_kv, head_size_vo] = + select<0, 1, 2, 6>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + // Represent the full output tensor + auto q_group_size = num_heads_q / num_heads_kv; + auto q_group_nums = num_heads_q / q_group_size; + // Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo * q_group_size + // , head_size_vo, (is_var_len ? 1: batch) * q_group_nums)); + Tensor mO_mnl = + cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1)); + + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + // Tile the output tensor per WG + Tensor g_wg_O = + local_tile(mO_mnl, select<0, 1>(TileShapeOutput{}), + make_coord(m_coord, n_coord, 0)); // (BLK_M,BLK_N,m,n,l) + static constexpr auto ATOM_N = + get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape()); + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + // Tile the output tensor per SG + Tensor gO = + local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg, n_sg, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX()); + Tensor tOgO = thread_xe_store_o.partition_D(gO); + + Tensor final_out_reg = make_fragment_like(out_reg); + // iff ElementOutput == ElementAccumulator, then convert_type doesn't do the + // right conversion so we call copy() which internally performs a + // static_cast op on the data. for ElementOutput == bf16 | fp16, + // convert_type calls relevant NumericConverter specialization. + if constexpr (cute::is_same_v) { + copy(out_reg, final_out_reg); + } else { + Tensor temp = convert_type(out_reg); + copy(temp, final_out_reg); + } + copy(params.xe_store_o, final_out_reg, tOgO); + } + + // SequenceLengthShapeType = Shape + // For Fixed Sequence Length, ProblemShapeType = Shape For Variable Sequence Length, ProblemShapeType = + // Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, ProblemShapeType const& problem_shape, + SequenceLengthShapeType const& sequence_length_shape, int const& l_coord, + int const& q_head_coord) { + auto [num_heads_q, num_heads_kv, head_size_vo] = + select<1, 2, 6>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + int offset_o = 0; + if constexpr (VarLen) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord] + + q_head_coord * head_size_vo; + } else { + offset_o = num_heads_q * head_size_vo * seq_len_qo * l_coord + + q_head_coord * head_size_vo; + } + auto store_traits = static_cast(params.xe_store_o); + ElementO* base_ptr = (ElementO*)store_traits.base_ptr; + auto shape_o = + make_shape(static_cast(seq_len_qo), num_heads_q * head_size_vo, 1); + StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o); + auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), + make_layout(shape_o, stride_o)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return Params{xe_store_o}; + } + + private: + Params const& params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/flash-attn2/flash_attn_xpu/src/collective/varlen_mma.hpp b/flash-attn2/flash_attn_xpu/src/collective/varlen_mma.hpp new file mode 100644 index 0000000..b275742 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/collective/varlen_mma.hpp @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "fmha_fusion.hpp" + +//////////////////////////////////////////////////////////// +namespace {} +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::flash_attention::collective { +using namespace cute; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FlashChunkPrefillMma { + static_assert(cutlass::detail::dependent_false, + "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FlashChunkPrefillMma< + gemm::MainloopIntelXeXMX16, ProblemShapeType_, ElementQ_, StrideQ_, + ElementK_, StrideK_, ElementV_, StrideV_, MMAOperation_, TileShapeQK_, + TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, + GmemTiledCopyV_, CausalMask_, LocalMask_, PagedKV_> { + // + // Type Aliases + // + using DispatchPolicy = gemm::MainloopIntelXeXMX16; + using TileShapeQK = TileShapeQK_; + using TileShapePV = TileShapePV_; + using SubgroupLayout = SubgroupLayout_; + using ProblemShapeType = ProblemShapeType_; + using ElementQ = ElementQ_; + using StrideQ = StrideQ_; + using ElementK = ElementK_; + using StrideK = StrideK_; + using ElementV = ElementV_; + using StrideV = StrideV_; + using GmemTiledCopyQ = GmemTiledCopyQ_; + using GmemTiledCopyK = GmemTiledCopyK_; + using GmemTiledCopyV = GmemTiledCopyV_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaAtom = MMA_Atom; + + using TiledMmaQK = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; + + using TiledMmaPV = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; + using ElementAccumulator = typename TiledMmaQK::ValTypeC; + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + static constexpr bool PagedKV = PagedKV_; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename MmaAtom::Shape_MNK; + + static constexpr auto PV_ATOM_M = + decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = + decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = + decltype(get<2>(SubgroupLayout{}.shape()))::value; + + using SubgroupTileShapePV = + decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); + static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); + static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); + + // This TiledMma is only required to serve the specific tiling requirements + // for matrix K. This is due to the consumption of matrix K by all subgroups + // within a workgroup. + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + + using SubgroupTileShapeQK = decltype(cute::shape_div( + TileShapeQK{}, + SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + + static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<3, ProblemShapeType>>; + + using FragsShapeS = decltype(cute::shape_div( + take<0, 2>(SubgroupTileShapeQK{}), + take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + static constexpr int FragsM = get<0>(FragsShapeS{}); + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + + static constexpr uint32_t MaxThreadsPerBlock = + size(SubgroupLayout{}) * SubgroupSize; + using CopyThreadShape = Shape<_1, Int>; + + using traits_load_Q = Copy_Traits; + using atom_load_Q = Copy_Atom; + using val_layout_load_Q = decltype(make_layout( + shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy( + atom_load_Q{}, Layout{}, val_layout_load_Q{})); + + using traits_load_K = Copy_Traits; + using atom_load_K = Copy_Atom; + using val_layout_load_K = decltype(make_layout( + shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy( + atom_load_K{}, Layout{}, val_layout_load_K{})); + + using traits_load_V = Copy_Traits; + using atom_load_V = Copy_Atom; + using val_layout_load_V = decltype(make_layout( + shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy( + atom_load_V{}, Layout{}, val_layout_load_V{})); + + // Host side kernel arguments + struct Arguments { + ElementQ const* ptr_Q; + StrideQ dQ; + ElementK const* ptr_K_cache; + StrideK dK_cache; + ElementV const* ptr_V_cache; + StrideV dV_cache; + // Paged KV Cache + int const* ptr_page_table; + int page_size; + int const max_pages_per_seq; + int const total_seqlen_k; + int window_left; + int window_right; + }; + + struct Params { + XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k_cache; + XE_Copy_V gmem_tiled_copy_v_cache; + // Paged KV Cache + int const* ptr_page_table; + int page_size; + int const max_pages_per_seq; + int const total_seqlen_k; + int window_left; + int window_right; + }; + + // + // Methods + // + + FlashChunkPrefillMma() = default; + + static constexpr Params to_underlying_arguments( + ProblemShapeType const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv_cache, + head_size_qk, head_size_vo] = problem_shape; + auto q_group_size = num_heads_q / num_heads_kv; + + auto tensorQ = make_tensor( + make_gmem_ptr(args.ptr_Q), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), + args.dQ)); + auto tensorK_cache = make_tensor( + make_gmem_ptr(args.ptr_K_cache), + make_layout( + make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch), + args.dK_cache)); + auto tensorV_cache = make_tensor( + make_gmem_ptr(args.ptr_V_cache), + make_layout( + make_shape(head_size_vo * num_heads_kv, seq_len_kv_cache, batch), + args.dV_cache)); + + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + + return Params{copyQ, + copyK_cache, + copyV_cache, + args.ptr_page_table, + args.page_size, + args.max_pages_per_seq, + args.total_seqlen_k, + args.window_left, + args.window_right}; + } + + template + CUTLASS_DEVICE void mmaQK(FragQccum& accum, TensorQ gQ, TensorK gK, + FragSrc const& frag_src, int const& k_tile_count, + Params const& params) { + auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; + + int thread_idx = static_cast(ThreadIdxX()); + auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); + auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx); + // Instantiate the MMA object + TiledMmaQK tiled_mma; + // To make all threads in a warp have the same global tensors pass in the + // index of thread 0 in each warp + auto sg = COMPAT::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); + auto thread_mma_k = tiled_mma.get_slice(0); + + Tensor tCgQ = thread_mma_q.partition_A(gQ); + Tensor tCgK = thread_mma_k.partition_B(gK); + + // Create fragments + // TODO(Codeplay): fix this, this is probably not general + Tensor tCrQ = make_tensor(make_fragment_layout( + params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); + Tensor tCrK = make_tensor( + make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); + + // Retile registers for copies + Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); + Tensor tKrK = thr_copy_K.retile_D(tCrK); + + // Retile global tile for copies + Tensor tQgQ = thr_copy_Q.retile_S(tCgQ); + Tensor tKgK = thr_copy_K.retile_S(tCgK); + + // + // Mainloop + // + + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); + copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); +#if 0 + #define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(0, 0)) { + print("======================= Q: \n"); + PRINT(gQ); + PRINT(tCrQ); + PRINT(tCgQ); + PRINT(tQrQ); + PRINT(tQgQ); + + print("===================== K :\n"); + PRINT(gK); + PRINT(tCrK); + PRINT(tCgK); + PRINT(tKrK); + PRINT(tKgK); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapeQK{}); + } + #undef PRINT +#endif + } + } + + template + CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + CUTLASS_DEVICE void mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, + FragSrc const& frag_src, Params const& params) { + auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache; + + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMmaPV tiled_mma; + // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid + // Register spill + Tensor gV_ = take<0, 3>( + local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + auto sg = COMPAT::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + Tensor tCgV = thread_mma.partition_B(gV_); + Tensor tCrV = make_tensor( + make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); + + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); + Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV); + Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + #define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("===================== V :\n"); + PRINT(gV); + PRINT(tCrV); + PRINT(tCgV); + PRINT(tVrV); + PRINT(tVgV); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapePV{}); + } + #undef PRINT +#endif + + // 7) Convert S to P (FP32 -> BF16) + Tensor tPr = convert_type(tSr); + // + // Mainloop + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tile_count; i++) { + copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); + // if (cute::thread(0, 0)) { + // print("V:\n"); + // print_tensor(tVrV); + // } + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + } + } + + // SequenceLengthShape = Shape + // For Variable Sequence Length, ProblemShape = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, ProblemShape const& problem_shape, + SequenceLengthShape const& sequence_length_shape, int const& l_coord, + int const& q_head_coord = 0) { + auto [num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = + select<1, 2, 5, 6>(problem_shape); + auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; + if constexpr (PagedKV) { + seq_len_kv_cache = params.total_seqlen_k; + } + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, + offset_v_cache = 0; + if constexpr (is_var_len) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + auto kv_cached_cumulative_length = + get<4>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + + q_head_coord * head_size_qk; + offset_k_cache = seq_len_kv_cache == 0 ? 0 + : PagedKV + ? kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * + kv_cached_cumulative_length[l_coord] + + kv_head_coord * head_size_qk; + offset_v_cache = seq_len_kv_cache == 0 ? 0 + : PagedKV + ? kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * + kv_cached_cumulative_length[l_coord] + + kv_head_coord * head_size_vo; + } else { + offset_q = num_heads_q /*q_group_nums * q_group_size*/ * head_size_qk * + seq_len_qo * l_coord + + q_head_coord * head_size_qk; + offset_k_cache = + seq_len_kv_cache == 0 ? 0 + : PagedKV ? kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * seq_len_kv_cache * l_coord + + kv_head_coord * head_size_qk; + offset_v_cache = + seq_len_kv_cache == 0 ? 0 + : PagedKV ? kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * seq_len_kv_cache * l_coord + + kv_head_coord * head_size_vo; + } + + auto q_traits = static_cast(params.gmem_tiled_copy_q); + const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr; + auto k_traits_cache = + static_cast(params.gmem_tiled_copy_k_cache); + const ElementK* k_cache_ptr = (const ElementK*)k_traits_cache.base_ptr; + auto v_traits_cache = + static_cast(params.gmem_tiled_copy_v_cache); + const ElementV* v_cache_ptr = (const ElementV*)v_traits_cache.base_ptr; + + auto shape_q = + make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); + auto shape_k_cache = make_shape(static_cast(seq_len_kv_cache), + head_size_qk * num_heads_kv, 1); + StrideK stride_k_cache = + cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + auto shape_v_cache = make_shape(head_size_vo * num_heads_kv, + static_cast(seq_len_kv_cache), 1); + StrideV stride_v_cache = + cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), + make_layout(shape_q, stride_q)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), + make_layout(shape_k_cache, stride_k_cache)); + auto tensorV_cache = + make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), + make_layout(shape_v_cache, stride_v_cache)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + return Params{copyQ, + copyK_cache, + copyV_cache, + params.ptr_page_table, + params.page_size, + params.max_pages_per_seq, + params.total_seqlen_k, + params.window_left, + params.window_right}; + } +}; + +} // namespace cutlass::flash_attention::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/flash-attn2/flash_attn_xpu/src/collective/varlen_softmax_epilogue.hpp b/flash-attn2/flash_attn_xpu/src/collective/varlen_softmax_epilogue.hpp new file mode 100644 index 0000000..0353662 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/collective/varlen_softmax_epilogue.hpp @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing online softmax. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class FlashChunkPrefillSoftmaxEpilogue { + static_assert(cutlass::detail::dependent_false, + "Could not find an epilogue specialization."); +}; + +template +class FlashChunkPrefillSoftmaxEpilogue { + public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using Element = Element_; + + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + + using GmemTiledCopyOut = void; + + // Host side epilogue arguments + struct Arguments { + Element const scale; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + static constexpr Params to_underlying_arguments(Arguments const& args) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + Element val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + template + static size_t get_workspace_size() { + return 0; + } + + template + static cutlass::Status initialize_workspace() { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement() { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillSoftmaxEpilogue(Params const& params_) : params(params_) {} + + template + CUTLASS_DEVICE void scale_exp_log2(FragAcc& frag_s, FragMax const& max, + FragSum& sum) { + auto g = COMPAT::get_nd_item<1>().get_sub_group(); + const auto max_scale = max * params.scale; + CUTLASS_PRAGMA_UNROLL + for (int index = 0; index < Vec * FragsM; index++) { + const auto max_scale_bcast = group_broadcast(g, max_scale, index); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_index = index + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_index)) && frag_s(base_index) < 0)) { + frag_s(base_index) = 0.f; + // continue; + } else { + Element eq = frag_s(base_index) - max_scale_bcast; + frag_s(base_index) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_index) - max_scale_bcast; + frag_s(base_index) = sycl::native::exp2(eq); + } + sum(index) += frag_s(base_index); + } + } + } + + template + CUTLASS_DEVICE void reduce_max(FragSrc& src, FragMax& max) { + auto sg = COMPAT::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for (int index = 0; index < Vec * FragsM; index++) { + auto maxptr = group_broadcast(sg, max, index); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_index = index + (z * Vec * FragsM); + maxptr = sycl::max(maxptr, src(base_index)); + src(base_index) *= params.scale; + } + maxptr = reduce_over_group(sg, maxptr, sycl::maximum<>()); + if (index == sg.get_local_id()[0]) { + max = maxptr; + } + } + } + + template + CUTLASS_DEVICE void operator()(bool is_first, FragAcc& frag_s, FragMax& max, + FragSum& sum, FragOut& out) { + auto max_prev = max; + using FragAccLayout = typename FragAcc::layout_type; + using FragOutLayout = typename FragOut::layout_type; + constexpr int Vec = get<0>(FragAccLayout{}.shape()); + constexpr int FragsM = get<1>(FragAccLayout{}.shape()); + constexpr int FragsNAcc = get<2>(FragAccLayout{}.shape()); + constexpr int FragsNOut = size(select<2, 3>(FragOutLayout{}.shape())); + reduce_max(frag_s, max); + // if (max == INFINITY) { + // max = 0.f; + // } + static_assert(Vec * FragsM % 8 == 0, + " No. of attention rows per subgroup should be >= 1 MMA Atom " + "worth of rows."); + if (!is_first) { + auto sg = COMPAT::get_nd_item<1>().get_sub_group(); + Element max_scale{max * params.scale}; + Element exp_scale; + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || + (std::isinf(max_prev) && max_prev < 0)) { + exp_scale = 0.f; + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + + CUTLASS_PRAGMA_UNROLL + for (int index = 0; index < Vec * FragsM; index++) { + auto max_scale_bcast = group_broadcast(sg, max_scale, index); + auto exp_scale_bcast = group_broadcast(sg, exp_scale, index); + sum(index) *= exp_scale_bcast; + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNAcc; z++) { + auto base_index = index + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_index)) && frag_s(base_index) < 0)) { + frag_s(base_index) = 0.f; + // continue; + } else { + Element eq = frag_s(base_index) - max_scale_bcast; + frag_s(base_index) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_index) - max_scale_bcast; + // eq = eq < -65400.f ? 0.f : eq; + frag_s(base_index) = sycl::native::exp2(eq); + } + sum(index) += frag_s(base_index); + } + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNOut; z++) { + auto base_index = index + (z * Vec * FragsM); + out(base_index) *= exp_scale_bcast; + } + } + } else { + scale_exp_log2(frag_s, max, sum); + } + } + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/flash-attn2/flash_attn_xpu/src/prefill.hpp b/flash-attn2/flash_attn_xpu/src/fixed.hpp similarity index 56% rename from flash-attn2/flash_attn_xpu/src/prefill.hpp rename to flash-attn2/flash_attn_xpu/src/fixed.hpp index 89766a5..556336d 100644 --- a/flash-attn2/flash_attn_xpu/src/prefill.hpp +++ b/flash-attn2/flash_attn_xpu/src/fixed.hpp @@ -7,18 +7,20 @@ #include #include "./compat_wrapper.hpp" -#include "./kernel/tile_scheduler.hpp" -#include "./kernel/xe_flash_attn_prefill.hpp" +#include "./kernel/fixed_scheduler.hpp" +#include "./kernel/fixed_kernel.hpp" #include "./collective/fmha_fusion.hpp" -#include "./collective/xe_flash_attn_prefill_epilogue.hpp" -#include "./collective/xe_flash_attn_prefill_softmax_epilogue.hpp" +#include "./collective/fixed_epilogue.hpp" +#include "./collective/fixed_softmax_epilogue.hpp" #include "fmha_utils.hpp" +namespace cutlass::flash_attention::fixed { + using namespace cute; -// Base structure for common arguments -struct prefill_args_base_t { +// Fixed length specific arguments +struct prefill_args_fixed_t { void* query; void* key; void* value; @@ -29,25 +31,11 @@ struct prefill_args_base_t { int head_size; bool is_causal; int batch_size; -}; - -// Variable length specific arguments -struct prefill_args_varlen_t : public prefill_args_base_t { - void* cu_seqlens_q; - void* cu_seqlens_k; - int max_seqlen_q; - int max_seqlen_k; - int total_seqlen_q; - int total_seqlen_k; -}; - -// Fixed length (non-varlen) specific arguments -struct prefill_args_fixed_t : public prefill_args_base_t { int seq_len_q; int seq_len_k; }; -template +template struct KernelLauncher { using StrideQ = typename FMHAPrefillKernel::StrideQ; using StrideK = typename FMHAPrefillKernel::StrideK; @@ -69,37 +57,7 @@ struct KernelLauncher { StrideV stride_V; StrideO stride_O; - // Specialization for variable length - template - typename std::enable_if_t - initialize(const prefill_args_varlen_t& args) { - auto problem_shape_out = cute::make_tuple( - args.batch_size, args.num_heads_q, args.num_heads_kv, - cutlass::fmha::collective::VariableLength{args.max_seqlen_q, nullptr}, // cu_q - cutlass::fmha::collective::VariableLength{args.max_seqlen_k, nullptr}, // cu_kv - args.head_size, args.head_size); - - stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, - cute::make_shape(args.total_seqlen_q, args.head_size, args.num_heads_q)); - stride_K = cutlass::make_cute_packed_stride(StrideK{}, - cute::make_shape(args.total_seqlen_k, args.head_size, args.num_heads_kv)); - stride_V = cutlass::make_cute_packed_stride(StrideV{}, - cute::make_shape(args.head_size, args.total_seqlen_k, args.num_heads_kv)); - stride_O = cutlass::make_cute_packed_stride(StrideO{}, - cute::make_shape(args.total_seqlen_q, args.head_size, args.num_heads_q)); - - cute::get<3>(problem_shape_out).cumulative_length = - reinterpret_cast(args.cu_seqlens_q); - cute::get<4>(problem_shape_out).cumulative_length = - reinterpret_cast(args.cu_seqlens_k); - - return problem_shape_out; - } - - // Specialization for fixed length - template - typename std::enable_if_t - initialize(const prefill_args_fixed_t& args) { + ProblemShapeType initialize(const prefill_args_fixed_t& args) { auto problem_shape = cute::make_tuple( args.batch_size, args.num_heads_q, args.num_heads_kv, args.seq_len_q, args.seq_len_k, args.head_size, args.head_size); @@ -116,30 +74,7 @@ struct KernelLauncher { return problem_shape; } - // Run function for variable length - template - typename std::enable_if_t - run(const prefill_args_varlen_t& args, const cutlass::KernelHardwareInfo& hw_info) { - ProblemShapeType problem_size = initialize(args); - - typename FMHAPrefillKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {reinterpret_cast(args.query), stride_Q, - reinterpret_cast(args.key), stride_K, - reinterpret_cast(args.value), stride_V, - }, // window_left, window_right for local mask (not supported currently) - {args.softmax_scale}, - {reinterpret_cast(args.out), stride_O}, - hw_info}; - - return run_kernel(arguments); - } - - // Run function for fixed length - template - typename std::enable_if_t - run(const prefill_args_fixed_t& args, const cutlass::KernelHardwareInfo& hw_info) { + cutlass::Status run(const prefill_args_fixed_t& args, const cutlass::KernelHardwareInfo& hw_info) { ProblemShapeType problem_size = initialize(args); typename FMHAPrefillKernel::Arguments arguments{ @@ -230,8 +165,8 @@ template struct FMHAKernel { - template - static void run_impl(const ArgsType& args) { + template + static void run_impl(const prefill_args_fixed_t& args) { cutlass::KernelHardwareInfo hw_info; using LayoutQ = cutlass::layout::RowMajor; @@ -253,12 +188,7 @@ struct FMHAKernel { cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue< Causal, EpilogueDispatchPolicy, ElementAccumulator>; - using ProblemShape = typename std::conditional< - std::is_same::value, - cute::tuple, - cute::tuple - >::type; + using ProblemShape = cute::tuple; using CollectiveMainloop = cutlass::flash_attention::collective::FlashPrefillMma< @@ -273,23 +203,21 @@ struct FMHAKernel { ProblemShape, CollectiveMainloop, CollectiveSoftmaxEpilogue, CollectiveEpilogue, Scheduler>; - constexpr bool isVarLen = std::is_same::value; - KernelLauncher launcher; + KernelLauncher launcher; launcher.run(args, hw_info); } - template - static void dispatch(const ArgsType& args) { + static void dispatch(const prefill_args_fixed_t& args) { if (args.is_causal) { - run_impl(args); + run_impl(args); } else { - run_impl(args); + run_impl(args); } } }; -template -void policy_dispatch(CutlassType cuType, const ArgsType& args) { +template +void policy_dispatch(CutlassType cuType, const prefill_args_fixed_t& args) { constexpr int PipelineStages = 2; if (cuType == CutlassType::half) { @@ -309,62 +237,26 @@ void policy_dispatch(CutlassType cuType, const ArgsType& args) { } } -class TensorRearranger { -public: - // [total_seq, heads, head_size] -> [total_seq * heads, head_size] - static void to_block_layout( - const at::Tensor& input, at::Tensor& output, - const at::Tensor& cu_seqlens, int batch_size, int num_heads) { - - int offset = 0; - for (int b = 0; b < batch_size; ++b) { - const int start = cu_seqlens[b].item(); - const int end = cu_seqlens[b + 1].item(); - const int seq_len = end - start; - - for (int h = 0; h < num_heads; ++h) { - output.slice(0, offset, offset + seq_len).copy_( - input.slice(0, start, end).select(1, h)); - offset += seq_len; - } - } - } - - // [total_seq * heads, head_size] -> [total_seq, heads, head_size] - static void from_block_layout( - const at::Tensor& input, at::Tensor& output, - const at::Tensor& cu_seqlens, int batch_size, int num_heads) { - - int offset = 0; - for (int b = 0; b < batch_size; ++b) { - const int start = cu_seqlens[b].item(); - const int end = cu_seqlens[b + 1].item(); - const int seq_len = end - start; - - for (int h = 0; h < num_heads; ++h) { - output.slice(0, start, end).select(1, h).copy_( - input.slice(0, offset, offset + seq_len)); - offset += seq_len; - } - } - } -}; - -template -void dispatch_by_head_size(CutlassType cuType, const ArgsType& args) { +void dispatch_by_head_size(CutlassType cuType, const prefill_args_fixed_t& args) { const int h = args.head_size; - if (h <= 64) { + if (h <= 32) { + policy_dispatch(cuType, args); + } + else if (h <= 64) { policy_dispatch(cuType, args); - } + } else if (h <= 96) { policy_dispatch(cuType, args); - } + } else if (h <= 128) { policy_dispatch(cuType, args); - } + } + else if (h <= 160) { + policy_dispatch(cuType, args); + } else if (h <= 192) { policy_dispatch(cuType, args); - } + } else if (h <= 256) { policy_dispatch(cuType, args); } @@ -373,52 +265,8 @@ void dispatch_by_head_size(CutlassType cuType, const ArgsType& args) { } } -// Variable length implementation -void cutlass_prefill_varlen_impl( - const at::Tensor& query, // [total_seq_q, heads, head_size] B*S, H, D - const at::Tensor& key, // [total_seq_k, heads, head_size] - const at::Tensor& value, // [total_seq_k, heads, head_size] - at::Tensor& out, // [total_seq_q, heads, head_size] - const at::Tensor& cu_seqlens_q, - const at::Tensor& cu_seqlens_k, - int max_seqlen_q, int max_seqlen_k, - double softmax_scale, bool is_causal) { - - int num_heads_q = query.size(1); - int num_heads_kv = key.size(1); - int head_size = query.size(2); - int batch_size = cu_seqlens_q.numel() - 1; - int total_seqlen_q = query.size(0); - int total_seqlen_k = key.size(0); - - auto cu_q = cu_seqlens_q.to(torch::kInt32); - auto cu_k = cu_seqlens_k.to(torch::kInt32); - - // Create block layouts - auto q_block = torch::empty({total_seqlen_q * num_heads_q, head_size}, query.options()); - auto k_block = torch::empty({total_seqlen_k * num_heads_kv, head_size}, key.options()); - auto v_block = torch::empty({total_seqlen_k * num_heads_kv, head_size}, value.options()); - auto out_block = torch::empty({total_seqlen_q * num_heads_q, head_size}, query.options()); - - // Rearrange tensors - TensorRearranger::to_block_layout(query, q_block, cu_q, batch_size, num_heads_q); - TensorRearranger::to_block_layout(key, k_block, cu_k, batch_size, num_heads_kv); - TensorRearranger::to_block_layout(value, v_block, cu_k, batch_size, num_heads_kv); - - // Prepare arguments - prefill_args_varlen_t args{ - {q_block.data_ptr(), k_block.data_ptr(), v_block.data_ptr(), out_block.data_ptr(), - static_cast(softmax_scale), num_heads_q, num_heads_kv, head_size, is_causal, batch_size}, - cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), - max_seqlen_q, max_seqlen_k, total_seqlen_q, total_seqlen_k - }; - - dispatch_by_head_size(aten_to_Cutlass_dtype(query), args); - TensorRearranger::from_block_layout(out_block, out, cu_q, batch_size, num_heads_q); -} - // Fixed length implementation -void cutlass_prefill_fixed_impl( +void cutlass_fixed_impl( const at::Tensor& query, // [batch, seq_q, heads, head_size] B S H D const at::Tensor& key, // [batch, seq_k, heads, head_size] const at::Tensor& value, // [batch, seq_k, heads, head_size] @@ -440,12 +288,14 @@ void cutlass_prefill_fixed_impl( // Prepare arguments prefill_args_fixed_t args{ - {q_reshaped.data_ptr(), k_reshaped.data_ptr(), v_reshaped.data_ptr(), - out_temp.data_ptr(), static_cast(softmax_scale), - num_heads_q, num_heads_kv, head_size, is_causal, batch_size}, + q_reshaped.data_ptr(), k_reshaped.data_ptr(), v_reshaped.data_ptr(), + out_temp.data_ptr(), static_cast(softmax_scale), + num_heads_q, num_heads_kv, head_size, is_causal, batch_size, seq_len_q, seq_len_k }; dispatch_by_head_size(aten_to_Cutlass_dtype(query), args); out.copy_(out_temp.transpose(1, 2)); } + +} // namespace cutlass::flash_attention::fixed \ No newline at end of file diff --git a/flash-attn2/flash_attn_xpu/src/fmha_utils.hpp b/flash-attn2/flash_attn_xpu/src/fmha_utils.hpp index d918983..71ed2af 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_utils.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_utils.hpp @@ -2,11 +2,6 @@ #include "torch/all.h" #include -#define HEAD_SIZE_LIMIT_0 64 -#define HEAD_SIZE_LIMIT_1 96 -#define HEAD_SIZE_LIMIT_2 128 -#define HEAD_SIZE_LIMIT_3 192 - enum class CutlassType { half, bfloat16, @@ -26,6 +21,14 @@ inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { } using namespace cute; + +struct prefill_policy_head32 { + using ShapeQK = Shape<_64, _64, _32>; + using ShapePV = Shape<_64, _32, _64>; + using ShapeOutPut = Shape<_64, _32, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + struct prefill_policy_head64 { using ShapeQK = Shape<_128, _64, _64>; using ShapePV = Shape<_128, _32, _64>; @@ -47,6 +50,13 @@ struct prefill_policy_head128 { using SubgroupLayout = Layout, Stride<_1, _1, _1>>; }; +struct prefill_policy_head160 { + using ShapeQK = Shape<_256, _64, _32>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _160, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + struct prefill_policy_head192 { using ShapeQK = Shape<_256, _64, _64>; using ShapePV = Shape<_256, _32, _64>; @@ -55,8 +65,8 @@ struct prefill_policy_head192 { }; struct prefill_policy_head256 { - using ShapeQK = Shape<_256, _64, _64>; - using ShapePV = Shape<_256, _32, _64>; - using ShapeOutPut = Shape<_256, _256, _64>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _256, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; }; \ No newline at end of file diff --git a/flash-attn2/flash_attn_xpu/src/kernel/xe_flash_attn_prefill.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp similarity index 98% rename from flash-attn2/flash_attn_xpu/src/kernel/xe_flash_attn_prefill.hpp rename to flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp index b5ceb71..ec1db8e 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/xe_flash_attn_prefill.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp @@ -35,7 +35,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/kernel_hardware_info.hpp" -#include "../collective/xe_flash_attn_prefill_mma.hpp" +#include "../collective/fixed_mma.hpp" namespace cutlass::flash_attention::kernel { @@ -77,10 +77,10 @@ class FMHAPrefill { using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; - static_assert(cute::is_void_v or cute::is_same_v or - cute::is_same_v, "Unsupported TileScheduler for Intel Xe."); + static_assert(cute::is_void_v or cute::is_same_v or + cute::is_same_v, "Unsupported TileScheduler for Intel Xe."); using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector::Scheduler; + using TileScheduler = typename detail::fixed::TileSchedulerSelector::Scheduler; using TileSchedulerParams = typename TileScheduler::Params; // Epilogue derived types diff --git a/flash-attn2/flash_attn_xpu/src/kernel/tile_scheduler.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fixed_scheduler.hpp similarity index 94% rename from flash-attn2/flash_attn_xpu/src/kernel/tile_scheduler.hpp rename to flash-attn2/flash_attn_xpu/src/kernel/fixed_scheduler.hpp index 34345d8..884ba78 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/tile_scheduler.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fixed_scheduler.hpp @@ -40,7 +40,7 @@ namespace cutlass::flash_attention { -namespace kernel { +namespace kernel::fixed { struct XeFlashIndividualTileScheduler { @@ -226,15 +226,14 @@ struct XeFlashPersistentTileScheduler { } }; - -//////////////////////////////////////////////////////////////////////////////// -} // namespace kernel - struct IndividualScheduler{}; struct PersistentScheduler{}; struct FlashDecodeIndividualScheduler{}; - namespace detail +//////////////////////////////////////////////////////////////////////////////// +} // namespace kernel::fixed + + namespace detail::fixed { template < @@ -255,37 +254,37 @@ struct XeFlashPersistentTileScheduler { cute::enable_if_t>> { using Scheduler = typename TileSchedulerSelector< - IndividualScheduler, + kernel::fixed::IndividualScheduler, ArchTag>::Scheduler; }; template struct TileSchedulerSelector< - IndividualScheduler, + kernel::fixed::IndividualScheduler, ArchTag, cute::enable_if_t>> { - using Scheduler = kernel::XeFlashIndividualTileScheduler; + using Scheduler = kernel::fixed::XeFlashIndividualTileScheduler; }; template struct TileSchedulerSelector< - PersistentScheduler, + kernel::fixed::PersistentScheduler, ArchTag, cute::enable_if_t>> { - using Scheduler = kernel::XeFlashPersistentTileScheduler; + using Scheduler = kernel::fixed::XeFlashPersistentTileScheduler; }; template struct TileSchedulerSelector< - FlashDecodeIndividualScheduler, + kernel::fixed::FlashDecodeIndividualScheduler, ArchTag, cute::enable_if_t>> { - using Scheduler = kernel::XeFlashDecodeIndividualTileScheduler; + using Scheduler = kernel::fixed::XeFlashDecodeIndividualTileScheduler; }; - } // namespace detail + } // namespace detail::fixed //////////////////////////////////////////////////////////////////////////////// diff --git a/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp b/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp new file mode 100644 index 0000000..1de8aa8 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp @@ -0,0 +1,555 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "../collective/varlen_mma.hpp" +namespace cutlass::flash_attention::kernel { + +template +class FMHAPrefillChunk; +/////////////////////////////////////////////////////////////////////////////// +template +class FMHAPrefillChunk { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // ProblemShape: + static_assert( + rank(ProblemShape{}) == 7, + "ProblemShape{} should be "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using TiledMmaQK = typename CollectiveMainloop::TiledMmaQK; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementQ = typename CollectiveMainloop::ElementQ; + using StrideQ = typename CollectiveMainloop::StrideQ; + using ElementK = typename CollectiveMainloop::ElementK; + using StrideK = typename CollectiveMainloop::StrideK; + using ElementV = typename CollectiveMainloop::ElementV; + using StrideV = typename CollectiveMainloop::StrideV; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; + using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; + using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; + + static_assert(cute::is_void_v or + cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::varlen::TileSchedulerSelector::Scheduler; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementO = typename CollectiveEpilogue::ElementO; + using StrideO = typename CollectiveEpilogue::StrideO; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileShapeOutput = typename CollectiveEpilogue::TileShapeOutput; + using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput; + + static_assert( + cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr bool LocalMask = CollectiveMainloop::LocalMask; + + static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + + static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; + static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; + static constexpr int QK_BLK_K = CollectiveMainloop::QK_BLK_K; + + static constexpr int QK_ATOM_N = CollectiveMainloop::QK_ATOM_N; + static constexpr int QK_ATOM_K = CollectiveMainloop::QK_ATOM_K; + + static constexpr int QK_SG_M = CollectiveMainloop::QK_SG_M; + + static constexpr int Epilogue_BLK_N = get<1>(TileShapeOutput{}); + static constexpr int Epilogue_BLK_K = get<2>(TileShapeOutput{}); + + static constexpr int PV_ATOM_M = CollectiveMainloop::PV_ATOM_M; + static constexpr int PV_ATOM_N = CollectiveMainloop::PV_ATOM_N; + static constexpr int PV_ATOM_K = CollectiveMainloop::PV_ATOM_K; + + static constexpr auto Num_SGs = PV_ATOM_N * PV_ATOM_M * PV_ATOM_K; + static constexpr int Vec = CollectiveMainloop::Vec; + static constexpr int FragsM = CollectiveMainloop::FragsM; + // The FragsN here used for Creation of S matrix so we use the FragsN for S + // shape + static constexpr int FragsN = CollectiveMainloop::FragsNS; + + static constexpr int VSlicer = + get<1>(TileShapeOutput{}) / + (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = decltype(make_shape( + Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), + Int{})); + + static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + // Device side arguments + struct Arguments { + gemm::GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + SoftmaxArguments softmax{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + gemm::GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + SoftmaxParams softmax; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const& args, + void* workspace) { + (void)workspace; + return {args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, args.hw_info, TileShapeOutput{})}; + } + + static bool can_implement(Arguments const& args) { + bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or + (args.mode == gemm::GemmUniversalMode::kBatched && + rank(ProblemShape{}) == 4); + return mode_implementable; + } + + static int get_workspace_size(Arguments const& args) { return 0; } + + static cutlass::Status initialize_workspace( + Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, + int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length( + select<3, 4>(problem_shape), batch); + } else { + return select<3, 4>(problem_shape); + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // "ProblemShape{} should be "; + auto batch = get<0>(params.problem_shape); + auto num_heads_q = get<1>(params.problem_shape); + auto num_heads_kv = get<2>(params.problem_shape); + auto q_group_size = num_heads_q / num_heads_kv; + + auto& head_size_qk = get<5>(params.problem_shape); + auto& head_size_vo = get<6>(params.problem_shape); + // Preconditions + static_assert(cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert(cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert(cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); + + int thread_idx = int(ThreadIdxX()); + auto sub_group_id = get_sub_group_id(); + auto local_id = get_sub_group_local_id(); + + TileScheduler tile_scheduler{params.scheduler}; + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + // head_size_blk_idx, seq_len_blk_idx, batch_blk_idx, num_heads_blk_idx + auto blk_coord = tile_scheduler.get_block_coord(); + + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto blk_n_coord = 0; // nums_head_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_head_idx + + auto batch_coord = get<2>(blk_coord); // batch_blk_idx + + // for both fixed length and varlen + auto sequence_length_shape = + get_sequence_length_shape(params.problem_shape, batch_coord); + + auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) + // and check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { + continue; + } + + // loop kv by QK_BLK_N + const int kv_splits_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); + + int tiles_per_page = params.mainloop.page_size / QK_BLK_N; + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + Tensor mK_cache_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + + Tensor mQ_mk = mQ_mkl(_, _, 0); + Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, + make_coord(_, _, _), Step{}); + auto gV_cache = + local_tile(mV_cache_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + + // we limit the horizontal size to two subgroup, the empirical results + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop over + // to cover all the data needed) + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k_cache = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v_cache); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k_cache.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v_cache.get_slice(thread_idx); + + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + + auto& prefetch_K = tiled_prefetch_k_cache; + auto& pKgK1_ = pKgK_cache; + + int cached_nblock = 0; + if constexpr (PagedKV) { + if (seq_len_kv_cache != 0) { + int batch_offset = batch_coord * mainloop_params.max_pages_per_seq; + cached_nblock = + mainloop_params.ptr_page_table[batch_offset] * tiles_per_page; + } + } + // The headsize for both cached and non-cached version is the same + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK1_); j++) { + prefetch(prefetch_K, pKgK1_(_, _, _, cached_nblock, j)); + } + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime contain 1 + // max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence length process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + + // 2 for wg level, 3 for sg level + static constexpr int barrier_scope = CausalMask ? 3 : 2; + + int q_start_coord = blk_m_coord * QK_BLK_M; + int q_end_coord = cute::min(q_start_coord + QK_BLK_M, seq_len_qo); + int seq_diff = seq_len_kv_cache - seq_len_qo; + + const int seq_coord = cute::min( + seq_len_qo, + (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + + CUTLASS_PRAGMA_UNROLL + for (int split = 0; split < kv_splits_cache; split++) { + barrier_arrive(barrier_scope); + + int kv_start_coord = split * QK_BLK_N; + + if constexpr (CausalMask) { + if (kv_start_coord >= q_end_coord + seq_diff) { + break; + } + } + + // 1) Load KV (performed inside mmaQK) + auto gK_ = gK_cache(_, _, cached_nblock, _); + auto gV_ = gV_cache(_, _, cached_nblock); + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK(tSr, gQ, gK_, tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params); + + // mask padding + int col_start = local_id + kv_start_coord; + int col_end = col_start + (FragsN - 1) * get<1>(MmaAtomShape()); + if (col_end >= seq_len_kv_cache) { + int col_idx = col_start; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + if (col_idx >= seq_len_kv_cache) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + + if constexpr (CausalMask) { + int row_start = q_start_coord + sub_group_id * QK_SG_M; + if (row_start + seq_diff < col_end) { + int col_idx = col_start; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + if (col_idx > row_start + seq_diff) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + int row_idx = row_start + m * Vec + row; + if (row_idx + seq_diff < col_idx) + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + } + + if constexpr (LocalMask) { + // mask the elements of each tile where j - left > i || j + right < i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id; + col_idx += split * cute::min(QK_BLK_N, seq_len_kv_cache); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + bool left_mask = + col_idx < cute::max(0, row + row_idx + seq_len_kv_cache - + mainloop_params.window_left); + bool right_mask = + col_idx > cute::min(seq_len_kv_cache, + row + row_idx + seq_len_kv_cache + + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + + auto& tiled_prefetch_v_ = tiled_prefetch_v_cache; + auto& pVgV_ = pVgV_cache; + int v_prefetch_idx = PagedKV ? cached_nblock : split; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(pVgV_); i++) { + prefetch(tiled_prefetch_v_, pVgV_(_, i, _, v_prefetch_idx)); + } + int next_cached_nblock = split + 1; + if constexpr (PagedKV) { + int curr_batch_pages = mainloop_params.max_pages_per_seq; + int next_page_logical_idx = + next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + int batch_offset = batch_coord * mainloop_params.max_pages_per_seq; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = + params.mainloop + .ptr_page_table[batch_offset + next_page_logical_idx] * + tiles_per_page + + next_cached_nblock % tiles_per_page; + } else { + // if not valid, set to the end page + next_cached_nblock = curr_batch_pages * tiles_per_page; + } + } + + // 4) Fused softmax + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(split == 0, tSr, max_reg, sum_reg, out_reg); + // 5) Perform GEMM O = S*V + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, + mainloop_params); + + // ... prefetch next tile ... + // Prefetch the next Q tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + + cached_nblock = next_cached_nblock; + // Prefetch the next K tile + // there is no need to guard it with if statement as prefetch will + // ignore out of bound reading + auto& prefetch_k_selector = tiled_prefetch_k_cache; + auto& pKgK_ = pKgK_cache; + int k_prefetch_idx = + PagedKV ? cached_nblock : split + DispatchPolicy::Stages; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_); j++) { + prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx, j)); + } + barrier_wait(barrier_scope); + } + + // Epilogue + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0); + epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, + out_reg, max_reg, sum_reg); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention::kernel diff --git a/flash-attn2/flash_attn_xpu/src/kernel/varlen_scheduler.hpp b/flash-attn2/flash_attn_xpu/src/kernel/varlen_scheduler.hpp new file mode 100644 index 0000000..ede0489 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/kernel/varlen_scheduler.hpp @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::flash_attention { + +namespace kernel::varlen { + +struct XeFlashIndividualTileScheduler { + struct Params { + dim3 grid; + // FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const& problem_size, + KernelHardwareInfo hw_info, + TileShape const& tile_shape) { + using namespace cute; + + dim3 grid(size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + size(shape<1>(problem_size)), size(shape<0>(problem_size))); + return Params{grid}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { return valid_; } + + CUTLASS_DEVICE + auto get_block_coord() { + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct XeFlashPersistentTileScheduler { + struct Params { + int num_blocks; + FastDivmod divmod_seq_len_block; + FastDivmod divmod_head_size_block; + FastDivmod divmod_num_heads; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler(Params const& params) + : block_idx(BlockIdxX()), params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const& problem_size, + KernelHardwareInfo hw_info, + TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + + CUTLASS_TRACE_HOST( + "to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + hw_info.sm_count = sm_count; + + // problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, + // seq_len_kv_cache, head_size_qk, head_size_vo] + int num_head_size_blocks = + size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))); + int num_seq_len_blocks = + size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))); + int num_blocks = num_seq_len_blocks * num_head_size_blocks * + size(shape<0>(problem_size) * shape<1>(problem_size)); + + return Params{num_blocks, + {num_seq_len_blocks}, + {num_head_size_blocks}, + {shape<1>(problem_size)}, + hw_info}; + } + + template + static dim3 get_grid_shape(Params const& params) { + auto queue = COMPAT::get_default_queue(); + auto dev = queue.get_device(); + const size_t maxSubgroups = + dev.template get_info(); + // TODO (Codeplay): revert this back to std::min(params.num_blocks, + // params.hw_info.sm_count) once performance issue is fixed. + dim3 grid( + std::min(params.num_blocks, + ceil_div(params.hw_info.sm_count * maxSubgroups, Num_SGs)), + 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { return block_idx < params.num_blocks; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int seq_len_block, head_size_block, bidh; + params.divmod_head_size_block(block_decode, head_size_block, block_decode); + params.divmod_seq_len_block(block_decode, seq_len_block, block_decode); + params.divmod_num_heads(block_decode, bidh, block_decode); + return make_coord(head_size_block, seq_len_block, block_decode, bidh); + } + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler& operator++() { + block_idx += GridDimX(); + return *this; + } +}; + +struct IndividualScheduler {}; +struct PersistentScheduler {}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace kernel::varlen + +namespace detail::varlen { + +template +struct TileSchedulerSelector { + static_assert(cutlass::detail::dependent_false, + "Could not select a tile scheduler for given parameters."); +}; + +// Default (void) maps to XeFlashIndividualTileScheduler +template +struct TileSchedulerSelector< + void, ArchTag, + cute::enable_if_t>> { + using Scheduler = + typename TileSchedulerSelector::Scheduler; +}; + +template +struct TileSchedulerSelector< + kernel::varlen::IndividualScheduler, ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::varlen::XeFlashIndividualTileScheduler; +}; + +template +struct TileSchedulerSelector< + kernel::varlen::PersistentScheduler, ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::varlen::XeFlashPersistentTileScheduler; +}; +} // namespace detail::varlen + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention diff --git a/flash-attn2/flash_attn_xpu/src/varlen.hpp b/flash-attn2/flash_attn_xpu/src/varlen.hpp new file mode 100644 index 0000000..7729522 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/varlen.hpp @@ -0,0 +1,365 @@ +#pragma once + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/util/device_memory.h" +#include + +#include "./compat_wrapper.hpp" +#include "./kernel/varlen_scheduler.hpp" +#include "./kernel/varlen_kernel.hpp" +#include "./collective/fmha_fusion.hpp" +#include "./collective/varlen_epilogue.hpp" +#include "./collective/varlen_softmax_epilogue.hpp" + +#include "fmha_utils.hpp" + +namespace cutlass::flash_attention::varlen { + +using namespace cute; + +struct chunk_prefill_args_t { + void* query; + void* key; + void* value; + void* out; + void* block_table; + void* cu_seqlens_q; + void* cu_seqlens_k; + int max_queries; + int max_keys; + int total_seqlen_q; + int total_seqlen_k; + float sm_scale; + int batch_size; + int num_heads_q; + int num_heads_k; + int head_size; + int max_blocks_per_seq; + int block_size; + bool is_causal; + bool use_paged_kv; +}; + +template +struct KernelLauncher { + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = + typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + /// Initialization + StrideQ stride_Q; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + ProblemShapeType initialize(const chunk_prefill_args_t& args) { + auto problem_shape = cute::make_tuple( + 1, args.num_heads_q, args.num_heads_k, args.total_seqlen_q, + args.total_seqlen_k, args.head_size, args.head_size); + auto problem_shape_out = cute::make_tuple( + args.batch_size, args.num_heads_q, args.num_heads_k, + cutlass::fmha::collective::VariableLength{args.max_queries}, // cu_q + cutlass::fmha::collective::VariableLength{ + args.max_keys}, // cu_kv_cache + args.head_size, args.head_size); + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv_cache, + head_size_qk, head_size_vo] = problem_shape; + auto group_q_size = num_heads_q / num_heads_kv; + auto group_q_num = num_heads_q / group_q_size; + + stride_Q = cutlass::make_cute_packed_stride( + StrideQ{}, + cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K_cache = cutlass::make_cute_packed_stride( + StrideK{}, + cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride( + StrideV{}, + cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv_cache, batch)); + + stride_O = cutlass::make_cute_packed_stride( + StrideO{}, cute::make_shape(seq_len_qo * group_q_size, + group_q_num * head_size_vo, batch)); + + get<3>(problem_shape_out).cumulative_length = + reinterpret_cast(args.cu_seqlens_q); + get<4>(problem_shape_out).cumulative_length = + reinterpret_cast(args.cu_seqlens_k); + + return problem_shape_out; + } + + cutlass::Status run(const chunk_prefill_args_t& args, + const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = initialize(args); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {reinterpret_cast(args.query), stride_Q, + reinterpret_cast(args.key), stride_K_cache, + reinterpret_cast(args.value), stride_V_cache, + static_cast(args.block_table), args.block_size, + args.max_blocks_per_seq, args.total_seqlen_k, -1, -1}, + {args.sm_scale}, + {reinterpret_cast(args.out), stride_O}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = + FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()); + + // Convert host-side arguments to device-side arguments to be passed to the + // kernel + auto params = FMHAChunkPrefillKernel::to_underlying_arguments( + arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params); + + return cutlass::Status::kSuccess; + } + + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = COMPAT::dim3(block.x, block.y, block.z); + const auto sycl_grid = COMPAT::dim3(grid.x, grid.y, grid.z); + + COMPAT::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + COMPAT::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size< + FMHAChunkPrefillKernel::DispatchPolicy::SubgroupSize>}; + COMPAT::experimental::launch_policy policy{sycl_grid, sycl_block, + launch_props, kernel_props}; +#if defined(OLD_API) + auto event = COMPAT::experimental::launch>(policy, params); +#else + auto event = COMPAT::experimental::launch, FMHAChunkPrefillKernel>(policy, params); +#endif + // EventManager::getInstance().addEvent(event); + } +}; + +template +struct FMHAKernel { + template + static void run(const chunk_prefill_args_t& args) { + cutlass::KernelHardwareInfo hw_info; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using ElementInputKV = ElementInputQ; + using ElementOutput = ElementInputQ; + + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = + cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, + SubgroupLayout, ElementComputeEpilogue, ElementOutput, + cutlass::gemm::TagToStrideC_t, ElementOutput, + GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = + cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue< + Causal, Local, EpilogueDispatchPolicy, ElementAccumulator>; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = + cute::tuple; + using ProblemShapeType = + std::conditional_t; + + // Mainloop + using CollectiveMainloop = + cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, + cutlass::gemm::TagToStrideA_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, + TileShapePV, SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, Local, PagedKV>; + + using FMHAChunkPrefillKernel = + cutlass::flash_attention::kernel::FMHAPrefillChunk< + ProblemShapeType, CollectiveMainloop, CollectiveSoftmaxEpilogue, + CollectiveEpilogue, Scheduler>; + + KernelLauncher launcher; + + launcher.run(args, hw_info); + } + + static void dispatch(const chunk_prefill_args_t& args) { + if (args.use_paged_kv) { + if (args.is_causal) { + run(args); + } else { + run(args); + } + } else { + if (args.is_causal) { + run(args); + } else { + run(args); + } + } + } +}; + +template +void policy_dispatch(CutlassType cuType, const chunk_prefill_args_t& args) { + const int PipelineStages = 2; + + if (cuType == CutlassType::half) { + FMHAKernel::dispatch(args); + } else { + FMHAKernel::dispatch(args); + } +} + +template +void dispatch_by_head_size(CutlassType cuType, const ArgsType& args) { + const int h = args.head_size; + if (h <= 32) { + policy_dispatch(cuType, args); + } + else if (h <= 64) { + policy_dispatch(cuType, args); + } + else if (h <= 96) { + policy_dispatch(cuType, args); + } + else if (h <= 128) { + policy_dispatch(cuType, args); + } + else if (h <= 160) { + policy_dispatch(cuType, args); + } + else if (h <= 192) { + policy_dispatch(cuType, args); + } + else if (h <= 256) { + policy_dispatch(cuType, args); + } + else { + throw std::runtime_error("Unsupported head_size: " + std::to_string(h) + ". Max supported head_size is 256"); + } +} + +void cutlass_varlen_impl( + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, at::Tensor& out, + const std::optional& block_table, const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + double sm_scale, bool is_causal) { + int num_heads_q = query.size(1); + int head_size = query.size(2); + int batch_size = cu_seqlens_q.numel() - 1; + int total_seqlen_q = query.size(0); + bool use_paged_kv = block_table.has_value() && block_table->defined(); + int num_block, block_size, num_heads_kv, max_blocks_per_seq, total_seqlen_k; + + if (use_paged_kv) { + num_block = key_cache.size(0); + block_size = key_cache.size(1); + num_heads_kv = key_cache.size(2); + max_blocks_per_seq = block_table->size(1); + total_seqlen_k = num_block * block_size; + } else { + // [total_seqlen_k, heads, head_size] + num_block = 0; + block_size = 0; + max_blocks_per_seq = 0; + num_heads_kv = key_cache.size(1); + total_seqlen_k = key_cache.size(0); + } + + chunk_prefill_args_t args = {query.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + out.data_ptr(), + block_table.has_value() ? block_table->data_ptr() : nullptr, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + static_cast(sm_scale), + batch_size, + num_heads_q, + num_heads_kv, + head_size, + max_blocks_per_seq, + block_size, + is_causal, + use_paged_kv}; + + CutlassType cuType = aten_to_Cutlass_dtype(query); + + dispatch_by_head_size(cuType, args); +} + +} // namespace cutlass::flash_attention::varlen \ No newline at end of file diff --git a/flash-attn2/tests/test_flash_attn.py b/flash-attn2/tests/test_flash_attn.py index 9166c52..2cd234f 100644 --- a/flash-attn2/tests/test_flash_attn.py +++ b/flash-attn2/tests/test_flash_attn.py @@ -1670,8 +1670,6 @@ def test_flash_attn_varlen_causal( if device == "xpu": if local: pytest.skip("local attention not supported on xpu currently") - if paged_kv_block_size is not None: - pytest.skip("paged_kv_block_size not supported on xpu currently") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q