Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {
/////////////////////////////////////////////////////////////////////////////////////////////////

template <class DispatchPolicy, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOp_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_,
class GmemTiledCopyK_, class GmemTiledCopyV_, bool CausalMask_, bool PagedKV_>
class ElementV_, class StrideV_, class ElementSink_, class MMAOp_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_,
class GmemTiledCopyK_, class GmemTiledCopyV_, bool CausalMask_, bool PagedKV_, bool HasSink_>
struct FlashDecodeMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <int Stages, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOp_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_,
class GmemTiledCopyQ_, class GmemTiledCopyK_, class GmemTiledCopyV_, bool CausalMask_, bool PagedKV_>
class ElementV_, class StrideV_, class ElementSink_, class MMAOp_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_,
class GmemTiledCopyQ_, class GmemTiledCopyK_, class GmemTiledCopyV_, bool CausalMask_, bool PagedKV_, bool HasSink_>
struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_,
StrideV_, MMAOp_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_,
GmemTiledCopyV_, CausalMask_, PagedKV_> {
StrideV_, ElementSink_, MMAOp_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_,
GmemTiledCopyV_, CausalMask_, PagedKV_, HasSink_> {
//
// Type Aliases
//
Expand All @@ -88,13 +88,15 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
using StrideK = StrideK_;
using ElementV = ElementV_;
using StrideV = StrideV_;
using ElementSink = ElementSink_;
using GmemTiledCopyQ = GmemTiledCopyQ_;
using GmemTiledCopyK = GmemTiledCopyK_;
using GmemTiledCopyV = GmemTiledCopyV_;
using ArchTag = typename DispatchPolicy::ArchTag;

static constexpr bool CausalMask = CausalMask_;
static constexpr bool PagedKV = PagedKV_;
static constexpr bool HasSink = HasSink_;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using MmaAtom = MMA_Atom<MMAOp_>;
Expand Down Expand Up @@ -174,6 +176,8 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
int const* ptr_page_table;
int page_size;
int const* num_pages_per_seq;
// attention sink
ElementSink const* ptr_Sink;
};

struct Params {
Expand All @@ -186,6 +190,8 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
int const* ptr_page_table;
int page_size;
int const* num_pages_per_seq;
// attention sink
ElementSink const* ptr_Sink;
};

//
Expand All @@ -212,7 +218,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)};
XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)};

return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, args.ptr_page_table, args.page_size, args.num_pages_per_seq};
return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, args.ptr_page_table, args.page_size, args.num_pages_per_seq, args.ptr_Sink};
}

template <class FragAccum, class TensorQ, class TensorK, class FragSrc>
Expand Down Expand Up @@ -430,7 +436,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)};
XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)};

