From 6c291113d162dcf70e395be6f4c3c02644078934 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Mon, 5 Aug 2024 23:22:34 +0800 Subject: [PATCH 1/4] XeTLA use mask with zero-passthrough --- include/subgroup/tile/impl/load_xe.hpp | 38 ++++++++++++++++---------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index ace2ed1a4..f9c027f09 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -457,6 +457,7 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t num_channel = payload_t::num_channel; constexpr uint32_t load_elems = num_channel * payload_t::vector_size; constexpr uint32_t pack_factor = payload_t::pack_factor; + const xetla_vector reg_zeros(0); auto channel_offset = payload.channel_offset + payload.base_offset; #pragma unroll @@ -494,28 +495,35 @@ tile_load(tile_t& tile, payload_t& payload) { ? (xetla_vector_gen(offset_ch_dim, 1) < size_ch_dim) : 1; + reg_tmp = xetla_load_global< + load_dtype, + load_elems, + payload_t::vector_size, + L1, + L2>( + payload.base_ptr, + channel_offset + address_offset, + mask, + reg_zeros); + } else { + reg_tmp = xetla_load_global< + load_dtype, + load_elems, + payload_t::vector_size, + L1, + L2>(payload.base_ptr, channel_offset + address_offset, mask); } - reg_tmp = xetla_load_global< - load_dtype, - load_elems, - payload_t::vector_size, - L1, - L2>(payload.base_ptr, channel_offset + address_offset, mask); if constexpr ( payload_t::vector_size > 1 && payload_t::num_channel > 1) { xetla_vector reg_tmp_trans; #pragma unroll for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { - if ((bool)mask[iii]) // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::vector_size) = - reg_tmp.xetla_select< - payload_t::vector_size, - payload_t::num_channel>(iii); - else // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::vector_size) = 0; + reg_tmp_trans.xetla_select( + iii * payload_t::vector_size) = + reg_tmp.xetla_select< + payload_t::vector_size, + payload_t::num_channel>(iii); } reg_sub .xetla_select( From f3b453d021ed1bcbd2d9077f226143394c4eb215 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Tue, 6 Aug 2024 01:49:23 +0000 Subject: [PATCH 2/4] reformat --- include/subgroup/tile/impl/payload_xe.hpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index b39b132a1..2316ef23f 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1655,12 +1655,11 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t< - (!arch_has_2d_load_store) && - (((block_size_y_ != 1 || tile_size_y_ != 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ != 1 || tile_size_x_ != 1) && - mem_layout_ == mem_layout::col_major))>> { + std::enable_if_t<(!arch_has_2d_load_store)&&( + ((block_size_y_ != 1 || tile_size_y_ != 1) && + mem_layout_ == mem_layout::row_major) || + ((block_size_x_ != 1 || tile_size_x_ != 1) && + mem_layout_ == mem_layout::col_major))>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -1902,10 +1901,9 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t< - (arch_has_2d_load_store) && - (((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || - ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { + std::enable_if_t<(arch_has_2d_load_store)&&( + ((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; From d7bbf6b874d63dd61e6401d9e779a3eb33331fbc Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Tue, 6 Aug 2024 01:51:51 +0000 Subject: [PATCH 3/4] sync fmha --- tests/integration/fmha/fmha_forward.hpp | 69 +++++++++++-------------- tests/integration/fmha/fmha_utils.h | 23 +++++---- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/tests/integration/fmha/fmha_forward.hpp b/tests/integration/fmha/fmha_forward.hpp index 03a2bc99a..8efe1f0c5 100644 --- a/tests/integration/fmha/fmha_forward.hpp +++ b/tests/integration/fmha/fmha_forward.hpp @@ -129,12 +129,12 @@ class fmha_forward_t { using comp_attr = group::compute_attr_t; using knobs = group::perf_tuning_knob_t; using compute_policy_BrBc = std::conditional_t< - (arch_tag >= gpu_arch::XeHpg), + (arch_has_xmx), group::compute_policy_default_xmx, group::compute_policy_default_fpu>; // TODO: add k slicing using compute_policy_BrBm = std::conditional_t< - (arch_tag >= gpu_arch::XeHpg), + (arch_has_xmx), group::compute_policy_default_xmx, group::compute_policy_default_fpu>; // ---------------- // Tile shape and Threads // ---------------- // @@ -688,7 +688,7 @@ class fmha_forward_t { uint8_t, mem_desc_Dp_Mask_t::layout, mem_desc_Dp_Mask_t::space>>, - gpu_arch::XeHpc>; + arch_tag>; load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij); subgroup::tile_load(mask_in, load_payload_mask); matAccSij.reg = matAccSij.reg * mask_in.reg * args.dp_scale; @@ -771,7 +771,7 @@ class fmha_forward_t { uint32_t height = args.uB * args.uN * args.uF; uint32_t offset_height = b * args.uN * args.uF + f * args.uN + n; - if constexpr (arch_tag != gpu_arch::XeHpc) { + if constexpr (!arch_has_2d_load_store) { // offset for curr work item const uint32_t O_offset = offset_height * args.uH + h; const auto ld_c = args.uN * args.uH; @@ -798,30 +798,30 @@ class fmha_forward_t { matOi_store_t matOi_store(mem_desc_Oi); subgroup::tile_store( matOi, matOi_store); - return; - } - - xetla_fill_tdesc( - transpose_tdecs.xetla_format(), - args.O_ptr, - args.uH, - height, - args.uH, - h, - offset_height); - - for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) { - // load data from matAccOi - auto v_acc = matAccOi.reg.xetla_select(i * kSgHm); - v_out = xetla_cvt(v_acc); - - xetla_tstore_global< - scalar_t, - kSgHm, - cache_hint::write_back, - cache_hint::write_back>(transpose_tdecs, v_out); - xetla_update_tdesc_offsety( - transpose_tdecs.xetla_format(), args.uN); + } else { + xetla_fill_tdesc( + transpose_tdecs.xetla_format(), + args.O_ptr, + args.uH, + height, + args.uH, + h, + offset_height); + + for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) { + // load data from matAccOi + auto v_acc = matAccOi.reg.xetla_select(i * kSgHm); + v_out = xetla_cvt(v_acc); + + xetla_tstore_global< + scalar_t, + kSgHm, + cache_hint::write_back, + cache_hint::write_back, + arch_tag>(transpose_tdecs, v_out); + xetla_update_tdesc_offsety( + transpose_tdecs.xetla_format(), args.uN); + } } } // ====================== // preload_Qi // ====================== // @@ -888,16 +888,9 @@ class fmha_forward_t { /// @return The size of local memory required. inline static constexpr uint32_t get_slm_size() { constexpr uint32_t size = slm_size_Qi + slm_size_Pij + slm_size_softmax; - if constexpr (arch_tag == gpu_arch::XeHpc) { - static_assert( - size <= (128 * 1024), - "The local memory size should be less than 128KB!"); - - } else { - static_assert( - size <= (64 * 1024), - "The local memory size should be less than 64KB!"); - } + static_assert( + size <= (arch_attr_t::local_mem_size), + "The local memory size should be less than arch total local memory size"); return size; }; diff --git a/tests/integration/fmha/fmha_utils.h b/tests/integration/fmha/fmha_utils.h index 25e19814c..5e3d491a0 100644 --- a/tests/integration/fmha/fmha_utils.h +++ b/tests/integration/fmha/fmha_utils.h @@ -134,7 +134,7 @@ template < typename mat_t, uint32_t kNumSg, reduce_op reduce_kind, - gpu_arch arch_tag = gpu_arch::XeHpc> + gpu_arch arch_tag> struct group_row_reduce_t { using T = typename mat_t::dtype; static constexpr uint32_t kNum = mat_t::tile_desc::tile_size_y; @@ -215,7 +215,7 @@ enum class add_type : uint8_t { /// @tparam arch_tag Is the hardware architecture tag. template < typename dtype_bias_, - gpu_arch arch_tag = gpu_arch::XeHpc, + gpu_arch arch_tag, add_type add_tag = add_type::single_line> struct bias_add_op_t {}; @@ -324,8 +324,8 @@ struct bias_add_op_t { using base_t = typename mem_desc_bias_t::base_t; struct arguments_t { - shape_t shape; base_t base; + shape_t shape; inline arguments_t() = default; inline arguments_t(base_t base_, shape_t shape_) : base(base_), shape(shape_) {} @@ -351,11 +351,10 @@ struct bias_add_op_t { uint32_t offset = (pos_y + pos_x * args.shape.stride) * sizeof(dtype_bias); auto bias_data_vector = xetla_load_global< dtype_bias, + 16, 1, - data_size::default_size, - cache_hint::cached, cache_hint::cached, - 16>(ptr, offset); + cache_hint::cached>(ptr, offset); dtype_acc bias_data = xetla_cvt(bias_data_vector)[0]; @@ -418,15 +417,19 @@ template < typename mem_desc_c_t_> class epilogue_transp_t {}; -template +template < + typename tile_op_t_, + typename tile_shape_, + typename mem_desc_c_t_, + gpu_arch arch_tag_> class epilogue_transp_t< - epilogue_policy_tile_op, + epilogue_policy_tile_op, tile_shape_, mem_desc_c_t_> { public: using tile_shape = tile_shape_; using mem_desc_c_t = mem_desc_c_t_; - static constexpr gpu_arch arch_tag = gpu_arch::XeHpc; + static constexpr gpu_arch arch_tag = arch_tag_; static constexpr uint32_t barrier_count = 0; static constexpr uint32_t slm_size = 0; @@ -505,7 +508,7 @@ class epilogue_write_back_t< epilogue_policy_default, tile_shape_, mem_desc_c_t_, - std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> { + std::enable_if_t>> { public: using epilogue_policy = epilogue_policy_default; using tile_shape = tile_shape_; From f9d4ad3ad86462c9660245d5cd76710d616e0f02 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Tue, 6 Aug 2024 02:27:33 +0000 Subject: [PATCH 4/4] support pass_thru for 2024.1 --- include/common/core/memory.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index f9e75f6a0..93bedbfe0 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -500,12 +500,23 @@ __XETLA_API xetla_vector xetla_load_global( xetla_vector byte_offsets, xetla_mask mask, xetla_vector pass_thru) { +#if __INTEL_LLVM_COMPILER >= 20240200 __ESIMD_NS::properties props{ __ESIMD_NS::cache_hint_L1, __ESIMD_NS::cache_hint_L2, __ESIMD_NS::alignment}; return __ESIMD_NS::gather(p, byte_offsets, mask, pass_thru, props); +#else + constexpr data_size DS = data_size::default_size; + return __ESIMD_ENS::lsc_gather< + T, + VS, + gpu::xetla::detail::get_data_size(DS), + gpu::xetla::detail::get_cache_hint(L1H), + gpu::xetla::detail::get_cache_hint(L2H), + N / VS>(p, byte_offsets, mask, pass_thru); +#endif } /// template