Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.5"
var vecVersion = "1.0.15"
var vecVersion = "1.0.16"

repositories {
exclusiveContent {
Expand Down
13 changes: 0 additions & 13 deletions libs/simdvec/includes.txt

This file was deleted.

10 changes: 5 additions & 5 deletions libs/simdvec/native/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ model {
gcc(Gcc) {
target("aarch64") {
cCompiler.executable = "/usr/bin/gcc"
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=armv8-a"]) }
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=armv8-a"]) }
}
target("amd64") {
cCompiler.executable = "/usr/bin/gcc"
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2", "-Wno-incompatible-pointer-types"]) }
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=core-avx2", "-Wno-incompatible-pointer-types"]) }
cppCompiler.executable = "/usr/bin/g++"
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) }
}
Expand All @@ -67,11 +67,11 @@ model {
}
clang(Clang) {
target("aarch64") {
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=armv8-a"]) }
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=armv8-a"]) }
}

target("amd64") {
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c99", "-march=core-avx2"]) }
cCompiler.withArguments { args -> args.addAll(["-O3", "-std=c11", "-march=core-avx2"]) }
cppCompiler.withArguments { args -> args.addAll(["-O3", "-march=core-avx2"]) }
}
}
Expand Down Expand Up @@ -134,7 +134,7 @@ tasks.register('buildSharedLibraryAndCopy') {
description = 'Assembles native shared library for the host architecture and copies to libs/native/libraries/build/platform/'
// check if `LOCAL_VEC_BINARY_OS` is set, if not throw an error to prevent accidental overwrites
if (System.getenv("LOCAL_VEC_BINARY_OS") == null){
throw new GradleException("LOCAL_VEC_BINARY_OS is set, skipping copy to prevent overwriting local binary.")
throw new GradleException("LOCAL_VEC_BINARY_OS is not set, skipping copy to prevent overwriting local binary.")
}
dependsOn "buildSharedLibrary"
doLast {
Expand Down
2 changes: 1 addition & 1 deletion libs/simdvec/native/publish_vec_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi

VERSION="1.0.15"
VERSION="1.0.16"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)

Expand Down
9 changes: 9 additions & 0 deletions libs/simdvec/native/settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@
*/

rootProject.name = 'vec'

buildCache {
local {
enabled = false
}
remote(HttpBuildCache) {
enabled = false
}
}
62 changes: 31 additions & 31 deletions libs/simdvec/native/src/vec/c/aarch64/vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ EXPORT int vec_caps() {
#endif
}

static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) {
static inline int32_t dot7u_inner(int8_t* a, int8_t* b, const int32_t dims) {
// We have contention in the instruction pipeline on the accumulation
// registers if we use too few.
int32x4_t acc1 = vdupq_n_s32(0);
Expand Down Expand Up @@ -82,7 +82,7 @@ static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) {
return vaddvq_s32(vaddq_s32(acc5, acc6));
}

EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
EXPORT int32_t dot7u(int8_t* a, int8_t* b, const int32_t dims) {
int32_t res = 0;
int i = 0;
if (dims > DOT7U_STRIDE_BYTES_LEN) {
Expand All @@ -99,7 +99,7 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
int32_t res = 0;
if (dims > DOT7U_STRIDE_BYTES_LEN) {
const int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
for (size_t c = 0; c < count; c++) {
for (int32_t c = 0; c < count; c++) {
int i = limit;
res = dot7u_inner(a, b, i);
for (; i < dims; i++) {
Expand All @@ -109,9 +109,9 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
a += dims;
}
} else {
for (size_t c = 0; c < count; c++) {
for (int32_t c = 0; c < count; c++) {
res = 0;
for (size_t i = 0; i < dims; i++) {
for (int32_t i = 0; i < dims; i++) {
res += a[i] * b[i];
}
results[c] = (float_t)res;
Expand All @@ -120,7 +120,7 @@ EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int
}
}

static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) {
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
int32x4_t acc1 = vdupq_n_s32(0);
int32x4_t acc2 = vdupq_n_s32(0);
int32x4_t acc3 = vdupq_n_s32(0);
Expand All @@ -145,7 +145,7 @@ static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) {
return vaddvq_s32(vaddq_s32(acc5, acc6));
}

EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) {
EXPORT int32_t sqr7u(int8_t* a, int8_t* b, const int32_t dims) {
int32_t res = 0;
int i = 0;
if (dims > SQR7U_STRIDE_BYTES_LEN) {
Expand All @@ -161,10 +161,10 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) {

// --- single precision floats

// const float *a pointer to the first float vector
// const float *b pointer to the second float vector
// size_t elementCount the number of floating point elements
EXPORT float dotf32(const float *a, const float *b, size_t elementCount) {
// const f32_t *a pointer to the first float vector
// const f32_t *b pointer to the second float vector
// const int32_t elementCount the number of floating point elements
EXPORT f32_t dotf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
float32x4_t sum0 = vdupq_n_f32(0.0f);
float32x4_t sum1 = vdupq_n_f32(0.0f);
float32x4_t sum2 = vdupq_n_f32(0.0f);
Expand All @@ -174,9 +174,9 @@ EXPORT float dotf32(const float *a, const float *b, size_t elementCount) {
float32x4_t sum6 = vdupq_n_f32(0.0f);
float32x4_t sum7 = vdupq_n_f32(0.0f);

size_t i = 0;
int32_t i = 0;
// Each float32x4_t holds 4 floats, so unroll 8x = 32 floats per loop
size_t unrolled_limit = elementCount & ~31UL;
int32_t unrolled_limit = elementCount & ~31UL;
for (; i < unrolled_limit; i += 32) {
sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i));
sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4));
Expand All @@ -192,7 +192,7 @@ EXPORT float dotf32(const float *a, const float *b, size_t elementCount) {
vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)),
vaddq_f32(vaddq_f32(sum4, sum5), vaddq_f32(sum6, sum7))
);
float result = vaddvq_f32(total);
f32_t result = vaddvq_f32(total);

// Handle remaining elements
for (; i < elementCount; ++i) {
Expand All @@ -202,10 +202,10 @@ EXPORT float dotf32(const float *a, const float *b, size_t elementCount) {
return result;
}

// const float *a pointer to the first float vector
// const float *b pointer to the second float vector
// size_t elementCount the number of floating point elements
EXPORT float cosf32(const float *a, const float *b, size_t elementCount) {
// const f32_t *a pointer to the first float vector
// const f32_t *b pointer to the second float vector
// const int32_t elementCount the number of floating point elements
EXPORT f32_t cosf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
float32x4_t sum0 = vdupq_n_f32(0.0f);
float32x4_t sum1 = vdupq_n_f32(0.0f);
float32x4_t sum2 = vdupq_n_f32(0.0f);
Expand All @@ -221,9 +221,9 @@ EXPORT float cosf32(const float *a, const float *b, size_t elementCount) {
float32x4_t norm_b2 = vdupq_n_f32(0.0f);
float32x4_t norm_b3 = vdupq_n_f32(0.0f);

size_t i = 0;
int32_t i = 0;
// Each float32x4_t holds 4 floats, so unroll 4x = 16 floats per loop
size_t unrolled_limit = elementCount & ~15UL;
int32_t unrolled_limit = elementCount & ~15UL;
for (; i < unrolled_limit; i += 16) {
float32x4_t va0 = vld1q_f32(a + i);
float32x4_t vb0 = vld1q_f32(b + i);
Expand Down Expand Up @@ -257,27 +257,27 @@ EXPORT float cosf32(const float *a, const float *b, size_t elementCount) {
float32x4_t norms_a = vaddq_f32(vaddq_f32(norm_a0, norm_a1), vaddq_f32(norm_a2, norm_a3));
float32x4_t norms_b = vaddq_f32(vaddq_f32(norm_b0, norm_b1), vaddq_f32(norm_b2, norm_b3));

float dot = vaddvq_f32(sums);
float norm_a = vaddvq_f32(norms_a);
float norm_b = vaddvq_f32(norms_b);
f32_t dot = vaddvq_f32(sums);
f32_t norm_a = vaddvq_f32(norms_a);
f32_t norm_b = vaddvq_f32(norms_b);

// Handle remaining tail elements
for (; i < elementCount; ++i) {
float va = a[i];
float vb = b[i];
f32_t va = a[i];
f32_t vb = b[i];
dot += va * vb;
norm_a += va * va;
norm_b += vb * vb;
}

float denom = sqrtf(norm_a) * sqrtf(norm_b);
f32_t denom = sqrtf(norm_a) * sqrtf(norm_b);
if (denom == 0.0f) {
return 0.0f;
}
return dot / denom;
}

EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) {
EXPORT f32_t sqrf32(const f32_t *a, const f32_t *b, const int32_t elementCount) {
float32x4_t sum0 = vdupq_n_f32(0.0f);
float32x4_t sum1 = vdupq_n_f32(0.0f);
float32x4_t sum2 = vdupq_n_f32(0.0f);
Expand All @@ -287,9 +287,9 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) {
float32x4_t sum6 = vdupq_n_f32(0.0f);
float32x4_t sum7 = vdupq_n_f32(0.0f);

size_t i = 0;
int32_t i = 0;
// Each float32x4_t holds 4 floats, so unroll 8x = 32 floats per loop
size_t unrolled_limit = elementCount & ~31UL;
int32_t unrolled_limit = elementCount & ~31UL;
for (; i < unrolled_limit; i += 32) {
float32x4_t d0 = vsubq_f32(vld1q_f32(a + i), vld1q_f32(b + i));
float32x4_t d1 = vsubq_f32(vld1q_f32(a + i + 4), vld1q_f32(b + i + 4));
Expand All @@ -314,11 +314,11 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) {
vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)),
vaddq_f32(vaddq_f32(sum4, sum5), vaddq_f32(sum6, sum7))
);
float result = vaddvq_f32(total);
f32_t result = vaddvq_f32(total);

// Handle remaining tail elements
for (; i < elementCount; ++i) {
float diff = a[i] - b[i];
f32_t diff = a[i] - b[i];
result += diff * diff;
}

Expand Down
Loading