From b5d8d32d70c9760165056266fb2724007ae5ac83 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Fri, 19 Jul 2024 07:14:43 +0000 Subject: [PATCH 1/2] grf_mode ctrl --- include/common/core/arch_config.hpp | 5 +++++ include/experimental/group/gemm/compute_policy.hpp | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 7a1315ab9..5a8d37e88 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -140,6 +140,11 @@ struct fpu_attr_t { template inline constexpr bool arch_has_fpu = fpu_attr_t::has_fpu; +#define GRF grf_mode::double_grf +#ifdef NORMAL_GRF +#define GRF grf_mode::normal_grf +#endif + template struct register_nums_t { static constexpr uint32_t register_nums = diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 1de706ccb..59bfd8f54 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -136,12 +136,15 @@ struct compute_policy_int4_dequantize< static constexpr bool is_col_major_b = quant_info_.weight_mem_layout == mem_layout::col_major; + using reg_nums_t = register_nums_t; static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16; - static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32; + static constexpr uint32_t block_bytes_x_a = + is_col_major_b ? reg_nums_t::register_nums : 32; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32; - static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32; + static constexpr uint32_t block_bytes_y_b = + is_col_major_b ? reg_nums_t::register_nums : 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); From 068dc801a7fee051f27de9f9a5c17f47724c0304 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Fri, 19 Jul 2024 07:15:24 +0000 Subject: [PATCH 2/2] Use 2d when tile_size_x/y=1 --- include/subgroup/tile/impl/payload_xe.hpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index be693b063..0968f4bc6 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1887,10 +1887,8 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< (arch_tag_ == gpu_arch::XeHpc) && - (((block_size_y_ != 1 || tile_size_y_ != 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ != 1 || tile_size_x_ != 1) && - mem_layout_ == mem_layout::col_major))>> { + (((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2180,10 +2178,8 @@ struct prefetch_payload_t< num_coop_sg_, arch_tag_, std::enable_if_t< - ((block_size_y_ == 1 && tile_size_y_ == 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ == 1 && tile_size_x_ == 1) && - mem_layout_ == mem_layout::col_major)>> { + ((tile_size_y_ == 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t;