Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "fmha_fusion.hpp"
#include "xe_rotary.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -62,7 +63,7 @@ CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {

template <class DispatchPolicy, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_ = false>
struct FlashPrefillMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};
Expand All @@ -71,9 +72,9 @@ struct FlashPrefillMma {

template <int Stages, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_>
struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_,
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> {
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, RopeMask_> {
//
// Type Aliases
//
Expand All @@ -97,6 +98,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
using TiledMmaPV = typename TiledMMAHelper<MmaAtom, Layout<TileShapePV>, SubgroupLayout>::TiledMMA;
using ElementAccumulator = typename TiledMmaQK::ValTypeC;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool rope_enabled = RopeMask_;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using MmaAtomShape = typename MmaAtom::Shape_MNK;
Expand Down Expand Up @@ -158,12 +160,19 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
StrideK dK;
ElementV const *ptr_V;
StrideV dV;
// for RoPE case
ElementQ const *ptr_cos = nullptr;
ElementQ const *ptr_sin = nullptr;
};

struct Params {
XE_Copy_Q gmem_tiled_copy_q;
XE_Copy_K gmem_tiled_copy_k;
XE_Copy_V gmem_tiled_copy_v;
XE_Copy_Q gmem_tiled_copy_q_cos;
XE_Copy_Q gmem_tiled_copy_q_sin;
XE_Copy_K gmem_tiled_copy_k_cos;
XE_Copy_K gmem_tiled_copy_k_sin;
};

//
Expand All @@ -181,11 +190,21 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv), args.dV));

auto tensorQCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorQSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorKCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorKSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};
XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)};
XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)};

return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
Expand Down Expand Up @@ -372,11 +391,32 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k));
auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v));

auto q_traits_cos = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

auto tensorQCos = make_tensor(make_gmem_ptr(base_ptr_q_cos + offset_q), make_layout(shape_q, stride_q));
auto tensorQSin = make_tensor(make_gmem_ptr(base_ptr_q_sin + offset_q), make_layout(shape_q, stride_q));
auto tensorKCos = make_tensor(make_gmem_ptr(base_ptr_k_cos + offset_k), make_layout(shape_k, stride_k));
auto tensorKSin = make_tensor(make_gmem_ptr(base_ptr_k_sin + offset_k), make_layout(shape_k, stride_k));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};
XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)};
XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)};

