Skip to content

Commit

Permalink
cpu: x64: update the methods of computing batch size in brgemm matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxinzen authored and tprimak committed Mar 30, 2022
1 parent b59b027 commit 9b3826f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
13 changes: 8 additions & 5 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Expand Up @@ -695,6 +695,12 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
+ s8s8_buffer_sz]));
}

// Set last_chunk_brgemm_batch_size_ to brgemm_batch_size
// when K_tail = 0 and brgemm_batch_tail_size = 0
last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_tail_size;
if (bgmmc.K_tail == 0 && last_chunk_brgemm_batch_size_ == 0)
last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_size;

// parallelization
parallel_work_amount_ = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks;

Expand Down Expand Up @@ -1011,11 +1017,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
}

int get_brgemm_batch_size(int k_chunk_idx) const {
const int last_brgemm_batch_size
= (nstl::max(bgmmc_.K, bgmmc_.K_blk)
- k_chunk_idx * bgmmc_.K_chunk_elems)
/ bgmmc_.K_blk;
return is_last_K_chunk(k_chunk_idx) ? last_brgemm_batch_size
return is_last_K_chunk(k_chunk_idx) ? last_chunk_brgemm_batch_size_
: bgmmc_.brgemm_batch_size;
}

Expand Down Expand Up @@ -1071,6 +1073,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
// parallelization parameters
int parallel_work_amount_;
int nthr_, nthr_k_, nthr_bmn_, num_threads_used_;
int last_chunk_brgemm_batch_size_;
};

template struct brgemm_matmul_t<avx512_core_bf16_amx_int8>;
Expand Down
11 changes: 3 additions & 8 deletions src/cpu/x64/matmul/brgemm_matmul.hpp
Expand Up @@ -56,14 +56,9 @@ inline int get_brg_kernel_index(const brgemm_matmul_conf_t &bgmmc,

inline int get_brg_batchsize(
const brgemm_matmul_conf_t &bgmmc, bool is_bs_tail, bool is_K_tail) {
auto adj_k_a = bgmmc.use_buffer_a ? utils::rnd_up(bgmmc.K, bgmmc.K_blk)
: bgmmc.K;
auto adj_k_b = utils::rnd_up<decltype(bgmmc.K)>(bgmmc.wei_k_blk, bgmmc.K);
auto adj_k = nstl::min(adj_k_a, adj_k_b);
auto bs = (is_K_tail)
? 1
: ((is_bs_tail) ? (adj_k / bgmmc.K_blk) % bgmmc.brgemm_batch_size
: bgmmc.brgemm_batch_size);
auto bs = (is_K_tail) ? 1
: ((is_bs_tail) ? bgmmc.brgemm_batch_tail_size
: bgmmc.brgemm_batch_size);
return bs;
}
} // namespace
Expand Down
6 changes: 6 additions & 0 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Expand Up @@ -915,6 +915,12 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc,
bgmmc.K_chunks = div_up(bgmmc.K, bgmmc.K_chunk_elems);
bgmmc.num_M_blocks = div_up(bgmmc.M, bgmmc.M_blk);
bgmmc.num_N_blocks = div_up(bgmmc.N, bgmmc.N_blk);
const int last_chunck_batch_size
= (nstl::max(bgmmc.K, bgmmc.K_blk)
- (bgmmc.K_chunks - 1) * bgmmc.K_chunk_elems)
/ bgmmc.K_blk;
bgmmc.brgemm_batch_tail_size
= last_chunck_batch_size % bgmmc.brgemm_batch_size;

bgmmc.buffer_c_chunk_sz = bgmmc.acc_dt_sz * bgmmc.LDC
* (bgmmc.nthr_k > 1 ? bgmmc.M : bgmmc.M_blk);
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Expand Up @@ -82,7 +82,7 @@ struct brgemm_matmul_conf_t {
dim_t M_blk, N_blk, K_blk, M_tail, N_tail, K_tail;
int M_chunk_size, N_chunk_size;
dim_t LDA, LDB, LDC, LDD;
int brgemm_batch_size;
int brgemm_batch_size, brgemm_batch_tail_size;
int wei_n_blk, wei_k_blk;
brgemm_batch_kind_t brg_type;

Expand Down

0 comments on commit 9b3826f

Please sign in to comment.