Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve distances_simd.cpp for aarch64 #2392

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
76 changes: 70 additions & 6 deletions faiss/utils/distances_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
float32x4_t sq = vsubq_f32(xi, yi);
accux4 = vfmaq_f32(accux4, sq, sq);
}
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
float32_t accux1 = vaddvq_f32(accux4);
for (; i < d; ++i) {
float32_t xi = x[i];
float32_t yi = y[i];
Expand All @@ -969,8 +968,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
float32x4_t yi = vld1q_f32(y + i);
accux4 = vfmaq_f32(accux4, xi, yi);
}
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
float32_t accux1 = vaddvq_f32(accux4);
for (; i < d; ++i) {
float32_t xi = x[i];
float32_t yi = y[i];
Expand All @@ -987,8 +985,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
float32x4_t xi = vld1q_f32(x + i);
accux4 = vfmaq_f32(accux4, xi, xi);
}
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
float32_t accux1 = vaddvq_f32(accux4);
for (; i < d; ++i) {
float32_t xi = x[i];
accux1 += xi * xi;
Expand Down Expand Up @@ -1186,6 +1183,22 @@ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
#endif
}

#elif defined(__aarch64__)

void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
const size_t n_simd = n - (n & 3);
const float32x4_t bfv = vdupq_n_f32(bf);
size_t i;
for (i = 0; i < n_simd; i += 4) {
const float32x4_t ai = vld1q_f32(a + i);
const float32x4_t bi = vld1q_f32(b + i);
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
vst1q_f32(c + i, ci);
}
for (; i < n; ++i)
c[i] = a[i] + bf * b[i];
}

#else

void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
Expand Down Expand Up @@ -1279,6 +1292,57 @@ int fvec_madd_and_argmin(
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
}

#elif defined(__aarch64__)

int fvec_madd_and_argmin(
size_t n,
const float* a,
float bf,
const float* b,
float* c) {
float32x4_t vminv = vdupq_n_f32(1e20);
uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
size_t i;
{
const size_t n_simd = n - (n & 3);
const uint32_t iota[] = {0, 1, 2, 3};
uint32x4_t iv = vld1q_u32(iota);
const uint32x4_t incv = vdupq_n_u32(4);
const float32x4_t bfv = vdupq_n_f32(bf);
for (i = 0; i < n_simd; i += 4) {
const float32x4_t ai = vld1q_f32(a + i);
const float32x4_t bi = vld1q_f32(b + i);
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
vst1q_f32(c + i, ci);
const uint32x4_t less_than = vcltq_f32(ci, vminv);
vminv = vminq_f32(ci, vminv);
iminv = vorrq_u32(
vandq_u32(less_than, iv),
vandq_u32(vmvnq_u32(less_than), iminv));
iv = vaddq_u32(iv, incv);
}
}
float vmin = vminvq_f32(vminv);
uint32_t imin;
{
const float32x4_t vminy = vdupq_n_f32(vmin);
const uint32x4_t equals = vceqq_f32(vminv, vminy);
imin = vminvq_u32(vorrq_u32(
vandq_u32(equals, iminv),
vandq_u32(
vmvnq_u32(equals),
vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
}
for (; i < n; ++i) {
c[i] = a[i] + bf * b[i];
if (c[i] < vmin) {
vmin = c[i];
imin = static_cast<uint32_t>(i);
}
}
return static_cast<int>(imin);
}

#else

int fvec_madd_and_argmin(
Expand Down