diff --git a/applications/flash_attention_v2/collective/copy_block_slm.hpp b/applications/flash_attention_v2/collective/copy_block_slm.hpp new file mode 100644 index 0000000000..2e8f4b64d1 --- /dev/null +++ b/applications/flash_attention_v2/collective/copy_block_slm.hpp @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cute { + +/* Flat copies */ +template +CUTE_HOST_DEVICE +void +copy_block_r2s(Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem_v && is_smem_v, "Expected rmem->smem copy"); + + auto atom_r2s = Copy_Atom, float>{}; // TODO: larger block messages + + auto atom_shape = make_shape(_1{}, size(src)); + auto src_v = src.compose(make_layout(atom_shape)); + auto dst_v = dst.compose(make_layout(atom_shape, Stride<_0, _16>{})); + + copy(atom_r2s, src_v, dst_v); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + Tensor & dst) +{ + static_assert(is_smem_v && is_rmem_v, "Expected smem->rmem copy"); + + auto atom_s2r = Copy_Atom, float>{}; + + auto atom_shape = make_shape(_1{}, size(dst)); + auto src_v = src.compose(make_layout(atom_shape, Stride<_0, _16>{})); + auto dst_v = dst.compose(make_layout(atom_shape)); + + copy(atom_s2r, src_v, dst_v); +} + +/* Coordinate-aware copies */ +template +CUTE_HOST_DEVICE +void +copy_block_r2s(SubgroupTensor const& src, + Tensor & dst, + DstCoordLayout const& dst_c) +{ + static_assert(is_rmem_v && is_smem_v, "Expected rmem->smem copy"); + + auto atom_r2s = Copy_Atom, float>{}; // TODO: larger block messages + + auto atom_shape = make_shape(_1{}, size(SrcLayout{})); + + auto src_c_wi0 = composition(project_strides(SrcCoordLayout{}), make_layout(atom_shape, Stride<_0, _16>{})); + auto rlayout = composition(right_inverse(project_strides(dst_c)), src_c_wi0); + + auto src_v = src.compose(make_layout(atom_shape)); + auto dst_v = dst.compose(rlayout); + + copy(atom_r2s, src_v, dst_v); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + SrcCoordLayout const& src_c, + SubgroupTensor & dst) +{ + static_assert(is_smem_v && is_rmem_v, "Expected smem->rmem copy"); + + auto atom_s2r = Copy_Atom, float>{}; + + auto atom_shape = make_shape(_1{}, size(DstLayout{})); + + auto dst_c_wi0 = composition(project_strides(DstCoordLayout{}), make_layout(atom_shape, Stride<_0, _16>{})); + auto rlayout = composition(right_inverse(project_strides(src_c)), dst_c_wi0); + + auto src_v = src.compose(rlayout); + auto dst_v = dst.compose(make_layout(atom_shape)); + + copy(atom_s2r, src_v, dst_v); +} + +/* Variants accepting rvalue dst */ +template +CUTE_HOST_DEVICE +void +copy_block_r2s(Tensor const& src, + Tensor && dst) +{ + return copy_block_r2s(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + Tensor && dst) +{ + return copy_block_s2r(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_block_r2s(SubgroupTensor const& src, + Tensor && dst, + DstCoordLayout const& dst_c) +{ + return copy_block_r2s(src, dst, dst_c); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + SrcCoordLayout const& src_c, + SubgroupTensor && dst) +{ + return copy_block_s2r(src, dst); +} + +} /* namespace cute */ \ No newline at end of file diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp new file mode 100644 index 0000000000..1a32ba6ef5 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -0,0 +1,288 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/subgroup_algorithms.hpp" +#include "cute/algorithm/tensor_algorithms.hpp" + +#include "copy_block_slm.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template // Optional TiledCopy for loading O +class FMHAFwdEpilogue { + +public: + // + // Type Aliases + // + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + using TileShapeO = TileShapeO_; + using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAPV::ThrLayoutVMNK{})))); + + using TensorO = TensorO_; + using TensorO2D = decltype(TensorO_{}(append>(make_coord(_,_),0))); + using ElementO = typename TensorO_::value_type; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + using ElementA = typename FragA::value_type; + + // Split k-reduced tiles between participating subgroups. + // Assumption: the A tile is contiguous. + using ReduceK = decltype(size<3>(typename TiledMMAPV::ThrLayoutVMNK{})); + + using SGTileShapeA = decltype(atuple_coshape(FragA{}.tv_layout())); + using ReduceSGQ = decltype(cute::gcd(get<0>(SGTileShapeA{}), ReduceK{})); + using ReduceSGV = decltype(cute::min(get<1>(SGTileShapeA{}) / intel::_SGSize{}, ReduceK{} / ReduceSGQ{})); + using ReduceSGLayout = decltype(make_identity_layout(Shape{})); + + using SGTileShapeO = decltype(shape_div(take<0,2>(SGTileShapeA{}), shape(ReduceSGLayout{}))); + + using ReduceFragA = decltype(make_subgroup_tensor( + make_layout(select<1,0>(SGTileShapeO{}), + Stride, E<0>>{}) + )); + using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{})); + + static auto default_tiled_copy_O_helper() { + if constexpr (ReduceK{} == _1{}) + return make_block_2d_copy_C(TiledMMAPV{}, TensorO2D{}); + else + return make_block_2d_copy_C_subtiled(TiledMMAPV{}, ReduceFragA{}.tv_layout(), ReduceSGLayout{}, TensorO2D{}); + } + + using DefaultTiledCopyO = decltype(default_tiled_copy_O_helper()); + using TiledCopyO = conditional_t, DefaultTiledCopyO, TiledCopyO_>; + + // Stateless design -- no arguments or parameters. + struct Arguments {}; + struct Params {}; + + // Shared memory storage + // Note sum/max tiles are padded to 16 elements, due to limitations in CuTe block load infrastructure. + using AlignedSGTileA_Q = C<((size<0>(SGTileShapeA{}) + intel::sg_size - 1) / intel::sg_size) * intel::sg_size>; + + struct SharedStorageNone {}; + struct SharedStorageReduceK { + cute::array a_data; + cute::array a_sum_data, a_max_data; + }; + + using SharedStorage = conditional_t<(ReduceK{} > _1{}), SharedStorageReduceK, SharedStorageNone>; + +private: + SharedStorage &shared; + +public: + static constexpr + Params to_underlying_arguments(Arguments const &args, void * /* workspace */) { + return {}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) { + return true; + } + + CUTLASS_HOST_DEVICE + FMHAFwdEpilogue(Params const&, SharedStorage& shared_) : shared(shared_) {} + + template + CUTLASS_DEVICE + void + operator()(TensorO2D const& O, // Global O tensor: (q,v) + FragA & tArA, // O accumulator: (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id) { // Work-item ID + + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) + rA_sum(i) = ElementA(1) / rA_sum(i); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) + rA(i) *= broadcast<0>(rA_sum, rA, i); + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // Reduce k-blocks of A and A_sum across WG, if needed. + // Note that each k block has its own scale factor based on A_max, + // so A/A_sum contributions need to be rescaled to match. + template + CUTLASS_DEVICE + decltype(auto) + reduce_A(FragA & tArA, // O accumulator: (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + int thr_id) { // Work-item ID + + using namespace sycl::ext::oneapi::this_work_item; + + if constexpr (ReduceK{} == _1{}) { + return std::make_tuple(tArA, tA_sum, true); + } else { + /* Identify A tile ID and k block for this subgroup. */ + auto thr_vak = group<1,3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); + auto a_tile = get<1>(thr_vak); + auto k_blk = get<2>(thr_vak); + + /* Set up SLM tensors and partition A tiles among participating subgroups */ + auto shape_A = append(append(SGTileShapeA{}, ReduceK{}), SGPerWG{}/ReduceK{}); + auto shape_A_row = make_shape(get<0>(SGTileShapeO{}), shape(ReduceSGLayout{}), ReduceK{}, SGPerWG{}/ReduceK{}); + + /* Physical layouts, with subtile modes broken out */ + auto sA_layout = group<2,4>(flat_divide(make_ordered_layout(shape_A, Step<_1,_0,_2,_3>{}), SGTileShapeO{})); + auto sA_row_stride = make_stride(_1{}, make_stride(get<0>(shape_A_row), _0{}), + AlignedSGTileA_Q{}, AlignedSGTileA_Q{} * ReduceK{}); + auto sA_row_layout = make_layout(shape_A_row, sA_row_stride); + + /* Coordinate layouts, with subtile modes broken out */ + auto basis2 = make_basis_like(SGTileShapeO{}); + auto sA_coords = make_layout(append(SGTileShapeO{}, shape(ReduceSGLayout{})), + append(basis2, product_each(zip(SGTileShapeO{}, basis2)))); + + auto sA = make_tensor(make_smem_ptr(&shared.a_data), sA_layout); // (q,v,rblk_dst,rblk_src,a_tile) + auto sA_max = make_tensor(make_smem_ptr(&shared.a_max_data), sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + auto sA_sum = make_tensor(make_smem_ptr(&shared.a_sum_data), sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + + /* Write my contributions to SLM. */ + copy_block_r2s(tA_max, sA_max(_,_,k_blk,a_tile)); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + copy_block_r2s(tA_sum, sA_sum(_,_,k_blk,a_tile)); + copy_block_r2s(tArA, sA(_,_,_,k_blk,a_tile), sA_coords); + + bool active = (k_blk < size(ReduceSGLayout{})) + || (ReduceK{} == size(ReduceSGLayout{})); // help compiler out + + /* Wait for maxima to be available, signal other data available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + + ReduceFragA rA; + ReduceFragARow rA_sum, rA_max, rA_kmax[ReduceK{}]; + + if (active) { + /* Read A_max back from SLM and reduce. */ + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + copy_block_s2r(sA_max(_,k_blk,kr,a_tile), rA_kmax[kr]); + } + + rA_max = rA_kmax[0]; + for (int kr = 1; kr < ReduceK{}; kr++) + cute::transform(rA_max, rA_kmax[kr], rA_max, cute::max_fn{}); + + /* Calculate scale factors for aligning per-block maxima. */ + for (int kr = 0; kr < ReduceK{}; kr++) { + cute::transform(rA_max, rA_kmax[kr], rA_kmax[kr], [](auto gmax, auto kmax) { + return sycl::native::exp2(kmax - gmax); + }); + } + } + + /* Wait for A/A_sum data to be available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + + if (active) { + /* Read A/A_sum back from SLM, align scaling to new maxima, and reduce. */ + clear(rA_sum); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragARow rA_sum_read; + copy_block_s2r(sA_sum(_,k_blk,kr,a_tile), rA_sum_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum_read.size(); i++) { + rA_sum(i) += rA_sum_read(i) * rA_kmax[kr](i); + } + } + + clear(rA); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragA rA_read; + copy_block_s2r(sA(_,_,k_blk,kr,a_tile), sA_coords(_,_,0), rA_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_read.size(); i++) { + rA(i) += rA_read(i) * broadcast<0>(rA_kmax[kr], rA, i); + } + } + } + return std::make_tuple(rA, rA_sum, active); + } + } +}; + + +} // namespace cutlass::fmha::collective diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp new file mode 100644 index 0000000000..0feddae1b7 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -0,0 +1,415 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/algorithm/subgroup_algorithms.hpp" +#include "cute/atom/mma_atom.hpp" +#include "fmha_fusion.hpp" + +namespace cutlass::fmha { + +template class XeDefault {}; // Default FMHA mainloop, P in registers. + +}; + +namespace cutlass::fmha::collective { + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template // Optional TiledCopy for loading V +struct FMHAFwdMainloop { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHAFwdMainloop, CausalMask_, + TiledMMAQK_, TiledMMAPV_, VTiles_, + TensorQ_, TensorK_, TensorV_, + TiledCopyQ_, TiledCopyK_, TiledCopyV_> { + // + // Type Aliases + // + using TiledMMAQK = TiledMMAQK_; + using TiledMMAPV = TiledMMAPV_; + using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk()); + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + static constexpr int VTiles = VTiles_; + + using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); + + using TensorQ = TensorQ_; + using TensorK = TensorK_; + using TensorV = TensorV_; + + using TensorQ2D = decltype(TensorQ_{}(append>(make_coord(_,_),0))); + using TensorK2D = decltype(TensorK_{}(append>(make_coord(_,_),0))); + using TensorV2D = decltype(TensorV_{}(append>(make_coord(_,_),0))); + + using TiledCopyQ = conditional_t, decltype(make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{})), TiledCopyQ_>; + using TiledCopyK = conditional_t, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK2D{})), TiledCopyK_>; + using TiledCopyV = conditional_t, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV2D{})), TiledCopyV_>; + + // TODO: static_asserts on TiledMMAPV here... + + // + // Accumulator types + // + // FragS: accumulator for Q*K MMA + // FragO: accumulator for P*V MMAs. + // Note: v mode may be split into multiple pieces + // to reduce register pressure. + // Frag*Row types are reductions of the corresponding Frag* types + // over rows. + // + template + using FragC = decltype(TiledMMA{}.get_slice(0).partition_sg_fragment_C( + make_identity_tensor(select<0,1>(TiledMMA{}.tile_mnk())))); + + using FragS = FragC; + using FragSRow = decltype(reduce<1>(FragS{}, sycl::plus{})); + using ElementS = typename TiledMMAQK::ValTypeD; + + using SingleFragA = FragC; // (atom val,q',v') + using FragA = expand_sg_fragment_t; // (atom val,q',v',VV) + using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); + using ElementA = typename TiledMMAPV::ValTypeD; + + static constexpr bool CausalMask = CausalMask_; + + // User-facing arguments + struct Arguments { + ElementS const scale; + }; + + // Kernel-facing parameters + using Params = Arguments; + + // SLM data + struct SharedStorage {}; + + Params params; + + // + // Methods + // + + FMHAFwdMainloop(Params const& params_, SharedStorage&) : params(params_) {} + + static constexpr + Params to_underlying_arguments(Arguments const &args, void * /* workspace */) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) + ElementS val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + CUTLASS_HOST_DEVICE static + bool can_implement(Arguments const&) { + return true; + } + + template + CUTLASS_DEVICE + void + operator()(TensorQ2D const& Q_2D, // (q,d) + TensorK2D const& K_2D, // (k,d) + TensorV2D const& V_2D, // (d,k) + FragA & tArA, // Output accumulator (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (Q,V) + int blk_k0, // K block range: [K0,K1) + int blk_k1, + int thr_id) { // Work-item ID + + using namespace sycl::ext::oneapi::this_work_item; + + // Short dimension names: + // q = sequence len dimension for Q + // k = sequence len dimension for K + // d = head size dimension for K/Q + // v = head size dimension for V + // VV = MMA tile indices for V + // Capital letters (Q, K, ...) refer to WG block indices. + // Primed letters (q', k', ...) refer to atom block indices. + + auto tile_shape_v = make_shape(get<1>(TileShapePV{}) * C{}, get<2>(TileShapePV{})); + + /* Create proxy coordinate tensors for Q/K/P/V */ + Tensor cQ = make_identity_tensor(Q_2D.shape()); // (q,d) + Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) + Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) + + /* Partition global tensors into workgroup tiles */ + Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) + Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) + Tensor gV = local_tile(cV, tile_shape_v, make_coord(get<1>(blk_qv),_)); // (v,k,K) + Tensor gV_split = local_tile(gV, TileShapePV{}, make_coord(_,_,0), Step{}); // (v,k,VV,K) + + /* Create global -> register copies */ + TiledCopyQ copy_q{Q_2D}; + TiledCopyK copy_k{K_2D}; + TiledCopyV copy_v{V_2D}; + + /* Create MMAs */ + TiledMMAQK mma_qk{}; + TiledMMAPV mma_pv{}; + + /* Slice TiledCopy/TiledMMA operations down to to work-item level */ + auto thr_copy_q = copy_q.get_slice(thr_id); + auto thr_copy_k = copy_k.get_slice(thr_id); + auto thr_copy_v = copy_v.get_slice(thr_id); + auto thr_mma_qk = mma_qk.get_slice(thr_id); + auto thr_mma_pv = mma_pv.get_slice(thr_id); + + /* Partition coordinate tensors for copy */ + auto tQgQ = thr_copy_q.partition_S(gQ); // (atom_val,q',d',D) + auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) + auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) + + /* Create register fragments for MMA and copies */ + auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0)); + auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_,_,0)); + + auto tKrK = thr_copy_k.partition_sg_fragment_D(gK(_,_,0,0)); + auto tSrK = thr_mma_qk.partition_sg_fragment_B(gK(_,_,0,0)); + + auto tSrS = thr_mma_qk.partition_sg_fragment_C(cP); + auto tArP = thr_mma_pv.partition_sg_fragment_A(cP); + + auto tVrV = thr_copy_v.partition_sg_fragment_D(gV_split(_,_,0,0)); + auto tArV = thr_mma_pv.partition_sg_fragment_B(gV_split(_,_,0,0)); + + /* Create TiledCopy objects for prefetches */ + auto prefetch_q = make_block_2d_prefetch(copy_q); + auto prefetch_k = make_block_2d_prefetch(copy_k); + auto prefetch_v = make_block_2d_prefetch(tile_shape_v, V_2D); + + /* Partition global tensors for prefetch */ + auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); + auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); + auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV); + + // ------ + // Kernel + // ------ + + /* Initialization steps for first block: Q/K prefetch, O init */ + /* TODO: limit D prefetch for large head size, and reorder K prefetches */ + if (blk_k0 == 0) { + for (int D = 0; D < size<3>(pQgQ); D++) { + prefetch(prefetch_q, pQgQ(_,_,_,D)); + } + + for (int D = 0; D < size<4>(pKgK); D++) { + CUTLASS_PRAGMA_UNROLL + for (int K = 0; K < Stages; K++) { + prefetch(prefetch_k, pKgK(_,_,_,K,D)); + } + } + + clear(tArA); + fill(tA_max, ElementA(-INFINITY)); + clear(tA_sum); + } + + /* Check if */ + bool check_remainder_k = (shape<0>(K_2D) % get<1>(TileShapeQK{}) != 0); + + /* Main loop, blocked in k. */ + for (int K = blk_k0; K < blk_k1; K++) { + /* Split barrier to keep threads together */ + constexpr int barrier_scope = 2; /* WG scope */ + barrier_arrive(barrier_scope); + + /* GEMM 1: S = K * Q */ + clear(tSrS); /* TODO: fuse w/ initial gemm call */ + for (int D = 0; D < size<4>(tKgK); D++) { + copy(copy_q, tQgQ(_,_,_,D), tQrQ); + copy(copy_k, tKgK(_,_,_,K,D), tKrK); + + reorder(tQrQ, tSrQ); + reorder(tKrK, tSrK); + + cute::gemm(mma_qk, tSrQ, tSrK, tSrS); + } + + /* V prefetch for GEMM 2 */ + prefetch(prefetch_v, pVgV(_,_,_,K)); + + /* k masking for remainder tiles */ + if (check_remainder_k && K == blk_k1 - 1) { + FragSRow k_rem_mask; + int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; + for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { + k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); + } + for (int i = 0; i < tSrS.size(); i++) { + tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i)); + } + } + + /* TODO: causal masking */ + static_assert(!CausalMask, "Causal mask unimplemented"); + + /* Apply softmax and scaling */ + softmax(K == 0, tSrS, tA_max, tA_sum, tArA); +#if 0 + reorder(tSrS, tArP); +#else + for (int i = 0; i < tArP.size(); i++) + tArP(i) = static_cast(tSrS(i)); +#endif + + /* GEMM 2: A += P * V, split in v dimension */ + CUTLASS_PRAGMA_UNROLL + for (int VV = 0; VV < VTiles; VV++) { + copy(copy_v, tVgV(_,_,_,VV,K), tVrV); + reorder(tVrV, tArV); + cute::gemm(mma_pv, tArP, tArV, tArA(_,_,_,VV)); + } + + /* K prefetch */ + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_,_,_,K+Stages,D)); + } + + barrier_wait(barrier_scope); + } + } + + // Single step of blocked softmax. + CUTLASS_DEVICE + void + softmax(bool first_block, // First softmax block? + FragS & tS, // Softmax src/dst block + FragSRow & tS_max, // Softmax row-wise max accumulator + FragSRow & tS_sum, // Softmax row-wise sum accumulator + FragA & tA) { // O accumulator (for rescaling) + + /* Compute row-wise maxima for this block */ + auto tS_bmax = reduce<1>(tS, sycl::maximum{}); + + /* Update (scaled) maxima */ + auto tS_prev_max = tS_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + tS_max(i) = sycl::max(tS_max(i), params.scale * tS_bmax(i)); + } + + /* Scale S and subtract maxima, then exponentiate */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS.size(); i++) + tS(i) = sycl::native::exp2(params.scale * tS(i) - broadcast<0>(tS_max, tS, i)); + + /* Rescale existing S sums and O accumulator */ + if (!first_block) { + FragSRow rescale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + rescale(i) = sycl::native::exp2(tS_prev_max(i) - tS_max(i)); + tS_sum(i) *= rescale(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tA.size(); i++) + tA(i) *= broadcast<0>(rescale, tA, i); + } + + /* Update sums */ + auto tS_bsum = reduce<1>(tS, sycl::plus{}); + for (int i = 0; i < tS_sum.size(); i++) + tS_sum(i) += tS_bsum(i); + } +}; + + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_sg_layout_pv(SGLayoutQK const&) +{ + return make_layout( + get<0>(SGLayoutQK{}), + Layout<_1, _0>{}, + get<1>(SGLayoutQK{}) + ); +} + +// Get a P*V TiledMMA given K*Q tile size and SG configuration, for mainloops +// not supporting S data interchange among subgroups (e.g. XeDefault). +template +CUTLASS_HOST_DEVICE +constexpr auto +get_tiled_mma_pv(MMAOp const&, WGTileQK const& wg_tile_qk, SGLayoutQK const& sg_layout_qk, TileV const&) { + using TileQ = decltype(get<0>(wg_tile_qk)); + using TileK = decltype(get<1>(wg_tile_qk)); + + using WGTilePV = Shape; + using SGLayoutPV = decltype(get_sg_layout_pv(sg_layout_qk)); + + static_assert(size(SGLayoutPV{}) == size(SGLayoutQK{}), + "Q*K cannot be parallelized in the head size dimension"); + + return TiledMMAHelper{}; +} + +} // namespace cutlass::fmha::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp new file mode 100644 index 0000000000..48b27b16c4 --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" +#include "flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +struct FMHAProblemShape { + int batch; + int num_heads_q, num_heads_kv; + int seq_len_qo, seq_len_kv; // -> VariableLen to support variable-length-per-batch cases + int head_size_qk, head_size_vo; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class XeFMHAFwdKernel { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ *Q; + StrideQ dQ; + const ElementK *K; + StrideK dK; + const ElementV *V; + StrideV dV; + ElementO *O; + StrideO dO; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + return {args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{})}; + } + + static bool can_implement(Arguments const &args) { + return CollectiveMainloop::can_implement(args.mainloop) + && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) + { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_kv / s.num_heads_q; + + int thr_id = int(ThreadIdxX()); + + TileScheduler tile_scheduler{params.scheduler}; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto blk_qv = make_coord(blk_q, blk_v); + int head_q = head / head_group_q; + + const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{})); + + auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); + auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch); + auto shape_V = make_shape(s.head_size_vo, s.seq_len_kv, s.num_heads_kv, s.batch); + auto shape_O = make_shape(s.seq_len_qo, s.head_size_vo, s.num_heads_kv, s.batch); + + auto dcQ = const_cast(p.Q); // de-const these for uniformity + auto dcK = const_cast(p.K); + auto dcV = const_cast(p.V); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, p.dQ)); // (q,d,h,b) + Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, p.dK)); // (k,d,h,b) + Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, p.dV)); // (v,k,h,b) + Tensor O = make_tensor(make_gmem_ptr(p.O), make_layout(shape_O, p.dO)); // (q,v,h,b) + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + mainloop(Q(_,_,head_q,idx_b), + K(_,_,head,idx_b), + V(_,_,head,idx_b), + tArA, tA_max, tA_sum, + blk_qv, 0, k_blocks, + thr_id); + + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + epilogue(O(_,_,head,idx_b), + tArA, tA_max, tA_sum, + blk_qv, thr_id); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp new file mode 100644 index 0000000000..a14d6db482 --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +struct XeFHMAIndividualTileScheduler { + + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFHMAIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, KernelHardwareInfo hw_info, + TileShape const& tile_shape) + { + using namespace cute; + + dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + return Params{grid, {shape.num_heads_q}}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int idx_b = BlockIdxZ(); + int head; + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b); + } + + CUTLASS_DEVICE + XeFHMAIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp new file mode 100644 index 0000000000..82a9c3dbd2 --- /dev/null +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Flash Attention V2 Prefill for Intel BMG + + This example constructs and executes a Flash Attention Prefill kernel on Intel BMG. The + definition of the GEMM, options etc for this example are defined in the associated + bmg_flash_attn_runner.hpp header file. + + See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm + + To run this example: + $ ./examples/sycl/06_bmg_flash_attention/06_xe_fmha_fwd --seq_len_qo=512 + --seq_len_kv=512 --head_size_vo=128 --head_size_qk=128 + + To build & run this example (from your build dir): + + $ ninja 06_xe_fmha_fwd + $ ./examples/sycl/06_bmg_flash_attention/06_xe_fmha_fwd + + Call with `--help` for information about available options +*/ + +#include "xe_fmha_fwd_runner.hpp" + +int main(int argc, const char **argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // Define the work-group tile shape depending on the head-size of the second matmul + +#ifdef PREFILL +#if HEAD_DIM == 16 + /* Tiny config for testing */ + using ShapeQK = Shape<_1, _16, _16>; // (q,k,d) + using ShapePV = Shape<_1, _16, _16>; // (q,v,k) + using ShapeOut = Shape<_1, _16>; // (q,v) + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 64 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOut = Shape<_128, _64>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOut = Shape<_128, _96>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOut = Shape<_128, _128>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOut = Shape<_256, _192>; + using SubgroupLayoutQK = Layout>; + +#endif +#elif defined(DECODE) +#if HEAD_DIM == 16 + /* Tiny config for testing */ + using ShapeQK = Shape<_1, _16, _16>; // (q,k,d) + using ShapePV = Shape<_1, _16, _16>; // (q,v,k) + using ShapeOut = Shape<_1, _16>; // (q,v) + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 64 + using ShapeQK = Shape<_1, _512, _64>; + using ShapePV = Shape<_1, _32, _512>; + using ShapeOut = Shape<_1, _64>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_1, _512, _64>; + using ShapePV = Shape<_1, _32, _512>; + using ShapeOut = Shape<_1, _96>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_1, _512, _64>; + using ShapePV = Shape<_1, _32, _512>; + using ShapeOut = Shape<_1, _128>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_1, _512, _64>; + using ShapePV = Shape<_1, _32, _512>; + using ShapeOut = Shape<_1, _192>; + using SubgroupLayoutQK = Layout>; +#endif +#else +#error Either DECODE or PREFILL should be defined. +#endif + +#ifdef DECODE + constexpr int PipelineStages = 1; +#else + constexpr int PipelineStages = 2; +#endif + + return FMHAConfig::run(options); +} diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 39752da4ed..50c0a316b5 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -32,7 +32,17 @@ set(CUTLASS_APPLICATIONS_DIR ${CMAKE_SOURCE_DIR}/applications) set(TEST_NO_PAGED "") set(TEST_PAGED "--use_paged_kv") -foreach(HEAD_DIM 64 96 128 192) +foreach(HEAD_DIM 16 64 96 128 192) + + cutlass_example_add_executable( + 06_xe_fmha_fwd_prefill_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + + cutlass_example_add_executable( + 06_xe_fmha_fwd_decode_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) cutlass_example_add_executable( 06_bmg_prefill_attention_hdim${HEAD_DIM} @@ -72,4 +82,6 @@ foreach(HEAD_DIM 64 96 128 192) target_compile_definitions(06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) + target_compile_definitions(06_xe_fmha_fwd_prefill_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1) + target_compile_definitions(06_xe_fmha_fwd_decode_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1) endforeach() diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp new file mode 100644 index 0000000000..0715995d33 --- /dev/null +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -0,0 +1,617 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp" +#include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include +#include + +#include "helper.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" + +#include + +using namespace cute; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool is_causal; + bool varlen = false; + std::string scheduler; + + int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; + float softmax_scale; + + Options() + : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), + seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("is_causal")) { + is_causal = true; + } + + if (cmd.check_cmd_line_flag("varlen")) { + varlen = true; + } + + cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); + + cmd.get_cmd_line_argument("batch", batch, 32); + cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); + cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q); + cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, 512); +#ifdef DECODE + cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 1); +#else + cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, seq_len_kv); +#endif + cmd.get_cmd_line_argument("head_size_vo", head_size_vo, HEAD_DIM); + cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + softmax_scale = 1 / sqrt(static_cast(head_size_qk)); + } + + /// Prints the usage statement. + std::ostream &print_usage(std::ostream &out) const { + + out << "Xe FMHA Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --is_causal Apply Causal Mask to the output of first Matmul\n" + << " --varlen Enable variable sequence length\n" + << " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n" + << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" + << " --num_heads_q= Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n" + << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" + << " --seq_len_qo= Sets the Sequence length of the Query input in Multi-Head Self Attention module\n" + << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" + << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" + << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Helpers + +template +void convert_tensor(const SrcT* d_src, DstT* d_dst, size_t size) { + compat::get_default_queue().parallel_for(size, [=](auto indx) { + d_dst[indx] = static_cast(d_src[indx]); + }).wait(); +} + +template inline auto in_memory(cutlass::DeviceAllocation& in) { + using OutT = cute::conditional_t<(sizeof_bits_v <= 8), half_t, InT>; + if constexpr (!is_same_v) { + cutlass::DeviceAllocation out(in.size()); + convert_tensor(in.get(), out.get(), in.size()); + return out; + } else { + return in; + }; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// 3 input matrices: (K)eys, (Q)ueries and (V)alues. +using LayoutQ = cutlass::layout::RowMajor; +using LayoutK = cutlass::layout::ColumnMajor; +using LayoutV = cutlass::layout::RowMajor; +using LayoutO = cutlass::layout::RowMajor; + +template struct ExampleRunner { + + using StrideQ = typename FMHAKernel::StrideQ; + using StrideK = typename FMHAKernel::StrideK; + using StrideV = typename FMHAKernel::StrideV; + using StrideO = typename FMHAKernel::StrideO; + + using ElementQ = typename FMHAKernel::ElementQ; + using ElementK = typename FMHAKernel::ElementK; + using ElementV = typename FMHAKernel::ElementV; + using ElementO = typename FMHAKernel::ElementO; + + using CollectiveMainloop = typename FMHAKernel::CollectiveMainloop; + using ElementS = typename CollectiveMainloop::ElementS; + + using ProblemShapeType = typename FMHAKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_ref_O; + + // + // Methods + // + + bool verify(ProblemShapeType shape, bool is_causal) { + + auto batch = shape.batch; + auto num_heads_q = shape.num_heads_q; + auto num_heads_kv = shape.num_heads_kv; + auto head_size_qk = shape.head_size_qk; + auto head_size_vo = shape.head_size_vo; + auto seq_len_qo = shape.seq_len_qo; + auto seq_len_kv = shape.seq_len_kv; + + auto block_Q_ = in_memory(block_Q); + auto block_K_ = in_memory(block_K); + auto block_V_ = in_memory(block_V); + + using ElementV_ = std::remove_pointer_t; + + int offset_q = 0; + int offset_k = 0; + int offset_v = 0; + int offset_o = 0; + + // loop over the batch dimension to compute the output + // to avoid the risk of running out of device memory + int q_group_size = num_heads_q/num_heads_kv; + for (int b = 0; b < batch; b++) { + int kv_group_update=1; + for (int h = 0; h < num_heads_q; h++) { + cutlass::DeviceAllocation block_S; + block_S.reset(seq_len_qo * seq_len_kv); + + cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); + cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); + cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, + cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, + 0.f, ref_S, ref_S, ElementS(0), + 1, // batch_count + seq_len_qo * head_size_qk, // batch_stride_Q + seq_len_kv * head_size_qk, // batch_stride_K + seq_len_qo * seq_len_kv, // batch_stride_S + seq_len_qo * seq_len_kv // batch_stride_S + ); + + compat::wait(); + + std::vector host_S(block_S.size()); + compat::memcpy(host_S.data(), block_S.get(), host_S.size()); + + // delete this memory as it is no longer needed + block_S.reset(); + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + if (is_causal) { + // apply mask to S + for (int row = 0; row < seq_len_qo; row++) { + for (int col = 0; col < seq_len_kv; col++) { + if ((col - full_tile_offset) > (row - discard_seq_coord)) + host_S[col + row * seq_len_kv] = ElementS{-INFINITY}; + } + } + } + + // compute max element per row of S + std::vector max_vec(seq_len_qo, ElementS{-INFINITY}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int max_idx = row; + max_vec[max_idx] = host_S[idx++]; + for (int col = 1; col < seq_len_kv; col++, idx++) { + if (max_vec[max_idx] < host_S[idx]) + max_vec[max_idx] = host_S[idx]; + } + } + + // compute exp of S + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int max_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + /* FIXME: use softmax_scale instead of assuming its value here */ + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); + } + } + + // compute sum per row of S + std::vector sum_vec(seq_len_qo, ElementS{0}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int sum_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + sum_vec[sum_idx] += host_S[idx]; + } + + // scale each row with the sum to compute softmax + idx = row * seq_len_kv; + sum_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + if(is_causal && row < discard_seq_coord) { + host_S[idx] = 0; + } else { + host_S[idx] /= sum_vec[sum_idx]; + } + } + } + + std::vector host_P(host_S.size()); + for (int p = 0; p < host_P.size(); p++) + host_P[p] = static_cast(host_S[p]); + + cutlass::DeviceAllocation block_P; + block_P.reset(host_P.size()); + + compat::memcpy(block_P.get(), host_P.data(), host_P.size()); + + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementS{1}, ref_P, + cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, + ElementS{0}, ref_acc, ref_acc, ElementS{0}, + 1, // batch_count + seq_len_qo * seq_len_kv, // batch_stride_P + seq_len_kv * head_size_vo, // batch_stride_V + seq_len_qo * head_size_vo, // batch_stride_O + seq_len_qo * head_size_vo // batch_stride_O + ); + + compat::wait(); + // delete this memory as it is no longer needed + block_P.reset(); + + std::vector vec_acc(block_acc.size()); + compat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + + // delete this memory as it is no longer needed + block_acc.reset(); + std::vector vec_out(vec_acc.size()); + for(int i = 0; i < vec_out.size(); i++) { + vec_out[i] = static_cast(vec_acc[i]); + } + compat::memcpy(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); + + offset_q += seq_len_qo * head_size_qk; + if(kv_group_update % q_group_size==0) { + offset_k += seq_len_kv * head_size_qk; + offset_v += seq_len_kv * head_size_vo; + } + kv_group_update++; + offset_o += seq_len_qo * head_size_vo; + } + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), + block_O.size(), ElementO{0.005}, ElementO{0.005}); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + ProblemShapeType initialize(const Options &options) { + ProblemShapeType shape; + auto batch = shape.batch = options.batch; + auto num_heads_q = shape.num_heads_q = options.num_heads_q; + auto num_heads_kv = shape.num_heads_kv = options.num_heads_kv; + auto seq_len_qo = shape.seq_len_qo = options.seq_len_qo; + auto seq_len_kv = shape.seq_len_kv = options.seq_len_kv; + auto head_size_qk = shape.head_size_qk = options.head_size_qk; + auto head_size_vo = shape.head_size_vo = options.head_size_vo; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch)); + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch)); + + block_Q.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_qk); + block_K.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_qk); + block_V.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_vo); + block_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + + initialize_block(block_Q, seed + 2023); + initialize_block(block_K, seed + 2022); + initialize_block(block_V, seed + 2021); + +#if 1 + auto init1 = [](auto &block, auto stride, const char* envMN, const char* envK, const char* envB) { + if (!getenv(envMN) || !getenv(envK)) + return; + auto ptr = block.get(); + auto mn = atoi(getenv(envMN)); + auto k = atoi(getenv(envK)); + bool bfill = getenv(envB) && atoi(getenv(envB)); + auto idx = mn * get<0>(stride) + k * get<1>(stride); + using T = cute::remove_cvref_t; + + sycl::queue Q; + Q.parallel_for(sycl::range<1>(block.size()), [=](auto id) { + if (id == 0) printf("Filling @ %d\n", int(idx)); + ptr[id] = T(bfill ? (id <= idx) : (id == idx)); + }).wait(); + }; + + init1(block_Q, stride_Q, "QS", "QH", "QX"); + init1(block_K, stride_K, "KS", "KH", "KX"); + init1(block_V, stride_V, "VH", "VS", "VX"); +#endif + + return shape; + } + + // Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this + // secondary `run` function is required to launch the kernel. + static void run(typename FMHAKernel::Params params) + { + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + dim3 const block = FMHAKernel::get_block_shape(); + dim3 const grid = FMHAKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + + // Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension + compat::experimental::launch_properties launch_props { + syclex::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + syclex::sub_group_size, + intelex::grf_size<256> + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch>(policy, params); + + EventManager::getInstance().addEvent(event); + } + + cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) { + + ProblemShapeType shape = initialize(options); + + typename FMHAKernel::Arguments arguments{ + { + shape, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O + }, + {options.softmax_scale}, + {}, + hw_info + }; + + // Define device-global scratch memory + size_t workspace_size = FMHAKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << + options.seq_len_qo << 'x' << options.seq_len_kv << 'x' << options.head_size_qk << 'x' << options.head_size_vo + << (options.is_causal ? "xCausal" : "xNonCausal") << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + CUTLASS_CHECK(FMHAKernel::initialize_workspace(arguments, workspace.get())); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the GEMM + run(params); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(shape, options.is_causal); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + return cutlass::Status::kErrorInternal; + } + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + run(params); + } + compat::wait(); + // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. + // Following changes will adjust the effective_seq_len_kv when masking applied for such cases + auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); + auto discard_seq_coord = options.seq_len_qo - offset; + auto full_tile_offset = options.seq_len_kv - offset; + // offset + 1 is going to be ceil_div + auto effective_seq_len_kv = options.is_causal ? full_tile_offset + ((offset + 1) / 2.0): options.seq_len_kv; + auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo; + double cute_time = timer.seconds() / options.iterations; + double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk; + double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv; + double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; + double gbps_qk = options.batch * (sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo * options.head_size_qk + + sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv * options.head_size_qk); + double gbps_pv = sizeof(ElementV) * options.batch * options.num_heads_kv * effective_seq_len_kv * options.head_size_vo + + sizeof(ElementO) * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo; + double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); + std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo + << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo + << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") + << "\t Scheduler: " << options.scheduler; + printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +template default */ + int PipelineStages, + typename ElementQ = bfloat16_t, + typename ElementK = bfloat16_t, + typename ElementV = bfloat16_t, + typename ElementO = float, + typename MMAOperation_ = void, /* void -> default */ + typename StrideQ = Stride, + typename StrideK = Stride, + typename StrideV = Stride<_1, int, int, int>, + typename StrideO = Stride, + typename GmemTiledCopyQ = void, /* void -> default block 2D */ + typename GmemTiledCopyK = void, + typename GmemTiledCopyV = void, + typename GmemTiledCopyO = void> +struct FMHAConfig { + + static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); + using MMAOperation = cute::conditional_t, + XE_DPAS_TT, + MMAOperation_>; + using SubgroupLayoutPV = cute::conditional_t, + decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), + SubgroupLayoutPV_>; + + template + static int run(const Options &options) { + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + + using TiledMMAQK = typename TiledMMAHelper, Layout, SubgroupLayoutQK>::TiledMMA; + using TiledMMAPV = typename TiledMMAHelper, Layout, SubgroupLayoutPV>::TiledMMA; + + static_assert(get<0>(TileShapeOutput{}) == get<0>(TileShapePV{}), + "Output tile and P*V tile have different sizes in Q dimension"); + constexpr int VTiles = get<1>(TileShapeOutput{}) / get<1>(TileShapePV{}); + + auto make_dummy_tensor = [&](auto val, auto stride) { + return make_tensor(make_gmem_ptr(&val), + make_layout(repeat>(1), stride)); + }; + + using TensorQ = decltype(make_dummy_tensor(ElementQ{}, StrideQ{})); + using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); + using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); + using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideO{})); + + // Mainloop + using MainloopDispatchPolicy = cutlass::fmha::XeDefault; + using CollectiveMainloop = cutlass::fmha::collective::FMHAFwdMainloop< + MainloopDispatchPolicy, Causal, + TiledMMAQK, TiledMMAPV, VTiles, + TensorQ, TensorK, TensorV, + GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV + >; + + // Epilogue + using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue< + CollectiveMainloop, + TileShapeOutput, + TensorO, + GmemTiledCopyO + >; + + using FMHAKernel = cutlass::fmha::kernel::XeFMHAFwdKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler + >; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + return 0; + } + + static int run(const Options &options) { + return run(options); + } +}; diff --git a/include/cute/algorithm/subgroup_algorithms.hpp b/include/cute/algorithm/subgroup_algorithms.hpp new file mode 100644 index 0000000000..27694ad077 --- /dev/null +++ b/include/cute/algorithm/subgroup_algorithms.hpp @@ -0,0 +1,147 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" +#include "cute/util/sycl_vec.hpp" + +namespace cute { + +// Uniformize a value, in case the compiler cannot prove it is subgroup-uniform. +template +CUTE_HOST_DEVICE +T +assert_uniform(T x) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + return group_broadcast(sg, x, 0); +} + +// Set a value in a single work-item -- x[i] = val. +// WARNING: i _must_ be a compile-time constant. +// No diagnostics/error will be issued by the compiler if it is not. +template +CUTE_HOST_DEVICE void +set_wi_value(T &x, int i, T val) +{ +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) + asm ( + "mov (M1_NM, 1) %0(0,%2)<1> %1(0,0)<1;1,0>" + : "+rw"(x) + : "rw.u"(val), "P"(i) + ); +#else + int lane = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_id()[0]; + if (lane == i) + x = val; +#endif +} + +// Set an element of a 1D SG-shared fragment x. +// WARNING: i _must_ be a compile-time constant. +// No diagnostics/error will be issued by the compiler if it is not. +template +CUTE_HOST_DEVICE void +set_single_value(FragX& x, int i, typename FragX::element_type val) { + set_wi_value(x(i / intel::sg_size), i % intel::sg_size, val); +} + +// Broadcast the element from a 1D SG-shared fragment x +// corresponding to the Mode'th dimension of the logical coordinates of src(val). +template ::value)> +CUTE_HOST_DEVICE +constexpr auto +broadcast(FragX const& x, SGTensorSrc const& src, int val) +{ + auto coord = src.tv_layout()(0, val); + auto coord_i = get(coord); + + constexpr auto TMode = rank(as_arithmetic_tuple(stride<0>(SGTensorSrc{}.tv_layout()))) - 1; + if constexpr (TMode == Mode) { + return x(coord_i / intel::sg_size); + } else { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + return group_broadcast(sg, x(coord_i / intel::sg_size), coord_i % intel::sg_size); + } +} + +// Subgroup-cooperative reduction of a SubgroupTensor. +template +CUTE_HOST_DEVICE +auto +reduce(SubgroupTensor const& src, BinaryOp op) +{ + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + using T = typename Engine::value_type; + using TVToV = Layout, Stride<_0,_1>>; + + /* Retrieve logical coordinate -> (T,V) mapping */ + constexpr auto shape = atuple_coshape(SubgroupTVLayout{}); + constexpr auto coord_to_tv = right_inverse(project_strides(SubgroupTVLayout{})).with_shape(shape); + + /* Move reduction coordinate to mode-0 and group the rest in mode-1. Then, remove work-item modes. */ + constexpr auto rcoord_to_tv = make_layout(select(coord_to_tv), remove(coord_to_tv)); + constexpr auto rcoord_to_v = filter(composition(TVToV{}, rcoord_to_tv), Step<_1,_1>{}); + + /* Regroup input tensor */ + Tensor src_r = make_tensor(src.data(), rcoord_to_v); + + /* Create output tensor */ + Shape rshape = replace(shape, _1{}); + Tensor out = make_subgroup_tensor(make_tensor(ceil_div(size(rshape), intel::_SGSize{})), + make_identity_layout(rshape)); + + /* Check for reduction type */ + constexpr bool horizontal = (size<0>(rcoord_to_tv) == intel::_SGSize{} * size<0>(rcoord_to_v)); + constexpr bool vertical = (size<1>(rcoord_to_tv) == intel::_SGSize{} * size<1>(rcoord_to_v)); + + CUTE_UNROLL + for (int j = 0; j < size<1>(rcoord_to_v); j++) { + T acc = src_r(0, j); + CUTE_UNROLL + for (int i = 1; i < size<0>(rcoord_to_v); i++) { + acc = op(acc, src_r(i, j)); + } + + if constexpr (horizontal) + set_single_value(out, j, reduce_over_group(sg, acc, op)); // TODO: optimize vector usage + else if constexpr (vertical) + out(j) = acc; + else + static_assert("Unimplemented reduction type"); + } + + return out; +} + +} // namespace cute diff --git a/include/cute/arch/copy_xe_legacy.hpp b/include/cute/arch/copy_xe_legacy.hpp index a414885033..fc3da90055 100644 --- a/include/cute/arch/copy_xe_legacy.hpp +++ b/include/cute/arch/copy_xe_legacy.hpp @@ -47,25 +47,4 @@ #include #include -// FIXME: these are not copy-related and should be declared elsewhere. -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); -SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); -#endif - -namespace cute -{ - -// scope = 3 is for subgroup, scop = 2 is for workgroup -CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = 0, int memory_semantics = 0) { -#ifdef __SYCL_DEVICE_ONLY__ - __spirv_ControlBarrierArriveINTEL(scope, memory_scope, memory_semantics); -#endif -} -CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = 0, int memory_semantics = 0) { -#ifdef __SYCL_DEVICE_ONLY__ - __spirv_ControlBarrierWaitINTEL(scope, memory_scope, memory_semantics); -#endif -} - -} // end namespace cute +#include \ No newline at end of file diff --git a/include/cute/arch/reorder.hpp b/include/cute/arch/reorder.hpp index a2caa033f0..dd0e32e193 100644 --- a/include/cute/arch/reorder.hpp +++ b/include/cute/arch/reorder.hpp @@ -41,7 +41,7 @@ struct Universal_Reorder_UU { CUTE_HOST_DEVICE static void reorder(SrcType const& src0, DstType& dst0) { - dst0 = src0; + dst0 = DstType(src0); } }; diff --git a/include/cute/arch/reorder_xe.hpp b/include/cute/arch/reorder_xe.hpp index 42e701a4ce..e7c7cd7b06 100644 --- a/include/cute/arch/reorder_xe.hpp +++ b/include/cute/arch/reorder_xe.hpp @@ -1236,7 +1236,61 @@ struct Xe_Reorder } }; +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::float4[1]; + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::float4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 2 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,1)<2> IN_UB(0,0)<1;1,0> 7:uw\n" + "shl (M1_NM, 32) OUT_UW(2,1)<2> IN_UB(0,32)<1;1,0> 7:uw\n" + "add.sat (M1_NM, 32) OUT_UW(0,0)<2> IN_UB(0,0)<1;1,0> -254:w\n" + "add.sat (M1_NM, 32) OUT_UW(2,0)<2> IN_UB(0,32)<1;1,0> -254:w\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; +/****************/ +/* Downconverts */ +/****************/ + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::float2[1]; + using DRegisters = intel::ushort2[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::float2 const& src0, intel::ushort2& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 2 cycles/output register */ + "{\n" + ".decl IN_F v_type=G type=F num_elts=32 alias=<%1,0>\n" + ".decl OUT_BF v_type=G type=BF num_elts=32 alias=<%0,0>\n" + "mov (M1_NM, 32) OUT_BF(0,0)<1> IN_F(0,0)<1;1,0>\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; } // end namespace cute diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp index 2df8ae0a38..a432e72ad4 100644 --- a/include/cute/atom/copy_traits_xe_2d.hpp +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -56,7 +56,7 @@ namespace cute { // Utility to check if a layout belongs to a coordinate tensor. template -static constexpr bool is_counting_layout_v = is_arithmetic_tuple_like::value; +static constexpr bool is_counting_layout_v = is_arithmetic_tuple_like::value || is_constant_v<1, decltype(size(Layout{}))>; @@ -724,12 +724,11 @@ block_2d_selector(CoordLayout const&, GlobalStride const&) } // Helper for make_block_2d_copy_* routines -template CUTE_HOST_DEVICE auto make_block_2d_copy_X(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance Stride const& gstride, // Global memory strides XMode const& x_mode, // x, y modes YMode const& y_mode, @@ -826,7 +825,7 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation make_tile(sg_to_vmk, _)); // (SG,V) -> (M,K) // Derive copy tile layout and create TiledCopy - return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_mk, svA); + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_mk, svA); } template @@ -887,7 +886,7 @@ make_block_2d_copy_B(CopyOp const& op, // Copy operation auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) auto drop_m = make_layout(shape_vmnk, - make_stride(_1{}, _0{}, get<0>(shape_vmnk), _0{}, + make_stride(_1{}, _0{}, get<0>(shape_vmnk), get<0>(shape_vmnk) * get<2>(shape_vmnk))); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrN,ThrK) auto thr_to_vnk = composition(drop_m, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrN,ThrK) @@ -898,7 +897,7 @@ make_block_2d_copy_B(CopyOp const& op, // Copy operation make_tile(sg_to_vnk, _)); // (SG,V) -> (N,K) // Derive copy tile layout and create TiledCopy - return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_nk, svB); + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_nk, svB); } template @@ -971,7 +970,132 @@ make_block_2d_copy_C(CopyOp const& op, // Copy operation make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N) // Derive copy tile layout and create TiledCopy - return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_mn, svC); + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_mn, svC); +} + +// Variants of make_block_2d_copy_C where the C tile is further subdivided by the user. +// (e.g. split-k parallelization). + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(TiledMMA const& mma, // TiledMMA instance + SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_C_subtiled(mma, stv_layout, ssg_layout, gmem.stride()).with(gmem); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + SubtileShape const& sshape, // Subtile shape: (m,n) + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_C_subtiled(op, sshape, ssg_layout, mma, gmem.stride()).with(gmem); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(TiledMMA const& mma, // TiledMMA instance + SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride) // Global memory strides +{ + using MMAType = typename TiledMMA::ValTypeA; + auto cC = make_identity_tensor(select<0,1>(mma.tile_mnk())); + auto op = block_2d_selector(stv_layout, gstride); + return make_block_2d_copy_C_subtiled(op, mma, atuple_coshape(stv_layout), ssg_layout, gstride); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + SubtileShape const& sshape, // Subtile shape: (m,n) + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_C_subtiled(op, mma, sshape, ssg_layout, gstride, + find_x_mode(gstride), find_y_mode(gstride)); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + SubtileShape const& sshape, // Subtile shape: (m,n) + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) +{ + // Expand subtile layout. + auto xssg_layout = make_layout(shape(ssg_layout), + elem_scale(stride(ssg_layout), sshape)); // SG_K -> (M,N) + + // Retrieve MMA atom's (subgroup, value) -> (M,N) layout. + // Allow cross-MMA tiling. + auto tile_mn = round_up(select<0,1>(mma.tile_mnk()), + atuple_coshape(xssg_layout)); + + auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr + auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) + auto drop_k = replace<3>(make_layout(shape_vmnk), + make_layout(get<3>(shape_vmnk), _0{})); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrM,ThrN) + + auto thr_to_vmn = composition(drop_k, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrM,ThrN) + auto sg_to_vmn = composition(thr_to_vmn, + make_layout(product(take<1,4>(shape_vmnk)), get<0>(shape_vmnk))); // SG -> (0,ThrM,ThrN) + + auto svC = composition(mma.thrfrg_C(make_layout(tile_mn)), + make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N) + + // Add subtile modes. Limitations: + // - ThrK must be covered by a single mode in svC. + // - SubtileSGLayout must have a subtile for each ThrK, OR ThrK must be the last mode. + decltype(coalesce(get<0>(svC))) sC{}; + constexpr auto mode_thr_k = find_if(stride(sC), [](auto const &x) { return C>{}; }); + static_assert(shape(sC) == shape<3>(thr_vmnk), "ThrK split into multiple modes; unsupported"); + + auto k_to_mn = composition(make_layout(tile_mn), xssg_layout); // ThrK -> (M,N) + + static_assert(size(SubtileSGLayout{}) == shape<3>(thr_vmnk) || mode_thr_k + 1 >= rank(sC), + "Unsupported partially occupied ThrK scenario"); + + // Remove subtile value modes. + auto drop_subtiles = make_layout(zip(sshape, shape_div(tile_mn, sshape)), + zip(stride(make_layout(tile_mn)), Stride<_0,_0>{})); + + auto svC_tiled = make_layout(replace(sC, k_to_mn), + coalesce(composition(drop_subtiles, get<1>(svC)))); + + // Derive copy tile layout and create TiledCopy + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_mn, svC_tiled); } // Prefetch selection and creation. @@ -1072,7 +1196,7 @@ make_block_2d_prefetch(PrefetchOp const& op, Int{}); // Tile atom grid across collective op tile. - auto sv_layout = zipped_divide(make_layout(collective_op_tile), atom_shape); + auto sv_layout = zipped_divide(make_layout(atom_shape), collective_op_tile); // Create the TiledCopy object. return make_block_2d_copy(op, stride, x_mode, y_mode, atom_shape, sv_layout); diff --git a/include/cute/atom/reorder_atom_xe.hpp b/include/cute/atom/reorder_atom_xe.hpp index c7bc546fce..99ecce48bc 100644 --- a/include/cute/atom/reorder_atom_xe.hpp +++ b/include/cute/atom/reorder_atom_xe.hpp @@ -115,8 +115,8 @@ constexpr ReorderKind classify_xe_reorder() template -constexpr auto choose_xe_reorder_impl(SLayout const& slayout, // (src thr, src val) -> coord - DLayout const& dlayout) { // (dst thr, dst val) -> coord +auto choose_xe_reorder_impl(SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) { // (dst thr, dst val) -> coord // Calculate data transformation, interleaving WI-owned values: // (thr0,val0) ... (thr15,val0), (thr0,val1), ..., (thr15,val1), ... auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 76446f0244..0b6da2e0a8 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -326,6 +326,9 @@ struct is_layout : false_type {}; template struct is_layout> : true_type {}; +template +static constexpr bool is_layout_v = is_layout::value; + // // Layout construction // @@ -682,8 +685,10 @@ CUTE_HOST_DEVICE constexpr auto atuple_coshape(Layout const& layout) { + auto _0E0 = ScaledBasis,0>{}; auto flayout = filter(flatten(layout)); - return inner_product_atuple_max(shape(flayout), stride(flayout)); + auto coshape = inner_product_atuple_max(shape(flayout), stride(flayout)) + _0E0 + _0E0; + return cute::transform(coshape, [](auto a) { return cute::max(a, _1{}); }); } // Return the codomain size of a mode @@ -1062,6 +1067,15 @@ group(Layout const& layout) group(layout.stride())); } +template +CUTE_HOST_DEVICE constexpr +auto +remove(Layout const& layout) +{ + return make_layout(remove(layout.shape()), + remove(layout.stride())); +} + // // Composition of two layouts: lhs o rhs // @post compatible(rhs, result) diff --git a/include/cute/tensor_sg.hpp b/include/cute/tensor_sg.hpp index b128bf4e13..288296abae 100644 --- a/include/cute/tensor_sg.hpp +++ b/include/cute/tensor_sg.hpp @@ -74,7 +74,7 @@ struct SubgroupTensor : Tensor *this = static_cast(base); } - static constexpr int rank = Layout::rank; + static constexpr int rank = Layout::rank; CUTE_HOST_DEVICE constexpr decltype(auto) @@ -89,13 +89,18 @@ struct SubgroupTensor : Tensor } }; +template +struct is_sg_tensor : false_type {}; +template +struct is_sg_tensor> : true_type {}; + template struct is_tensor> : true_type {}; -template::value)> +template ::value)> CUTE_HOST_DEVICE constexpr auto make_subgroup_tensor(Tensor const& tensor, SubgroupTVLayout const&) @@ -105,6 +110,43 @@ make_subgroup_tensor(Tensor const& tensor, SubgroupTVLayout cons return static_cast const&>(tensor); } +template +CUTE_HOST_DEVICE +constexpr auto +make_subgroup_tensor(Layout const& sg_layout) +{ + return make_subgroup_tensor(make_fragment_like(sg_layout(0,_)), sg_layout); +} + +template +CUTE_HOST_DEVICE +constexpr auto +make_subgroup_tensor(Args const&... args) +{ + return make_subgroup_tensor(make_layout(args...)); +} + + +// Replicate a subgroup fragment in a given mode. +template +CUTE_HOST_DEVICE +constexpr auto +expand_sg_fragment_helper(SubgroupTensor const&) +{ + constexpr SubgroupTensor frag; + constexpr int ModeSize = get(atuple_coshape(frag.tv_layout())); + + auto xlayout = append(frag.layout(), + Layout, C>>{}); + auto xv_layout = append(get<1>(frag.tv_layout()), + make_layout(C{}, C{} * E{})); + auto xtv_layout = make_layout(get<0>(frag.tv_layout()), xv_layout); + + return make_subgroup_tensor(make_tensor(xlayout), xtv_layout); +} + +template +using expand_sg_fragment_t = decltype(expand_sg_fragment_helper(SGTensor{})); // // Display utilities diff --git a/include/cute/util/compat/memory.hpp b/include/cute/util/compat/memory.hpp index ed4de5c7bb..2f6c58846f 100644 --- a/include/cute/util/compat/memory.hpp +++ b/include/cute/util/compat/memory.hpp @@ -146,40 +146,6 @@ class pitched_data { size_t _pitch, _x, _y; }; -namespace experimental { -#ifdef SYCL_EXT_ONEAPI_BINDLESS_IMAGES -class image_mem_wrapper; -namespace detail { -static sycl::event memcpy(const image_mem_wrapper *src, - const sycl::id<3> &src_id, pitched_data &dest, - const sycl::id<3> &dest_id, - const sycl::range<3> ©_extend, sycl::queue q); -static sycl::event memcpy(const pitched_data src, const sycl::id<3> &src_id, - image_mem_wrapper *dest, const sycl::id<3> &dest_id, - const sycl::range<3> ©_extend, sycl::queue q); -} // namespace detail -#endif -class image_matrix; -namespace detail { -static pitched_data to_pitched_data(image_matrix *image); -} - -/// Memory copy parameters for 2D/3D memory data. -struct memcpy_parameter { - struct data_wrapper { - pitched_data pitched{}; - sycl::id<3> pos{}; -#ifdef SYCL_EXT_ONEAPI_BINDLESS_IMAGES - experimental::image_mem_wrapper *image_bindless{nullptr}; -#endif - image_matrix *image{nullptr}; - }; - data_wrapper from{}; - data_wrapper to{}; - sycl::range<3> size{}; -}; -} // namespace experimental - namespace detail { class mem_mgr { mem_mgr() { @@ -856,56 +822,6 @@ static sycl::accessor get_access(const void *ptr, } } -namespace experimental { -namespace detail { -static inline std::vector -memcpy(sycl::queue q, const experimental::memcpy_parameter ¶m) { - auto to = param.to.pitched; - auto from = param.from.pitched; -#ifdef SYCL_EXT_ONEAPI_BINDLESS_IMAGES - if (param.to.image_bindless != nullptr && - param.from.image_bindless != nullptr) { - throw std::runtime_error( - "[Compat] memcpy: Unsupported bindless_image API."); - // TODO: Need change logic when sycl support image_mem to image_mem copy. - std::vector event_list; - compat::detail::host_buffer buf(param.size.size(), q, event_list); - to.set_data_ptr(buf.get_ptr()); - experimental::detail::memcpy(param.from.image_bindless, param.from.pos, to, - sycl::id<3>(0, 0, 0), param.size, q); - from.set_data_ptr(buf.get_ptr()); - event_list.push_back(experimental::detail::memcpy( - from, sycl::id<3>(0, 0, 0), param.to.image_bindless, param.to.pos, - param.size, q)); - return event_list; - } else if (param.to.image_bindless != nullptr) { - throw std::runtime_error( - "[Compat] memcpy: Unsupported bindless_image API."); - return {experimental::detail::memcpy(from, param.from.pos, - param.to.image_bindless, param.to.pos, - param.size, q)}; - } else if (param.from.image_bindless != nullptr) { - throw std::runtime_error( - "[Compat] memcpy: Unsupported bindless_image API."); - return {experimental::detail::memcpy(param.from.image_bindless, - param.from.pos, to, param.to.pos, - param.size, q)}; - } -#endif - if (param.to.image != nullptr) { - throw std::runtime_error("[Compat] memcpy: Unsupported image API."); - to = experimental::detail::to_pitched_data(param.to.image); - } - if (param.from.image != nullptr) { - throw std::runtime_error("[Compat] memcpy: Unsupported image API."); - from = experimental::detail::to_pitched_data(param.from.image); - } - return compat::detail::memcpy(q, to, param.to.pos, from, param.from.pos, - param.size); -} -} // namespace detail -} // namespace experimental - /// Allocate memory block on the device. /// \param num_bytes Number of bytes to allocate. /// \param q Queue to execute the allocate task. @@ -1240,31 +1156,6 @@ static sycl::event inline fill_async(void *dev_ptr, const T &pattern, return detail::fill(q, dev_ptr, pattern, count); } -namespace experimental { - -/// [UNSUPPORTED] Synchronously copies 2D/3D memory data specified by \p param . -/// The function will return after the copy is completed. -/// -/// \param param Memory copy parameters. -/// \param q Queue to execute the copy task. -/// \returns no return value. -static inline void memcpy(const memcpy_parameter ¶m, - sycl::queue q = get_default_queue()) { - sycl::event::wait(compat::experimental::detail::memcpy(q, param)); -} - -/// [UNSUPPORTED] Asynchronously copies 2D/3D memory data specified by \p param -/// . The return of the function does NOT guarantee the copy is completed. -/// -/// \param param Memory copy parameters. -/// \param q Queue to execute the copy task. -/// \returns no return value. -static inline void memcpy_async(const memcpy_parameter ¶m, - sycl::queue q = get_default_queue()) { - compat::experimental::detail::memcpy(q, param); -} -} // namespace experimental - namespace { /// Synchronously sets \p value to the first \p size bytes starting from \p /// dev_ptr. The function will return after the memset operation is completed. diff --git a/include/cute/util/xe_split_barrier.hpp b/include/cute/util/xe_split_barrier.hpp new file mode 100644 index 0000000000..ad96f8df1f --- /dev/null +++ b/include/cute/util/xe_split_barrier.hpp @@ -0,0 +1,82 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +enum SPIRVScope { + ScopeCrossDevice = 0, + ScopeDevice = 1, + ScopeWorkgroup = 2, + ScopeSubgroup = 3, + ScopeInvocation = 4, +}; + +enum SPIRVMemorySemantics { + SemanticsNone = 0, + SemanticsAcquire = 0x2, + SemanticsRelease = 0x4, + SemanticsAcquireRelease = 0x8, + SemanticsSGMemory = 0x80, + SemanticsWGMemory = 0x100, + SemanticsCrossWGMemory = 0x200, +}; + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); +#endif + +namespace cute +{ + +CUTE_HOST_DEVICE void barrier_arrive(SPIRVScope scope, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierArriveINTEL(scope, scope, memory_semantics); +#endif +} +CUTE_HOST_DEVICE void barrier_wait(SPIRVScope scope, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierWaitINTEL(scope, scope, memory_semantics); +#endif +} + +CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = ScopeCrossDevice, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierArriveINTEL(scope, memory_scope, memory_semantics); +#endif +} +CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = ScopeCrossDevice, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierWaitINTEL(scope, memory_scope, memory_semantics); +#endif +} + +} // end namespace cute diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index 811ba152ab..faa3b4aaac 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -108,6 +109,54 @@ make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape return s_copy; } +// Strides with 2 batch modes. +// All this code should be replaced with a generic implementation. + +template +CUTLASS_HOST_DEVICE +auto +make_cute_packed_stride(cute::Stride,int,int> s, + cute::Shape shape) +{ + using namespace cute; + + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + + int batch_count0 = get<2>(shape); + int batch_count1 = get<3>(shape) * batch_count0; + + get<0>(s_copy) = static_cast(get<1>(shape)); + get<2>(s_copy) = (batch_count0 <= 1) ? 0 : product(take<0,2>(shape)); + get<3>(s_copy) = (batch_count1 <= 1) ? 0 : product(take<0,3>(shape)); + + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +auto +make_cute_packed_stride(cute::Stride,IntT,int,int> s, + cute::Shape shape) +{ + using namespace cute; + + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + + int batch_count0 = get<2>(shape); + int batch_count1 = get<3>(shape) * batch_count0; + + get<1>(s_copy) = static_cast(get<0>(shape)); + get<2>(s_copy) = (batch_count0 <= 1) ? 0 : product(take<0,2>(shape)); + get<3>(s_copy) = (batch_count1 <= 1) ? 0 : product(take<0,3>(shape)); + + return s_copy; +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Strides with group mode diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 5cae58e4f1..4a4c94882e 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -102,6 +102,9 @@ __global__ void Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { +#ifdef SHOW_DIFF + printf("[%zu]: %f vs %f\n", idx, (double) a, (double) b); +#endif *equal = 0; return; }