return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, params.ptr_page_table, params.page_size, params.num_pages_per_seq};
return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, params.ptr_page_table, params.page_size, params.num_pages_per_seq, params.ptr_Sink};
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ namespace collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <bool CausalMask_, class DispatchPolicy, class... Args> class FlashDecodeSoftmaxEpilogue {
template <bool CausalMask_, bool HasSink_, class DispatchPolicy, class... Args> class FlashDecodeSoftmaxEpilogue {
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};


template <bool CausalMask_, class Element_>
class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_> {
template <bool CausalMask_, bool HasSink_, class Element_>
class FlashDecodeSoftmaxEpilogue<CausalMask_, HasSink_, epilogue::IntelXeXMX16, Element_> {
public:

//
Expand All @@ -66,6 +66,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
using Element = Element_;

static constexpr bool CausalMask = CausalMask_;
static constexpr bool HasSink = HasSink_;

using GmemTiledCopyOut = void;

Expand Down Expand Up @@ -149,7 +150,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
}

template <int Num_SGs, class FragAcc, class FragSum, class STensorMax, class FragOut>
CUTLASS_DEVICE void operator()(bool is_first, FragAcc &frag_s, Element& max_val, FragSum& sum,
CUTLASS_DEVICE void operator()(bool is_first, FragAcc &frag_s, Element& max_val, FragSum& sum, FragSum& sink_token,
STensorMax& shmem_tensor_max, FragOut& out) {
using FragAccLayout = typename FragAcc::layout_type;
using FragOutLayout = typename FragOut::layout_type;
Expand All @@ -162,6 +163,13 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>

reduce_max<Num_SGs,FragsNS>(frag_s, shmem_tensor_max, max_val);

if constexpr (HasSink) {
if (syclcompat::get_nd_item<3>().get_local_linear_id() == 0) {
Element max_scale{max_val * params.scale};
sum += sycl::native::exp2((sink_token- max_scale));
}
}

if (!is_first) {
auto sg = compat::get_nd_item<1>().get_sub_group();
const int sg_group_id = sg.get_group_id()[0];
Expand Down
13 changes: 11 additions & 2 deletions applications/flash_attention_v2/kernel/xe_flash_attn_decode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class FMHADecode {
using StrideK = typename CollectiveMainloop::StrideK;
using ElementV = typename CollectiveMainloop::ElementV;
using StrideV = typename CollectiveMainloop::StrideV;
using ElementSink = typename CollectiveMainloop::ElementSink;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using MainloopArguments = typename CollectiveMainloop::Arguments;
Expand Down Expand Up @@ -101,6 +102,7 @@ class FMHADecode {

static constexpr bool CausalMask = CollectiveMainloop::CausalMask;
static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
static constexpr bool HasSink = CollectiveMainloop::HasSink;
static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16
Expand Down Expand Up @@ -340,7 +342,14 @@ class FMHADecode {
CollectiveMainloop collective_mma;

ElementAccumulator max_reg = ElementAccumulator{-INFINITY};
ElementAccumulator sink_token = ElementAccumulator{ 0 };
auto sum_reg = ElementAccumulator{0};
if constexpr (HasSink) {
if (syclcompat::get_nd_item<3>().get_local_linear_id() == 0) {
max_reg = static_cast<ElementAccumulator>(mainloop_params.ptr_Sink[num_heads_coord]);
sink_token = max_reg;
}
}
Tensor out_reg = make_tensor<ElementAccumulator>(AccumShape{});
clear(out_reg);

Expand Down Expand Up @@ -391,7 +400,7 @@ class FMHADecode {
}

CollectiveSoftmaxEpilogue softmax(params.softmax);
softmax.template operator()<Num_SGs>(split == 0, tSr, max_reg, sum_reg, shmem_max_tensor, out_reg);
softmax.template operator()<Num_SGs>(split == 0, tSr, max_reg, sum_reg, sink_token, shmem_max_tensor, out_reg);

collective_mma.template mmaPV<VSlicer>(out_reg, tSr, gV, out_reg, mainloop_params, is_KV_cache, curr_kv_tile_idx);

Expand Down Expand Up @@ -455,7 +464,7 @@ class FMHADecode {
}

CollectiveSoftmaxEpilogue softmax(params.softmax);
softmax.template operator()<Num_SGs>((kv_splits - 1) == 0, tSr, max_reg, sum_reg, shmem_max_tensor, out_reg);
softmax.template operator()<Num_SGs>((kv_splits - 1) == 0, tSr, max_reg, sum_reg, sink_token, shmem_max_tensor, out_reg);

collective_mma.template mmaPV<VSlicer>(out_reg, tSr, gV, out_reg, mainloop_params, false, curr_kv_tile_idx);

Expand Down
10 changes: 8 additions & 2 deletions examples/06_bmg_flash_attention/06_bmg_decode_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,14 @@ int run_decode(Options const& options) {

#endif

return options.is_causal ? FMHAConfig<true, PagedKV, ShapeQK, ShapePV, ShapeOutput, SubgroupLayout, Varlen>::run(options)
: FMHAConfig<false, PagedKV, ShapeQK, ShapePV, ShapeOutput, SubgroupLayout, Varlen>::run(options);
if (options.is_causal) {
return options.use_sink_attn ? FMHAConfig<true, PagedKV, ShapeQK, ShapePV, ShapeOutput, SubgroupLayout, Varlen, true>::run(options)
: FMHAConfig<true, PagedKV, ShapeQK, ShapePV, ShapeOutput, SubgroupLayout, Varlen, false>::run(options);
}
else {
return options.use_sink_attn ? FMHAConfig<false, PagedKV, ShapeQK, ShapePV, ShapeOutput, SubgroupLayout, Varlen, true>::run(options)
: FMHAConfig<false, PagedKV, ShapeQK, ShapePV, ShapeOutput, SubgroupLayout, Varlen, false>::run(options);
}
}

int main(int argc, const char **argv) {
Expand Down
45 changes: 38 additions & 7 deletions examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct Options {
bool help;
bool error;
bool is_causal;
bool use_sink_attn;
bool varlen = false;
bool use_paged_kv = false;
std::string scheduler;
Expand All @@ -68,7 +69,7 @@ struct Options {
float softmax_scale;

Options()
: help(false), error(false), is_causal(false), varlen(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(1), head_size_qk(128),
: help(false), error(false), is_causal(false), use_sink_attn(false), varlen(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(1), head_size_qk(128),
seq_len_kv(512), seq_len_kv_cache(0), page_size(128), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {}

// Parses the command line
Expand All @@ -84,6 +85,10 @@ struct Options {
is_causal = true;
}

if (cmd.check_cmd_line_flag("use_sink_attn")) {
use_sink_attn = true;
}

if (cmd.check_cmd_line_flag("varlen")) {
varlen = true;
}
Expand Down Expand Up @@ -120,6 +125,7 @@ struct Options {
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --is_causal Apply Causal Mask to the output of first Matmul\n"
<< " --use_sink_attn Apply Attention Sink\n"
<< " --varlen Enable variable sequence length\n"
<< " --scheduler Only Individual Scheduler supported\n"
<< " --batch=<int> Sets the Batch Size of the Multi-Head Self Attention module\n"
Expand Down Expand Up @@ -155,6 +161,7 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
using ElementQ = typename FMHAKernel::ElementQ;
using ElementK = typename FMHAKernel::ElementK;
using ElementV = typename FMHAKernel::ElementV;
using ElementSink = typename FMHAKernel::ElementSink;
using ElementAcc = typename FMHAKernel::ElementAccumulator;

using CollectiveEpilogue = typename FMHAKernel::CollectiveEpilogue;
Expand All @@ -181,6 +188,7 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
cutlass::DeviceAllocation<ElementQ> block_Q;
cutlass::DeviceAllocation<ElementK> block_K;
cutlass::DeviceAllocation<ElementV> block_V;
cutlass::DeviceAllocation<ElementSink> block_Sink;
cutlass::DeviceAllocation<ElementK> block_K_cache;
cutlass::DeviceAllocation<ElementV> block_V_cache;
cutlass::DeviceAllocation<ElementOutput> block_O;
Expand Down Expand Up @@ -216,7 +224,7 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
//
// Methods
//
bool verify(ProblemShapeType problem_size, bool is_causal, bool use_kv_cache) {
bool verify(ProblemShapeType problem_size, bool is_causal, bool use_kv_cache, bool sink_attn) {

if constexpr (isVarLen) {
int max_seq_len_q = static_cast<int>(get<3>(problem_size));
Expand All @@ -240,6 +248,13 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
int offset_k_cache = 0;
int offset_v_cache = 0;
int offset_o = 0;
std::vector<ElementSink> host_Sink;

if (sink_attn) {
host_Sink.resize(block_Sink.size());
syclcompat::memcpy<ElementSink>(host_Sink.data(), block_Sink.get(), host_Sink.size());
syclcompat::wait();
}

int q_group_size = num_heads_q / num_heads_kv;
// loop over the batch dimension to compute the output
Expand Down Expand Up @@ -352,6 +367,11 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
if (max_vec[max_idx] < host_S[idx])
max_vec[max_idx] = host_S[idx];
}
if (sink_attn) {
ElementAccumulator sink_val = static_cast<ElementAccumulator>(host_Sink[h]);
if (max_vec[max_idx] < sink_val)
max_vec[max_idx] = sink_val;
}
}

// compute exp of S
Expand All @@ -372,6 +392,12 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
sum_vec[sum_idx] += host_S[idx];
}

if (sink_attn) {
ElementAccumulator sink_val = static_cast<ElementAccumulator>(host_Sink[h]);
auto exp_sink = expf((sink_val - max_vec[row]) / sqrt(static_cast<ElementAccumulator>((head_size_qk))));
sum_vec[sum_idx] += exp_sink;
}

// scale each row with the sum to compute softmax
idx = row * seq_len_kv_total;
sum_idx = row;
Expand Down Expand Up @@ -553,6 +579,7 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk);
block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk);
block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo);
block_Sink.reset(num_heads_kv);
block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk);
block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo);
block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo);
Expand Down Expand Up @@ -592,6 +619,7 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
initialize_block(block_Q, seed + 2021);
initialize_block(block_K, seed + 2022);
initialize_block(block_V, seed + 2023);
initialize_block(block_Sink, seed + 2021);
initialize_block(block_K_cache, seed + 2024);
initialize_block(block_V_cache, seed + 2025);

Expand Down Expand Up @@ -664,7 +692,8 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {
block_V_cache.get(), stride_V_cache,
options.use_paged_kv ? paged_kv_cache.page_table.get() : nullptr,
options.use_paged_kv ? paged_kv_cache.page_size : 0,
options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr},
options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr,
block_Sink.get()},
{options.softmax_scale},
{block_O.get(), stride_O},
hw_info};
Expand All @@ -691,7 +720,7 @@ template <class FMHAKernel, bool isVarLen> struct ExampleRunner {

// Verify that the result is correct
bool use_kv_cache = options.seq_len_kv_cache > 0;
bool passed = verify(problem_size, options.is_causal, use_kv_cache);
bool passed = verify(problem_size, options.is_causal, use_kv_cache, options.use_sink_attn);
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;

if (!passed) {
Expand Down Expand Up @@ -742,9 +771,11 @@ template <bool Causal,
typename TileShapeOutput,
typename SubgroupLayout,
bool isVarLen,
bool hasSink,
int PipelineStages = 2,
typename ElementInputQ = bfloat16_t,
typename ElementInputKV = bfloat16_t,
typename ElementInputSink = bfloat16_t,
typename MMAOperation = XE_1x16x16_F32BF16BF16F32_TT,
typename GmemTiledCopyQ = XE_2D_U16x1x16_LD_N,
typename GmemTiledCopyK = XE_2D_U16x16x16_LD_T, // _T designates a transposed block load operation
Expand All @@ -768,7 +799,7 @@ template <bool Causal,
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashDecodeEpilogue<
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t<LayoutO>,
ElementOutput, GmemTiledCopyStore>;
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashDecodeSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashDecodeSoftmaxEpilogue<Causal, hasSink, EpilogueDispatchPolicy, ElementAccumulator>;

using ProblemShapeRegular = cute::tuple<int, int, int, int, int, int, int, int>;
using namespace cutlass::fmha::collective;
Expand All @@ -778,9 +809,9 @@ template <bool Causal,
// Mainloop
using CollectiveMainloop = cutlass::flash_attention::collective::FlashDecodeMma<
GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t<LayoutQ>, ElementInputKV,
cutlass::gemm::TagToStrideB_t<LayoutK>, ElementInputKV, cutlass::gemm::TagToStrideB_t<LayoutV>, MMAOperation,
cutlass::gemm::TagToStrideB_t<LayoutK>, ElementInputKV, cutlass::gemm::TagToStrideB_t<LayoutV>, ElementInputSink, MMAOperation,
TileShapeQK, TileShapePV, SubgroupLayout, GmemTiledCopyQ/* Q */, GmemTiledCopyK/* K */,
GmemTiledCopyV/* V */, Causal, PagedKV>;
GmemTiledCopyV/* V */, Causal, PagedKV, hasSink>;

using FMHAKernel = cutlass::flash_attention::kernel::FMHADecode<ProblemShapeType, CollectiveMainloop,
CollectiveSoftmaxEpilogue, CollectiveEpilogue, Scheduler>;
Expand Down