return Params{copyQ, copyK, copyV};
return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin};
}
}
};
Expand Down
63 changes: 63 additions & 0 deletions applications/flash_attention_v2/collective/xe_rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/***************************************************************************************************
* Copyright (c) 2025 Intel Corporation. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include "cutlass/cutlass.h"

namespace cutlass::flash_attention::collective {
using namespace cute;

template <typename Tensor,
typename TensorCos, typename TensorSin, typename TensorOut>
CUTLASS_DEVICE void apply_rope_interleaved_gmem(
int thread_idx,
Tensor const &srcTensor,
TensorCos const &gCos,
TensorSin const &gSin, TensorOut &destTensor) {
if(thread_idx < size<0>(srcTensor)){
for (int j = 0; j < size<1>(gCos); j+=2) {
auto real = static_cast<float>(srcTensor[make_coord(thread_idx, j)]);
auto imag = static_cast<float>(srcTensor[make_coord(thread_idx, j + 1)]);
auto cos_val = static_cast<float>(gCos[make_coord(thread_idx, j)]);
auto sin_val = static_cast<float>(gSin[make_coord(thread_idx, j)]);

auto new_real = real * cos_val - imag * sin_val;
auto new_imag = real * sin_val + imag * cos_val;

destTensor[make_coord(thread_idx,j)] = static_cast<typename Tensor::value_type>(new_real);
destTensor[make_coord(thread_idx,j + 1)] = static_cast<typename Tensor::value_type>(new_imag);
}
}
syncthreads();
}


} // namespace cutlass::flash_attention::collective
131 changes: 131 additions & 0 deletions applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class FMHAPrefill {
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using traits_load_Q = typename CollectiveMainloop::traits_load_Q;
using traits_load_K = typename CollectiveMainloop::traits_load_K;

using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_;
using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments;
Expand Down Expand Up @@ -132,6 +134,7 @@ class FMHAPrefill {
using AccumeShape = decltype(make_shape(Int<Vec>{}, Int<FragsM>{}, get<1>(TileShapePV{})/get<1>(MmaAtomShape()), Int<VSlicer>{}));

static constexpr bool is_var_len = CollectiveMainloop::is_var_len;
static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled;

// Kernel level shared memory storage
struct SharedStorage {
Expand Down Expand Up @@ -272,10 +275,24 @@ class FMHAPrefill {
Tensor mK_nk = mK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)
Tensor mV_nk = mV_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)

Tensor mCosQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l)
Tensor mSinQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l)
Tensor mCosK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l)
Tensor mSinK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l)
Tensor mCosQ_mk = mCosQ_mkl(_, _, blk_l_coord); // (m,k)
Tensor mSinQ_mk = mSinQ_mkl(_, _, blk_l_coord); // (m,k)
Tensor mCosK_nk = mCosK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)
Tensor mSinK_nk = mSinK_nkl(_, _, blk_l_coord/group_heads_q);

auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
auto gV = local_tile(mV_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step<X, _1, _1>{});

auto gCosQ = local_tile(mCosQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gSinQ = local_tile(mSinQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gCosK = local_tile(mCosK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
auto gSinK = local_tile(mSinK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});

auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, sequence_length_shape, batch_coord);
// we limit the horisontal size to two subgroup, the empirical resutls show that reading the two cacheline side by side in gives better performance and
// anything after that does not have an effect on performance. // (64 here for float b float when possible and loop over to cover all the data needed)
Expand All @@ -289,6 +306,120 @@ class FMHAPrefill {
auto pKgK = thr_prefetch_K.partition_S(gK);
auto pVgV = thr_prefetch_V.partition_S(gV);

// RoPE coordinate tensor partitions
auto pCosQgCosQ = thr_prefetch_Q.partition_S(gCosQ);
auto pSinQgSinQ = thr_prefetch_Q.partition_S(gSinQ);
auto pCosKgCosK = thr_prefetch_K.partition_S(gCosK);
auto pSinKgSinK = thr_prefetch_K.partition_S(gSinK);

// for (int i = 0; i < size<3>(pQgQ); i++) {
// prefetch(tiled_prefetch_q, pQgQ(_, _, _, i));
// }
// for (int j = 0; j < size<4>(pKgK); j++) {
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < DispatchPolicy::Stages; i++) {
// prefetch(tiled_prefetch_k, pKgK(_, _, _ , i, j));
// }
// }

// for (int i = 0; i < size<3>(pQgQ); i++) {
// prefetch(tiled_prefetch_q, pCosQgCosQ(_, _, _, i));
// prefetch(tiled_prefetch_q, pSinQgSinQ(_, _, _, i));
// }
// for (int j = 0; j < size<4>(pKgK); j++) {
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < DispatchPolicy::Stages; i++) {
// prefetch(tiled_prefetch_k, pCosKgCosK(_, _, _ , i, j));
// prefetch(tiled_prefetch_k, pSinKgSinK(_, _, _ , i, j));
// }
// }

if constexpr (rope_enabled) {

int block_idx = static_cast<int>(BlockIdxX());
int block_idy = static_cast<int>(BlockIdxY());
int block_idz = static_cast<int>(BlockIdxZ());
int block_dimx = static_cast<int>(BlockDimX());
int block_dimy = static_cast<int>(BlockDimY());
int block_dimz = static_cast<int>(BlockDimZ());
int thread_idx = static_cast<int>(ThreadIdxX());
int thread_idy = static_cast<int>(ThreadIdxY());
int thread_idz = static_cast<int>(ThreadIdxZ());
int grid_dimx = static_cast<int>(GridDimX());
int grid_dimy = static_cast<int>(GridDimY());
int grid_dimz = static_cast<int>(GridDimZ());
int block_id = block_idx + block_idy * grid_dimx + block_idz * grid_dimx * grid_dimy;
int thread_id = block_id * block_dimx * block_dimy * block_dimz + thread_idz * block_dimx * block_dimy + thread_idy * block_dimx + thread_idx;


// calculate the base_ptr and offset for Q, K.
// also calculate the layout for Q, K.
// then apply RoPE on Q, K accordingly
auto [coord_q_x, coord_q_y, coord_q_z] = *gQ.data();
auto [coord_k_x, coord_k_y, coord_k_z] = *gK.data();

auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = params.problem_shape;

int offset_q = seq_len_qo*head_size_qk*coord_q_z + head_size_qk*coord_q_x + coord_q_y; // row major
// int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_y + coord_k_x; // col major
int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_x + coord_k_y; // row major

auto q_traits = static_cast<traits_load_Q const&>(mainloop_params.gmem_tiled_copy_q);
ElementQ* base_ptr_q = (ElementQ*)q_traits.base_ptr;

auto q_traits_cos = static_cast<traits_load_Q const&>(mainloop_params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(mainloop_params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

// auto layout_q = gQ.layout();
constexpr auto static_shape_q = make_shape(size<0>(gQ), size<1>(gQ));
// constexpr auto layout_q = LayoutQ::packed({size<0>(gQ), size<1>(gQ)});
constexpr auto layout_q = make_layout(static_shape_q, LayoutRight{});

auto k_traits = static_cast<traits_load_K const&>(mainloop_params.gmem_tiled_copy_k);
ElementK* base_ptr_k = (ElementK*)k_traits.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(mainloop_params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(mainloop_params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

constexpr auto static_shape_k = make_shape(size<0>(gK), size<1>(gK));
constexpr auto layout_k = make_layout(static_shape_k, LayoutRight{});

for (int i =0 ;i< size<2>(gQ); i++){
auto tensorQ = make_tensor(make_gmem_ptr(base_ptr_q+offset_q), layout_q);
auto tensorCosQ = make_tensor(make_gmem_ptr(base_ptr_q_cos+offset_q), layout_q);
auto tensorSinQ = make_tensor(make_gmem_ptr(base_ptr_q_sin+offset_q), layout_q);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorQ, tensorCosQ, tensorSinQ, tensorQ);
offset_q += QK_BLK_M*QK_BLK_K;
}
if (block_id%4==1){
offset_k += QK_BLK_N*QK_BLK_K;
} else if (block_id%4==2){
offset_k += 2*QK_BLK_N*QK_BLK_K;
} else if (block_id%4==3){
offset_k += 3*QK_BLK_N*QK_BLK_K;
}

for (int k =0 ;k< size<3>(gK); k++){
auto new_offset_k = offset_k;
for (int i =0 ;i< size<2>(gK); i+=4){
auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k);
auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k);
auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK);
new_offset_k += 4*QK_BLK_N*QK_BLK_K;
}
offset_k += size<2>(gK)*QK_BLK_N*QK_BLK_K;
}
barrier_arrive(2);
barrier_wait(2);
}

for (int i = 0; i < size<3>(pQgQ); i++) {
prefetch(tiled_prefetch_q, pQgQ(_, _, _, i));
}
Expand Down
Loading