Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,20 @@ cutlass_test_unit_add_executable(
xe_flash_decode_fp16_fp32_fp32_h192_1024_nonpaged.cpp
)

cutlass_test_unit_add_executable(
cutlass_test_unit_flash_attention_decode_models_xe
xe_flash_decode_models_fp16_nonpaged.cpp
xe_flash_decode_models_bf16_nonpaged.cpp
)

add_custom_target(
cutlass_test_unit_flash_attention_decode
DEPENDS
cutlass_test_unit_flash_attention_decode_h64_xe
cutlass_test_unit_flash_attention_decode_h96_xe
cutlass_test_unit_flash_attention_decode_h128_xe
cutlass_test_unit_flash_attention_decode_h192_xe
cutlass_test_unit_flash_attention_decode_models_xe
)

add_custom_target(
Expand All @@ -74,4 +81,5 @@ add_custom_target(
test_unit_flash_attention_decode_h96_xe
test_unit_flash_attention_decode_h128_xe
test_unit_flash_attention_decode_h192_xe
test_unit_flash_attention_decode_models_xe
)
Original file line number Diff line number Diff line change
Expand Up @@ -798,14 +798,56 @@ struct Testbed3x {
};

template <typename FlashDecode>
bool TestFlashDecodeAll(int head_size) {
bool TestFlashDecodeAll(int head_size, std::string config="default") {
Testbed3x<FlashDecode> testbed;

std::vector<int> problem_size_batch{16};
std::vector<int> problem_size_num_heads{32};
std::vector<int> problem_size_seq_len{1024};
std::vector<int> problem_size_seq_len_cache{0, 1024};
std::vector<int> cache_page_size{64, 128};
std::vector<int> problem_size_batch;
std::vector<int> problem_size_num_heads;
std::vector<int> problem_size_seq_len;
std::vector<int> problem_size_seq_len_cache;
std::vector<int> cache_page_size;
if(config == "whisper_v3_large"){
problem_size_batch = {1, 2, 4};
problem_size_num_heads = {20};
problem_size_seq_len = {512, 1024};
problem_size_seq_len_cache = {0, 1024};
cache_page_size = {64, 128};
}
else if(config == "llama3_8b"){
problem_size_batch = {1, 2, 4};
problem_size_num_heads = {32};
problem_size_seq_len = {512, 1024};
problem_size_seq_len_cache = {0, 1024};
cache_page_size = {64, 128};
}
else if(config == "llama3_405b"){
problem_size_batch = {1, 2};
problem_size_num_heads = {128};
problem_size_seq_len = {512, 1024};
problem_size_seq_len_cache = {0, 1024};
cache_page_size = {64, 128};
}
else if(config == "qwen2_5_72b"){
problem_size_batch = {1, 2};
problem_size_num_heads = {64};
problem_size_seq_len = {512, 1024};
problem_size_seq_len_cache = {0, 1024};
cache_page_size = {64, 128};
}
else if(config == "deepseek_r1"){
problem_size_batch = {1, 2};
problem_size_num_heads = {64};
problem_size_seq_len = {512, 1024};
problem_size_seq_len_cache = {0, 1024};
cache_page_size = {64, 128};
}
else{
problem_size_batch = {16};
problem_size_num_heads = {32};
problem_size_seq_len = {1024};
problem_size_seq_len_cache = {0, 1024};
cache_page_size = {64, 128};
}
std::vector<float> problem_size_softmax_scale{ 1.f / sqrt(static_cast<float>(head_size)) };
bool passed = true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
*
**************************************************************************************************/


/*! \file
\brief Tests for Xe flash attention decode bf16
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/****************************************************************************
* Copyright (C) 2025 Intel Corporation. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
***************************************************************************/

#include "flash_decode_testbed_3x.hpp"

namespace cutlass {

using MMAOperationBF16 = test::flash_attention::MMAOperationBF16;
using GmemTiledCopyQ = test::flash_attention::GmemTiledCopyQU16;
using GmemTiledCopyK = test::flash_attention::GmemTiledCopyKU16;
using GmemTiledCopyV = test::flash_attention::GmemTiledCopyVU16;
using GmemTiledCopyStore = test::flash_attention::GmemTiledCopyStoreU32;

// 20 tests: 5 models × 4 head sizes, KV512, causal, varlen

// h64 × KV512 × Causal × VarLen
TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_whisper) {
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "whisper_v3_large"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_llama8b) {
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "llama3_8b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_llama405b) {
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "llama3_405b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_qwen25) {
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "qwen2_5_72b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_deepseek) {
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "deepseek_r1"));
}

// h96 × KV512 × Causal × VarLen
TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_whisper) {
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "whisper_v3_large"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_llama8b) {
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "llama3_8b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_llama405b) {
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "llama3_405b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_qwen25) {
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "qwen2_5_72b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_deepseek) {
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "deepseek_r1"));
}

// h128 × KV512 × Causal × VarLen
TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_whisper) {
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "whisper_v3_large"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_llama8b) {
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "llama3_8b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_llama405b) {
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "llama3_405b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_qwen25) {
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "qwen2_5_72b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_deepseek) {
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "deepseek_r1"));
}

// h192 × KV512 × Causal × VarLen
TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_whisper) {
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "whisper_v3_large"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_llama8b) {
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "llama3_8b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_llama405b) {
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "llama3_405b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_qwen25) {
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "qwen2_5_72b"));
}
TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_deepseek) {
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
MMAOperationBF16, true, true,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "deepseek_r1"));
}

} // namespace cutlass
Loading