Skip to content

Commit

Permalink
cpu: x64: gemm: fix perf of ncf-3 shapes
Browse files Browse the repository at this point in the history
Further reduce the number of cases where # of thread are limited.
  • Loading branch information
aaraujom authored and tprimak committed Dec 21, 2020
1 parent 2d2f8b8 commit 32c1110
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
9 changes: 3 additions & 6 deletions src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp
Expand Up @@ -1894,8 +1894,9 @@ dnnl_status_t jit_avx512_common_gemm_f32(int nthrs, const char *transa,
return ref_gemm(transa, transb, p_m, p_n, p_k, p_alpha, A, p_lda, B,
p_lda, p_beta, C, p_ldc, bias);

bool is_transa = *transa == 'T';
bool is_transb = *transb == 'T';
int nthr_max = dnnl_get_current_num_threads();
int nthr_to_use = nstl::min(nthrs, nthr_max);

dim_t m = *p_m;
dim_t n = *p_n;
dim_t k = *p_k;
Expand All @@ -1905,10 +1906,6 @@ dnnl_status_t jit_avx512_common_gemm_f32(int nthrs, const char *transa,
float beta = *p_beta;
dim_t MB, NB, KB;

int nthr_max = dnnl_get_current_num_threads();
bool use_max_nthr = !is_transa && !is_transb && m <= 10 && n >= 50;
int nthr_to_use = use_max_nthr ? nthr_max : nstl::min(nthrs, nthr_max);

int nthr_m = 1, nthr_n = 1, nthr_k = 1, nthr_mn = 1;

// Determine threading partitioning
Expand Down
9 changes: 3 additions & 6 deletions src/cpu/x64/gemm/f32/jit_avx_gemm_f32.cpp
Expand Up @@ -2474,8 +2474,9 @@ dnnl_status_t jit_avx_gemm_f32(int nthrs, const char *transa,
return ref_gemm(transa, transb, p_m, p_n, p_k, p_alpha, A, p_lda, B,
p_lda, p_beta, C, p_ldc, bias);

bool is_transa = *transa == 'T';
bool is_transb = *transb == 'T';
int nthr_max = dnnl_get_current_num_threads();
int nthr_to_use = nstl::min(nthrs, nthr_max);

dim_t m = *p_m;
dim_t n = *p_n;
dim_t k = *p_k;
Expand All @@ -2485,10 +2486,6 @@ dnnl_status_t jit_avx_gemm_f32(int nthrs, const char *transa,
float beta = *p_beta;
dim_t MB, NB, KB;

bool use_max_nthr = !is_transa && !is_transb && m <= 3 && n >= 50;

int nthr_max = dnnl_get_current_num_threads();
int nthr_to_use = use_max_nthr ? nthr_max : nstl::min(nthrs, nthr_max);
int nthr_m = 1, nthr_n = 1, nthr_k = 1, nthr_mn = 1;

// Determine threading partitioning
Expand Down
22 changes: 16 additions & 6 deletions src/cpu/x64/gemm/gemm_driver.cpp
Expand Up @@ -1545,17 +1545,27 @@ static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) {
auto veclen = get_vector_length<T>();
const double fp_per_cycle = 2.0 * 2.0 * veclen;

if (mayiuse(avx2) && !mayiuse(avx512_core))
const bool is_f32 = data_traits<T>::data_type == data_type::f32;

const bool is_avx512_mic = mayiuse(avx512_mic);
const bool is_avx512 = mayiuse(avx512_core);
const bool is_avx = mayiuse(avx);
const bool is_only_avx2 = mayiuse(avx2) && !is_avx512;

if (is_avx512_mic) return;

// Some sgemm cases still benefit from using all threads.
const bool use_all_threads = is_f32 && n > 50
&& ((is_avx && m <= 3) || (is_avx512 && m <= 10));
if (use_all_threads) return;

if (is_only_avx2)
if (m > 10 * n && n < *nthrs)
if (m / *nthrs < veclen * 3)
*nthrs = nstl::max(m / veclen / 3, dim_t(1));

double gemm_cycles = m * n * k / fp_per_cycle;
if (data_traits<T>::data_type == data_type::f32) {
gemm_cycles *= 2.0;
} else {
gemm_cycles *= 8.0;
}
gemm_cycles *= is_f32 ? 2.0 : 8.0;

int i = *nthrs;

Expand Down

0 comments on commit 32c1110

Please sign in to comment.