diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 7a1315ab9..5a8d37e88 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -140,6 +140,11 @@ struct fpu_attr_t { template inline constexpr bool arch_has_fpu = fpu_attr_t::has_fpu; +#define GRF grf_mode::double_grf +#ifdef NORMAL_GRF +#define GRF grf_mode::normal_grf +#endif + template struct register_nums_t { static constexpr uint32_t register_nums = diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 1de706ccb..59bfd8f54 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -136,12 +136,15 @@ struct compute_policy_int4_dequantize< static constexpr bool is_col_major_b = quant_info_.weight_mem_layout == mem_layout::col_major; + using reg_nums_t = register_nums_t; static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16; - static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32; + static constexpr uint32_t block_bytes_x_a = + is_col_major_b ? reg_nums_t::register_nums : 32; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32; - static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32; + static constexpr uint32_t block_bytes_y_b = + is_col_major_b ? reg_nums_t::register_nums : 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index be693b063..0968f4bc6 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1887,10 +1887,8 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< (arch_tag_ == gpu_arch::XeHpc) && - (((block_size_y_ != 1 || tile_size_y_ != 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ != 1 || tile_size_x_ != 1) && - mem_layout_ == mem_layout::col_major))>> { + (((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2180,10 +2178,8 @@ struct prefetch_payload_t< num_coop_sg_, arch_tag_, std::enable_if_t< - ((block_size_y_ == 1 && tile_size_y_ == 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ == 1 && tile_size_x_ == 1) && - mem_layout_ == mem_layout::col_major)>> { + ((tile_size_y_ == 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t;