Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 35c6d10

Browse files
authored
mha enhance (#180)
* mha enhance * enable MHA_AVX2 on client
1 parent c71a27d commit 35c6d10

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

bestla/bestla/kernel_avx2.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,12 @@ inline __m256 exp_ps_0_1(const __m256 x) {
11741174
static const auto log2e = _mm256_set1_ps(v_log2e);
11751175
static const auto half = _mm256_set1_ps(.5f);
11761176

1177-
const auto x1 = _mm256_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f);
1177+
static const auto upper_bound = _mm256_set1_ps(88.722838); // log(max_positive_float)
1178+
static const auto lower_bound = _mm256_set1_ps(-87.336549); // log(min_positive_float)
1179+
__m256 x1 = _mm256_min_ps(x, upper_bound);
1180+
x1 = _mm256_max_ps(x1, lower_bound);
1181+
1182+
x1 = _mm256_fmadd_ps(x1, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f);
11781183
const auto z = _mm256_floor_ps(x1);
11791184
const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z;
11801185

neural_speed/core/layers/mha_dense.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ bool bestla_reordered_attn_fp32_support(const attn_shape_t* params) {
7474
#endif
7575
// use avx2 and f16c on avx2 platforms
7676
// todo: check avx2 mha on sever
77-
return false;
77+
return !_cd->AVX512F() && _cd->AVX2();
7878
}
7979
// kv cache sizes in bytes per layer per batch per beam for;
8080
void bestla_reordered_attn_fp32_batch_kv_info(const kv_shape_t* params, kv_cache_info_t* out) {

0 commit comments

Comments
 (0)