Skip to content

Commit

Permalink
Improve distances_simd.cpp for aarch64 (#2392)
Browse files Browse the repository at this point in the history
Summary:
- use `vaddvq_f32` instead of `vpaddq_f32` and `vdups_laneq_f32` in `fvec_L2sqr` , `fvec_inner_product` , and `fvec_norm_L2sqr`
- ~~implement `fvec_L1` and `fvec_Linf` for ARM SIMD (NEON)~~
    - This causes performance regression, so I've droped it.
- implement `fvec_madd` and `fvec_madd_and_argmin` for ARM SIMD (NEON)

Pull Request resolved: #2392

Reviewed By: patricklabatut

Differential Revision: D38198174

Pulled By: mdouze

fbshipit-source-id: 3488a0cf2db1ded458b3bf73f4bc9665413e3351
  • Loading branch information
wx257osn2 authored and facebook-github-bot committed Aug 31, 2022
1 parent dbc3d1d commit dcbf33c
Showing 1 changed file with 70 additions and 6 deletions.
76 changes: 70 additions & 6 deletions faiss/utils/distances_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
float32x4_t sq = vsubq_f32(xi, yi);
accux4 = vfmaq_f32(accux4, sq, sq);
}
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
float32_t accux1 = vaddvq_f32(accux4);
for (; i < d; ++i) {
float32_t xi = x[i];
float32_t yi = y[i];
Expand All @@ -969,8 +968,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
float32x4_t yi = vld1q_f32(y + i);
accux4 = vfmaq_f32(accux4, xi, yi);
}
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
float32_t accux1 = vaddvq_f32(accux4);
for (; i < d; ++i) {
float32_t xi = x[i];
float32_t yi = y[i];
Expand All @@ -987,8 +985,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
float32x4_t xi = vld1q_f32(x + i);
accux4 = vfmaq_f32(accux4, xi, xi);
}
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
float32_t accux1 = vaddvq_f32(accux4);
for (; i < d; ++i) {
float32_t xi = x[i];
accux1 += xi * xi;
Expand Down Expand Up @@ -1186,6 +1183,22 @@ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
#endif
}

#elif defined(__aarch64__)

void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
const size_t n_simd = n - (n & 3);
const float32x4_t bfv = vdupq_n_f32(bf);
size_t i;
for (i = 0; i < n_simd; i += 4) {
const float32x4_t ai = vld1q_f32(a + i);
const float32x4_t bi = vld1q_f32(b + i);
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
vst1q_f32(c + i, ci);
}
for (; i < n; ++i)
c[i] = a[i] + bf * b[i];
}

#else

void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
Expand Down Expand Up @@ -1279,6 +1292,57 @@ int fvec_madd_and_argmin(
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
}

#elif defined(__aarch64__)

int fvec_madd_and_argmin(
size_t n,
const float* a,
float bf,
const float* b,
float* c) {
float32x4_t vminv = vdupq_n_f32(1e20);
uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
size_t i;
{
const size_t n_simd = n - (n & 3);
const uint32_t iota[] = {0, 1, 2, 3};
uint32x4_t iv = vld1q_u32(iota);
const uint32x4_t incv = vdupq_n_u32(4);
const float32x4_t bfv = vdupq_n_f32(bf);
for (i = 0; i < n_simd; i += 4) {
const float32x4_t ai = vld1q_f32(a + i);
const float32x4_t bi = vld1q_f32(b + i);
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
vst1q_f32(c + i, ci);
const uint32x4_t less_than = vcltq_f32(ci, vminv);
vminv = vminq_f32(ci, vminv);
iminv = vorrq_u32(
vandq_u32(less_than, iv),
vandq_u32(vmvnq_u32(less_than), iminv));
iv = vaddq_u32(iv, incv);
}
}
float vmin = vminvq_f32(vminv);
uint32_t imin;
{
const float32x4_t vminy = vdupq_n_f32(vmin);
const uint32x4_t equals = vceqq_f32(vminv, vminy);
imin = vminvq_u32(vorrq_u32(
vandq_u32(equals, iminv),
vandq_u32(
vmvnq_u32(equals),
vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
}
for (; i < n; ++i) {
c[i] = a[i] + bf * b[i];
if (c[i] < vmin) {
vmin = c[i];
imin = static_cast<uint32_t>(i);
}
}
return static_cast<int>(imin);
}

#else

int fvec_madd_and_argmin(
Expand Down

0 comments on commit dcbf33c

Please sign in to comment.