From 9b3826f762de28b2c35aa8f9249b916973b7b140 Mon Sep 17 00:00:00 2001 From: "Xuxin, Zeng" Date: Tue, 29 Mar 2022 21:07:21 -0700 Subject: [PATCH] cpu: x64: update the methods of computing batch size in brgemm matmul --- src/cpu/x64/matmul/brgemm_matmul.cpp | 13 ++++++++----- src/cpu/x64/matmul/brgemm_matmul.hpp | 11 +++-------- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 6 ++++++ src/cpu/x64/matmul/brgemm_matmul_utils.hpp | 2 +- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index c573bba0edd..73dfcca845c 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -695,6 +695,12 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { + s8s8_buffer_sz])); } + // Set last_chunk_brgemm_batch_size_ to brgemm_batch_size + // when K_tail = 0 and brgemm_batch_tail_size = 0 + last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_tail_size; + if (bgmmc.K_tail == 0 && last_chunk_brgemm_batch_size_ == 0) + last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_size; + // parallelization parallel_work_amount_ = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks; @@ -1011,11 +1017,7 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { } int get_brgemm_batch_size(int k_chunk_idx) const { - const int last_brgemm_batch_size - = (nstl::max(bgmmc_.K, bgmmc_.K_blk) - - k_chunk_idx * bgmmc_.K_chunk_elems) - / bgmmc_.K_blk; - return is_last_K_chunk(k_chunk_idx) ? last_brgemm_batch_size + return is_last_K_chunk(k_chunk_idx) ? last_chunk_brgemm_batch_size_ : bgmmc_.brgemm_batch_size; } @@ -1071,6 +1073,7 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { // parallelization parameters int parallel_work_amount_; int nthr_, nthr_k_, nthr_bmn_, num_threads_used_; + int last_chunk_brgemm_batch_size_; }; template struct brgemm_matmul_t; diff --git a/src/cpu/x64/matmul/brgemm_matmul.hpp b/src/cpu/x64/matmul/brgemm_matmul.hpp index acea57eae66..dad1a6ad739 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul.hpp @@ -56,14 +56,9 @@ inline int get_brg_kernel_index(const brgemm_matmul_conf_t &bgmmc, inline int get_brg_batchsize( const brgemm_matmul_conf_t &bgmmc, bool is_bs_tail, bool is_K_tail) { - auto adj_k_a = bgmmc.use_buffer_a ? utils::rnd_up(bgmmc.K, bgmmc.K_blk) - : bgmmc.K; - auto adj_k_b = utils::rnd_up(bgmmc.wei_k_blk, bgmmc.K); - auto adj_k = nstl::min(adj_k_a, adj_k_b); - auto bs = (is_K_tail) - ? 1 - : ((is_bs_tail) ? (adj_k / bgmmc.K_blk) % bgmmc.brgemm_batch_size - : bgmmc.brgemm_batch_size); + auto bs = (is_K_tail) ? 1 + : ((is_bs_tail) ? bgmmc.brgemm_batch_tail_size + : bgmmc.brgemm_batch_size); return bs; } } // namespace diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 60bd71b64b9..f74361f3c18 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -915,6 +915,12 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, bgmmc.K_chunks = div_up(bgmmc.K, bgmmc.K_chunk_elems); bgmmc.num_M_blocks = div_up(bgmmc.M, bgmmc.M_blk); bgmmc.num_N_blocks = div_up(bgmmc.N, bgmmc.N_blk); + const int last_chunck_batch_size + = (nstl::max(bgmmc.K, bgmmc.K_blk) + - (bgmmc.K_chunks - 1) * bgmmc.K_chunk_elems) + / bgmmc.K_blk; + bgmmc.brgemm_batch_tail_size + = last_chunck_batch_size % bgmmc.brgemm_batch_size; bgmmc.buffer_c_chunk_sz = bgmmc.acc_dt_sz * bgmmc.LDC * (bgmmc.nthr_k > 1 ? bgmmc.M : bgmmc.M_blk); diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp index 80854960836..5a4c4290ded 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp @@ -82,7 +82,7 @@ struct brgemm_matmul_conf_t { dim_t M_blk, N_blk, K_blk, M_tail, N_tail, K_tail; int M_chunk_size, N_chunk_size; dim_t LDA, LDB, LDC, LDD; - int brgemm_batch_size; + int brgemm_batch_size, brgemm_batch_tail_size; int wei_n_blk, wei_k_blk; brgemm_batch_kind_t brg_type;