diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index d95d6f88..58eb7607 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -15,8 +15,25 @@ #ifdef _MSC_VER #include #include +#include "cpu_x86.h" +void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) { + __cpuidex(out, eax, ecx); +} +__int64 xgetbv(unsigned int x) { + return _xgetbv(x); +} #else #include +#include +#include +void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { + __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); +} +uint64_t xgetbv(unsigned int index) { + uint32_t eax, edx; + __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); + return ((uint64_t)edx << 32) | eax; +} #endif #if defined(USE_AVX512) @@ -30,6 +47,65 @@ #define PORTABLE_ALIGN32 __declspec(align(32)) #define PORTABLE_ALIGN64 __declspec(align(64)) #endif + +// Adapted from https://github.com/Mysticial/FeatureDetector +#define _XCR_XFEATURE_ENABLED_MASK 0 + +bool AVXCapable() { + int cpuInfo[4]; + + // CPU support + cpuid(cpuInfo, 0, 0); + int nIds = cpuInfo[0]; + + bool HW_AVX = false; + if (nIds >= 0x00000001) { + cpuid(cpuInfo, 0x00000001, 0); + HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; + } + + // OS support + cpuid(cpuInfo, 1, 0); + + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; + + bool avxSupported = false; + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { + uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avxSupported = (xcrFeatureMask & 0x6) == 0x6; + } + return HW_AVX && avxSupported; +} + +bool AVX512Capable() { + if (!AVXCapable()) return false; + + int cpuInfo[4]; + + // CPU support + cpuid(cpuInfo, 0, 0); + int nIds = cpuInfo[0]; + + bool HW_AVX512F = false; + if (nIds >= 0x00000007) { // AVX512 Foundation + cpuid(cpuInfo, 0x00000007, 0); + HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; + } + + // OS support + cpuid(cpuInfo, 1, 0); + + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; + + bool avx512Supported = false; + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { + uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; + } + return HW_AVX512F && avx512Supported; +} #endif #include @@ -108,7 +184,6 @@ namespace hnswlib { return result; } - } #include "space_l2.h" diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index c0029bde..7cd3d020 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -18,7 +18,7 @@ namespace hnswlib { // Favor using AVX if available. static float - InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float PORTABLE_ALIGN32 TmpRes[8]; float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; @@ -64,10 +64,12 @@ namespace hnswlib { return 1.0f - sum; } -#elif defined(USE_SSE) +#endif + +#if defined(USE_SSE) static float - InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float PORTABLE_ALIGN32 TmpRes[8]; float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; @@ -128,7 +130,7 @@ namespace hnswlib { #if defined(USE_AVX512) static float - InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float PORTABLE_ALIGN64 TmpRes[16]; float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; @@ -157,10 +159,12 @@ namespace hnswlib { return 1.0f - sum; } -#elif defined(USE_AVX) +#endif + +#if defined(USE_AVX) static float - InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float PORTABLE_ALIGN32 TmpRes[8]; float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; @@ -195,10 +199,12 @@ namespace hnswlib { return 1.0f - sum; } -#elif defined(USE_SSE) +#endif + +#if defined(USE_SSE) static float - InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float PORTABLE_ALIGN32 TmpRes[8]; float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; @@ -245,6 +251,9 @@ namespace hnswlib { #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; + DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; + static float InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { size_t qty = *((size_t *) qty_ptr); @@ -283,6 +292,20 @@ namespace hnswlib { InnerProductSpace(size_t dim) { fstdistfunc_ = InnerProduct; #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + else if (AVXCapable()) + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + #endif + #if defined(USE_AVX) + if (AVXCapable()) + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + #endif + if (dim % 16 == 0) fstdistfunc_ = InnerProductSIMD16Ext; else if (dim % 4 == 0) diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 3b6a49ef..44135370 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -23,7 +23,7 @@ namespace hnswlib { // Favor using AVX512 if available. static float - L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); @@ -52,12 +52,13 @@ namespace hnswlib { return (res); } +#endif -#elif defined(USE_AVX) +#if defined(USE_AVX) // Favor using AVX if available. static float - L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); @@ -89,10 +90,12 @@ namespace hnswlib { return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; } -#elif defined(USE_SSE) +#endif + +#if defined(USE_SSE) static float - L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float *pVect1 = (float *) pVect1v; float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); @@ -141,6 +144,8 @@ namespace hnswlib { #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; + static float L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { size_t qty = *((size_t *) qty_ptr); @@ -156,7 +161,7 @@ namespace hnswlib { #endif -#ifdef USE_SSE +#if defined(USE_SSE) static float L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { float PORTABLE_ALIGN32 TmpRes[8]; @@ -208,7 +213,17 @@ namespace hnswlib { public: L2Space(size_t dim) { fstdistfunc_ = L2Sqr; - #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + if (dim % 16 == 0) fstdistfunc_ = L2SqrSIMD16Ext; else if (dim % 4 == 0) @@ -217,7 +232,7 @@ namespace hnswlib { fstdistfunc_ = L2SqrSIMD16ExtResiduals; else if (dim > 4) fstdistfunc_ = L2SqrSIMD4ExtResiduals; - #endif + #endif dim_ = dim; data_size_ = dim * sizeof(float); } diff --git a/setup.py b/setup.py index 90826dea..ddf50f75 100644 --- a/setup.py +++ b/setup.py @@ -74,8 +74,12 @@ class BuildExt(build_ext): """A custom build extension for adding compiler-specific options.""" c_opts = { 'msvc': ['/EHsc', '/openmp', '/O2'], - 'unix': ['-O3', '-march=native'], # , '-w' + #'unix': ['-O3', '-march=native'], # , '-w' + 'unix': ['-O3'], # , '-w' } + if not os.environ.get("HNSWLIB_NO_NATIVE"): + c_opts['unix'].append('-march=native') + link_opts = { 'unix': [], 'msvc': [],