Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,36 @@ namespace hnswlib {

#endif

#if defined(USE_SSE) || defined(USE_AVX)
static float
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return res + res_tail - 1.0f;
}

static float
InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;

float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);

return res + res_tail - 1.0f;
}
#endif

class InnerProductSpace : public SpaceInterface<float> {

DISTFUNC<float> fstdistfunc_;
Expand All @@ -220,11 +250,15 @@ namespace hnswlib {
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProduct;
#if defined(USE_AVX) || defined(USE_SSE)
if (dim % 4 == 0)
fstdistfunc_ = InnerProductSIMD4Ext;
if (dim % 16 == 0)
fstdistfunc_ = InnerProductSIMD16Ext;
#endif
else if (dim % 4 == 0)
fstdistfunc_ = InnerProductSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = InnerProductSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = InnerProductSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
Expand Down
79 changes: 50 additions & 29 deletions hnswlib/space_l2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
namespace hnswlib {

static float
L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) {
//return *((float *)pVect2);
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);

float res = 0;
for (unsigned i = 0; i < qty; i++) {
float t = ((float *) pVect1)[i] - ((float *) pVect2)[i];
for (size_t i = 0; i < qty; i++) {
float t = *pVect1 - *pVect2;
pVect1++;
pVect2++;
res += t * t;
}
return (res);

}

#if defined(USE_AVX)
Expand Down Expand Up @@ -49,10 +52,8 @@ namespace hnswlib {
}

_mm256_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];

return (res);
}
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
}

#elif defined(USE_SSE)

Expand All @@ -62,12 +63,9 @@ namespace hnswlib {
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8];
// size_t qty4 = qty >> 2;
size_t qty16 = qty >> 4;

const float *pEnd1 = pVect1 + (qty16 << 4);
// const float* pEnd2 = pVect1 + (qty4 << 2);
// const float* pEnd3 = pVect1 + qty;

__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
Expand Down Expand Up @@ -102,10 +100,24 @@ namespace hnswlib {
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}

_mm_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
#endif

return (res);
#if defined(USE_SSE) || defined(USE_AVX)
static float
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif

Expand All @@ -119,10 +131,9 @@ namespace hnswlib {
size_t qty = *((size_t *) qty_ptr);


// size_t qty4 = qty >> 2;
size_t qty16 = qty >> 2;
size_t qty4 = qty >> 2;

const float *pEnd1 = pVect1 + (qty16 << 2);
const float *pEnd1 = pVect1 + (qty4 << 2);

__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
Expand All @@ -136,9 +147,22 @@ namespace hnswlib {
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
_mm_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}

return (res);
static float
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;

float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);

return (res + res_tail);
}
#endif

Expand All @@ -151,13 +175,14 @@ namespace hnswlib {
L2Space(size_t dim) {
fstdistfunc_ = L2Sqr;
#if defined(USE_SSE) || defined(USE_AVX)
if (dim % 4 == 0)
fstdistfunc_ = L2SqrSIMD4Ext;
if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
/*else{
throw runtime_error("Data type not supported!");
}*/
else if (dim % 4 == 0)
fstdistfunc_ = L2SqrSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
Expand Down Expand Up @@ -185,10 +210,6 @@ namespace hnswlib {
int res = 0;
unsigned char *a = (unsigned char *) pVect1;
unsigned char *b = (unsigned char *) pVect2;
/*for (int i = 0; i < qty; i++) {
int t = int((a)[i]) - int((b)[i]);
res += t*t;
}*/

qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
Expand Down Expand Up @@ -241,4 +262,4 @@ namespace hnswlib {
};


}
}