diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..a6fa563c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "ThirdParty/zstd"] + path = ThirdParty/zstd + url = https://github.com/facebook/zstd + branch = release diff --git a/AnnService/Aggregator.vcxproj b/AnnService/Aggregator.vcxproj index af24dded..4946c13f 100644 --- a/AnnService/Aggregator.vcxproj +++ b/AnnService/Aggregator.vcxproj @@ -165,6 +165,7 @@ + @@ -177,5 +178,6 @@ + \ No newline at end of file diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt index a61b08a3..470f7d65 100644 --- a/AnnService/CMakeLists.txt +++ b/AnnService/CMakeLists.txt @@ -2,8 +2,10 @@ # Licensed under the MIT License. set(AnnService ${PROJECT_SOURCE_DIR}/AnnService) +set(Zstd ${PROJECT_SOURCE_DIR}/ThirdParty/zstd) include_directories(${AnnService}) +include_directories(${Zstd}/lib) file(GLOB_RECURSE HDR_FILES ${AnnService}/inc/Core/*.h ${AnnService}/inc/Helper/*.h) file(GLOB_RECURSE SRC_FILES ${AnnService}/src/Core/*.cpp ${AnnService}/src/Helper/*.cpp) @@ -32,9 +34,9 @@ if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") endif() add_library (SPTAGLib SHARED ${SRC_FILES} ${HDR_FILES}) -target_link_libraries (SPTAGLib DistanceUtils) +target_link_libraries (SPTAGLib DistanceUtils libzstd_shared) add_library (SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES}) -target_link_libraries (SPTAGLibStatic DistanceUtils) +target_link_libraries (SPTAGLibStatic DistanceUtils libzstd_static) if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") target_compile_options(SPTAGLibStatic PRIVATE -fPIC) endif() diff --git a/AnnService/Client.vcxproj b/AnnService/Client.vcxproj index f88234be..9381af59 100644 --- a/AnnService/Client.vcxproj +++ b/AnnService/Client.vcxproj @@ -132,6 +132,7 @@ + @@ -144,5 +145,6 @@ + \ No newline at end of file diff --git a/AnnService/CoreLibrary.vcxproj b/AnnService/CoreLibrary.vcxproj index 386c0f38..4cd347d3 100644 --- a/AnnService/CoreLibrary.vcxproj +++ b/AnnService/CoreLibrary.vcxproj @@ -160,6 +160,7 @@ + @@ -219,5 +220,12 @@ + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + \ No newline at end of file diff --git a/AnnService/CoreLibrary.vcxproj.filters b/AnnService/CoreLibrary.vcxproj.filters index 453a4795..f260b078 100644 --- a/AnnService/CoreLibrary.vcxproj.filters +++ b/AnnService/CoreLibrary.vcxproj.filters @@ -214,6 +214,9 @@ Header Files\Core\Common + + Header Files\Core\SPANN + diff --git a/AnnService/IndexBuilder.vcxproj b/AnnService/IndexBuilder.vcxproj index f82825fa..0900590c 100644 --- a/AnnService/IndexBuilder.vcxproj +++ b/AnnService/IndexBuilder.vcxproj @@ -155,6 +155,7 @@ + @@ -167,5 +168,6 @@ + \ No newline at end of file diff --git a/AnnService/IndexSearcher.vcxproj b/AnnService/IndexSearcher.vcxproj index 88214858..6d137837 100644 --- a/AnnService/IndexSearcher.vcxproj +++ b/AnnService/IndexSearcher.vcxproj @@ -156,6 +156,7 @@ + @@ -168,5 +169,6 @@ + \ No newline at end of file diff --git a/AnnService/Quantizer.vcxproj b/AnnService/Quantizer.vcxproj index bdcebcd0..942e55e1 100644 --- a/AnnService/Quantizer.vcxproj +++ b/AnnService/Quantizer.vcxproj @@ -171,6 +171,7 @@ + @@ -183,5 +184,6 @@ + \ No newline at end of file diff --git a/AnnService/SSDServing.vcxproj b/AnnService/SSDServing.vcxproj index bf37ca74..f6676597 100644 --- a/AnnService/SSDServing.vcxproj +++ b/AnnService/SSDServing.vcxproj @@ -1,4 +1,4 @@ - + @@ -27,6 +27,9 @@ + + + 15.0 {217B42B7-8F2B-4323-804C-08992CA2F65E} @@ -177,6 +180,7 @@ + @@ -189,5 +193,6 @@ + \ No newline at end of file diff --git a/AnnService/SSDServing.vcxproj.filters b/AnnService/SSDServing.vcxproj.filters index 8f36f5f3..95d60611 100644 --- a/AnnService/SSDServing.vcxproj.filters +++ b/AnnService/SSDServing.vcxproj.filters @@ -27,4 +27,7 @@ Source Files + + + \ No newline at end of file diff --git a/AnnService/Server.vcxproj b/AnnService/Server.vcxproj index fe9a7c8d..3b38afe4 100644 --- a/AnnService/Server.vcxproj +++ b/AnnService/Server.vcxproj @@ -140,6 +140,7 @@ + @@ -152,5 +153,6 @@ + \ No newline at end of file diff --git a/AnnService/inc/Core/SPANN/Compressor.h b/AnnService/inc/Core/SPANN/Compressor.h new file mode 100644 index 00000000..86100608 --- /dev/null +++ b/AnnService/inc/Core/SPANN/Compressor.h @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SPANN_COMPRESSOR_H_ +#define _SPTAG_SPANN_COMPRESSOR_H_ + +#include +#include "zstd.h" +#include "zdict.h" +#include "../Common.h" + +namespace SPTAG +{ + namespace SPANN + { + class Compressor + { + private: + void CreateCDict() + { + cdict = ZSTD_createCDict((void *)dictBuffer.data(), dictBuffer.size(), compress_level); + if (cdict == NULL) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD_createCDict() failed! \n"); + exit(1); + } + } + + void CreateDDict() + { + ddict = ZSTD_createDDict((void *)dictBuffer.data(), dictBuffer.size()); + if (ddict == NULL) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD_createDDict() failed! \n"); + throw std::runtime_error("ZSTD_createDDict() failed!"); + } + } + + std::string CompressWithDict(const std::string &src) + { + size_t est_compress_size = ZSTD_compressBound(src.size()); + std::string comp_buffer{}; + comp_buffer.resize(est_compress_size); + + ZSTD_CCtx *const cctx = ZSTD_createCCtx(); + if (cctx == NULL) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD_createCCtx() failed! \n"); + exit(1); + } + size_t compressed_size = ZSTD_compress_usingCDict(cctx, (void *)comp_buffer.data(), est_compress_size, src.data(), src.size(), cdict); + if (ZSTD_isError(compressed_size)) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD compress error %s, \n", ZSTD_getErrorName(compressed_size)); + exit(1); + } + ZSTD_freeCCtx(cctx); + comp_buffer.resize(compressed_size); + comp_buffer.shrink_to_fit(); + + return comp_buffer; + } + + std::size_t DecompressWithDict(const char* src, size_t srcSize, char* dst, size_t dstCapacity) + { + ZSTD_DCtx* const dctx = ZSTD_createDCtx(); + if (dctx == NULL) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD_createDCtx() failed! \n"); + throw std::runtime_error("ZSTD_createDCtx() failed!"); + } + std::size_t const decomp_size = ZSTD_decompress_usingDDict(dctx, + (void*)dst, dstCapacity, src, srcSize, ddict); + if (ZSTD_isError(decomp_size)) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD decompress error %s, \n", ZSTD_getErrorName(decomp_size)); + throw std::runtime_error("ZSTD decompress failed."); + } + ZSTD_freeDCtx(dctx); + return decomp_size; + } + + std::string CompressWithoutDict(const std::string &src) + { + size_t est_comp_size = ZSTD_compressBound(src.size()); + std::string buffer{}; + buffer.resize(est_comp_size); + size_t compressed_size = ZSTD_compress((void *)buffer.data(), est_comp_size, + src.data(), src.size(), compress_level); + if (ZSTD_isError(compressed_size)) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD compress error %s, \n", ZSTD_getErrorName(compressed_size)); + exit(1); + } + buffer.resize(compressed_size); + buffer.shrink_to_fit(); + + return buffer; + } + + std::size_t DecompressWithoutDict(const char *src, size_t srcSize, char* dst, size_t dstCapacity) + { + std::size_t const decomp_size = ZSTD_decompress( + (void *)dst, dstCapacity, src, srcSize); + if (ZSTD_isError(decomp_size)) + { + LOG(Helper::LogLevel::LL_Error, "ZSTD decompress error %s, \n", ZSTD_getErrorName(decomp_size)); + throw std::runtime_error("ZSTD decompress failed."); + } + + return decomp_size; + } + + public: + Compressor(int level = 0, int bufferCapacity = 102400) + { + compress_level = level; + dictBufferCapacity = bufferCapacity; + cdict = nullptr; + ddict = nullptr; + } + + virtual ~Compressor() {} + + std::size_t TrainDict(const std::string &samplesBuffer, const size_t *samplesSizes, unsigned nbSamples) + { + dictBuffer.resize(dictBufferCapacity); + size_t dictSize = ZDICT_trainFromBuffer((void *)dictBuffer.data(), dictBufferCapacity, (void *)samplesBuffer.data(), &samplesSizes[0], nbSamples); + if (ZDICT_isError(dictSize)) + { + LOG(Helper::LogLevel::LL_Error, "ZDICT_trainFromBuffer() failed: %s \n", ZDICT_getErrorName(dictSize)); + exit(1); + } + dictBuffer.resize(dictSize); + dictBuffer.shrink_to_fit(); + + CreateCDict(); + + return dictSize; + } + + std::string GetDictBuffer() + { + return dictBuffer; + } + + void SetDictBuffer(const std::string &buffer) + { + dictBuffer = buffer; + CreateDDict(); + } + + std::string Compress(const std::string &src, const bool useDict) + { + return useDict ? CompressWithDict(src) : CompressWithoutDict(src); + } + + std::size_t Decompress(const char *src, size_t srcSize, char* dst, size_t dstCapacity, const bool useDict) + { + return useDict ? DecompressWithDict(src, srcSize, dst, dstCapacity) : DecompressWithoutDict(src, srcSize, dst, dstCapacity); + } + + // return the compressed sie + size_t GetCompressedSize(const std::string &src, bool useDict) + { + std::string dst = Compress(src, useDict); + return dst.size(); + } + + private: + int compress_level; + + std::string dictBuffer; + size_t dictBufferCapacity; + ZSTD_CDict *cdict; + ZSTD_DDict *ddict; + }; + } // SPANN +} // SPTAG + +#endif // _SPTAG_SPANN_COMPRESSOR_H_ diff --git a/AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h b/AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h index 0afb0eb7..815530b0 100644 --- a/AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h +++ b/AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h @@ -8,11 +8,13 @@ #include "inc/Helper/AsyncFileReader.h" #include "IExtraSearcher.h" #include "../Common/TruthSet.h" +#include "Compressor.h" #include #include #include #include +#include namespace SPTAG { @@ -79,11 +81,19 @@ namespace SPTAG } }; -#define ProcessPosting(vectorInfoSize) \ - for (char *vectorInfo = buffer + listInfo->pageOffset, *vectorInfoEnd = vectorInfo + listInfo->listEleCount * vectorInfoSize; vectorInfo < vectorInfoEnd; vectorInfo += vectorInfoSize) { \ - int vectorID = *(reinterpret_cast(vectorInfo)); \ +#define ProcessPosting(p_postingListFullData, vectorInfoSize, m_enablePostingListRearrange, m_enableDeltaEncoding, headVector) \ + for (int i = 0; i < listInfo->listEleCount; i++) { \ + if (m_enableDeltaEncoding) { \ + ValueType* leaf = m_enablePostingListRearrange ? reinterpret_cast(p_postingListFullData + (vectorInfoSize - sizeof(int)) * i) : reinterpret_cast(p_postingListFullData + vectorInfoSize * i + sizeof(int)); \ + for (auto i = 0; i < p_index->GetFeatureDim(); i++) { \ + leaf[i] += headVector[i]; \ + } \ + } \ + uint64_t offsetVectorID = m_enablePostingListRearrange ? (vectorInfoSize - sizeof(int)) * listInfo->listEleCount + sizeof(int) * i : vectorInfoSize * i; \ + int vectorID = *(reinterpret_cast(p_postingListFullData + offsetVectorID));\ if (p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) continue; \ - auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), vectorInfo + sizeof(int)); \ + uint64_t offsetVector = m_enablePostingListRearrange ? (vectorInfoSize - sizeof(int)) * i : vectorInfoSize * i + sizeof(int); \ + auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), p_postingListFullData + offsetVector); \ queryResults.AddPoint(vectorID, distance2leaf); \ } \ @@ -93,6 +103,10 @@ namespace SPTAG public: ExtraFullGraphSearcher() { + m_enableDeltaEncoding = false; + m_enablePostingListRearrange = false; + m_enableDataCompression = false; + m_enableDictTraining = true; } virtual ~ExtraFullGraphSearcher() @@ -121,6 +135,10 @@ namespace SPTAG m_indexFiles.emplace_back(curIndexFile); m_listInfos.emplace_back(0); + m_enableDeltaEncoding = p_opt.m_enableDeltaEncoding; + m_enablePostingListRearrange = p_opt.m_enablePostingListRearrange; + m_enableDataCompression = p_opt.m_enableDataCompression; + m_enableDictTraining = p_opt.m_enableDictTraining; m_totalListCount += LoadingHeadInfo(curFile, p_opt.m_searchPostingPageLimit, m_listInfos.back()); curFile = m_extraFullGraphFile + "_" + std::to_string(m_indexFiles.size()); @@ -136,7 +154,8 @@ namespace SPTAG virtual void SearchIndex(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_queryResults, std::shared_ptr p_index, - SearchStats* p_stats, std::set* truth, std::map>* found) + SearchStats* p_stats, + std::set* truth, std::map>* found) { const uint32_t postingListCount = static_cast(p_exWorkSpace->m_postingIDs.size()); @@ -154,7 +173,6 @@ namespace SPTAG for (uint32_t pi = 0; pi < postingListCount; ++pi) { auto curPostingID = p_exWorkSpace->m_postingIDs[pi]; - int fileid = 0; ListInfo* listInfo; if (oneContext) { @@ -187,11 +205,43 @@ namespace SPTAG #ifdef BATCH_READ // async batch read auto vectorInfoSize = m_vectorInfoSize; - request.m_callback = [&p_exWorkSpace, &queryResults, &p_index, vectorInfoSize](Helper::AsyncReadRequest* request) + + request.m_callback = [&p_exWorkSpace, &queryResults, &p_index, vectorInfoSize, curPostingID, m_enableDeltaEncoding = m_enableDeltaEncoding, m_enablePostingListRearrange = m_enablePostingListRearrange, m_enableDictTraining = m_enableDictTraining, m_enableDataCompression = m_enableDataCompression, &m_pCompressor = m_pCompressor](Helper::AsyncReadRequest *request) { char* buffer = request->m_buffer; ListInfo* listInfo = (ListInfo*)(request->m_payload); - ProcessPosting(vectorInfoSize) + + // decompress posting list + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + p_postingListFullData = (char*)p_exWorkSpace->m_decompressBuffer.GetBuffer(); + if (listInfo->listEleCount != 0) + { + std::size_t sizePostingListFullData; + try { + sizePostingListFullData = m_pCompressor->Decompress(buffer + listInfo->pageOffset, listInfo->listTotalBytes, p_postingListFullData, listInfo->listEleCount * vectorInfoSize, m_enableDictTraining); + } + catch (std::runtime_error &err) { + LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", curPostingID, err.what()); + return; + } + if (sizePostingListFullData != listInfo->listEleCount * vectorInfoSize) + { + LOG(Helper::LogLevel::LL_Error, "PostingList %d decompressed size not match! %zu, %d, \n", curPostingID, sizePostingListFullData, listInfo->listEleCount * vectorInfoSize); + return; + } + } + } + + // delta encoding + ValueType* headVector = nullptr; + if (m_enableDeltaEncoding) + { + headVector = (ValueType*)p_index->GetSample(curPostingID); + } + + ProcessPosting(const_cast(p_postingListFullData), vectorInfoSize, m_enablePostingListRearrange, m_enableDeltaEncoding, headVector); }; #else // async read request.m_callback = [&p_exWorkSpace](Helper::AsyncReadRequest* request) @@ -240,9 +290,25 @@ namespace SPTAG ListInfo* listInfo = &(m_listInfos[curPostingID / m_listPerFile][curPostingID % m_listPerFile]); char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[pi]).GetBuffer()); - for (int i = 0; i < listInfo->listEleCount; ++i) { - char* vectorInfo = buffer + listInfo->pageOffset + i * m_vectorInfoSize; - int vectorID = *(reinterpret_cast(vectorInfo)); + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + p_postingListFullData = (char*)p_exWorkSpace->m_decompressBuffer.GetBuffer(); + if (listInfo->listEleCount != 0) + { + try { + m_pCompressor->Decompress(buffer + listInfo->pageOffset, listInfo->listTotalBytes, p_postingListFullData, listInfo->listEleCount * m_vectorInfoSize, m_enableDictTraining); + } + catch (std::runtime_error& err) { + LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", curPostingID, err.what()); + continue; + } + } + } + + for (size_t i = 0; i < listInfo->listEleCount; ++i) { + uint64_t offsetVectorID = m_enablePostingListRearrange ? (m_vectorInfoSize - sizeof(int)) * listInfo->listEleCount + sizeof(int) * i : m_vectorInfoSize * i; \ + int vectorID = *(reinterpret_cast(p_postingListFullData + offsetVectorID)); \ if (truth && truth->count(vectorID)) (*found)[curPostingID].insert(vectorID); } } @@ -256,6 +322,65 @@ namespace SPTAG } } + std::string GetPostingListFullData( + int postingListId, + size_t p_postingListSize, + Selection &p_selections, + std::shared_ptr p_fullVectors, + bool m_enableDeltaEncoding = false, + bool m_enablePostingListRearrange = false, + const ValueType *headVector = nullptr) + { + std::string postingListFullData(""); + std::string vectors(""); + std::string vectorIDs(""); + size_t selectIdx = p_selections.lower_bound(postingListId); + // iterate over all the vectors in the posting list + for (int i = 0; i < p_postingListSize; ++i) + { + if (p_selections[selectIdx].node != postingListId) + { + LOG(Helper::LogLevel::LL_Error, "Selection ID NOT MATCH! node:%d offset:%zu\n", postingListId, selectIdx); + exit(1); + } + std::string vectorID(""); + std::string vector(""); + + int vid = p_selections[selectIdx++].tonode; + vectorID.append(reinterpret_cast(&vid), sizeof(int)); + + ValueType *p_vector = reinterpret_cast(p_fullVectors->GetVector(vid)); + if (m_enableDeltaEncoding) + { + DimensionType n = p_fullVectors->Dimension(); + std::vector p_vector_delta(n); + for (auto j = 0; j < n; j++) + { + p_vector_delta[j] = p_vector[j] - headVector[j]; + } + vector.append(reinterpret_cast(&p_vector_delta[0]), p_fullVectors->PerVectorDataSize()); + } + else + { + vector.append(reinterpret_cast(p_vector), p_fullVectors->PerVectorDataSize()); + } + + if (m_enablePostingListRearrange) + { + vectorIDs += vectorID; + vectors += vector; + } + else + { + postingListFullData += (vectorID + vector); + } + } + if (m_enablePostingListRearrange) + { + return vectors + vectorIDs; + } + return postingListFullData; + } bool BuildIndex(std::shared_ptr& p_reader, std::shared_ptr p_headIndex, Options& p_opt) { std::string outputFile = p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdIndex; @@ -267,7 +392,6 @@ namespace SPTAG int numThreads = p_opt.m_iSSDNumberOfThreads; int candidateNum = p_opt.m_internalResultNum; - std::unordered_set headVectorIDS; if (p_opt.m_headIDFile.empty()) { LOG(Helper::LogLevel::LL_Error, "Not found VectorIDTranslate!\n"); @@ -444,6 +568,7 @@ namespace SPTAG auto t4 = std::chrono::high_resolution_clock::now(); LOG(SPTAG::Helper::LogLevel::LL_Info, "Time to perform posting cut:%.2lf sec.\n", ((double)std::chrono::duration_cast(t4 - t3).count()) + ((double)std::chrono::duration_cast(t4 - t3).count()) / 1000); + // number of posting lists per file size_t postingFileSize = (postingListSize.size() + p_opt.m_ssdIndexFileNum - 1) / p_opt.m_ssdIndexFileNum; std::vector selectionsBatchOffset(p_opt.m_ssdIndexFileNum + 1, 0); for (int i = 0; i < p_opt.m_ssdIndexFileNum; i++) { @@ -455,24 +580,102 @@ namespace SPTAG auto fullVectors = p_reader->GetVectorSet(); if (p_opt.m_distCalcMethod == DistCalcMethod::Cosine && !p_reader->IsNormalized() && !p_headIndex->m_pQuantizer) fullVectors->Normalize(p_opt.m_iSSDNumberOfThreads); + + // get compressed size of each posting list + std::vector postingListBytes(headVectorIDS.size()); + + if (p_opt.m_enableDataCompression) + { + m_pCompressor = std::make_unique(p_opt.m_zstdCompressLevel, p_opt.m_dictBufferCapacity); + LOG(Helper::LogLevel::LL_Info, "Getting compressed size of each posting list...\n"); + + LOG(Helper::LogLevel::LL_Info, "Training dictionary...\n"); + std::string samplesBuffer(""); + std::vector samplesSizes; + for (int i = 0; i < postingListSize.size(); i++) { + if (postingListSize[i] == 0) { + continue; + } + ValueType* headVector = nullptr; + if (p_opt.m_enableDeltaEncoding) + { + headVector = (ValueType*)p_headIndex->GetSample(i); + } + std::string postingListFullData = GetPostingListFullData( + i, postingListSize[i], selections, fullVectors, p_opt.m_enableDeltaEncoding, p_opt.m_enablePostingListRearrange, headVector); + + samplesBuffer += postingListFullData; + samplesSizes.push_back(postingListFullData.size()); + if (samplesBuffer.size() > p_opt.m_minDictTraingBufferSize) break; + } + LOG(Helper::LogLevel::LL_Info, "Using the first %zu postingLists to train dictionary... \n", samplesSizes.size()); + std::size_t dictSize = m_pCompressor->TrainDict(samplesBuffer, &samplesSizes[0], samplesSizes.size()); + LOG(Helper::LogLevel::LL_Info, "Dictionary trained, dictionary size: %zu \n", dictSize); +#pragma omp parallel for schedule(dynamic) + for (int i = 0; i < postingListSize.size(); i++) { + // do not compress if no data + if (postingListSize[i] == 0) { + postingListBytes[i] = 0; + continue; + } + ValueType* headVector = nullptr; + if (p_opt.m_enableDeltaEncoding) + { + headVector = (ValueType*)p_headIndex->GetSample(i); + } + std::string postingListFullData = GetPostingListFullData( + i, postingListSize[i], selections, fullVectors, p_opt.m_enableDeltaEncoding, p_opt.m_enablePostingListRearrange, headVector); + size_t sizeToCompress = postingListSize[i] * vectorInfoSize; + if (sizeToCompress != postingListFullData.size()) { + LOG(Helper::LogLevel::LL_Error, "Size to compress NOT MATCH! PostingListFullData size: %zu sizeToCompress: %zu \n", postingListFullData.size(), sizeToCompress); + } + postingListBytes[i] = m_pCompressor->GetCompressedSize(postingListFullData, true); + if (i % 10000 == 0 || postingListBytes[i] > static_cast(p_opt.m_postingPageLimit) * PageSize) { + LOG(Helper::LogLevel::LL_Info, "Posting list %d/%d, compressed size: %d, compression ratio: %.4f\n", i, postingListSize.size(), postingListBytes[i], postingListBytes[i] / float(sizeToCompress)); + } + } + LOG(Helper::LogLevel::LL_Info, "Getted compressed size for all the %d posting lists.\n", postingListBytes.size()); + LOG(Helper::LogLevel::LL_Info, "Mean compressed size: %.4f \n", std::accumulate(postingListBytes.begin(), postingListBytes.end(), 0.0) / postingListBytes.size()); + LOG(Helper::LogLevel::LL_Info, "Mean compression ratio: %.4f \n", std::accumulate(postingListBytes.begin(), postingListBytes.end(), 0.0) / (std::accumulate(postingListSize.begin(), postingListSize.end(), 0.0) * vectorInfoSize)); + } + else + { + for (int i = 0; i < postingListSize.size(); i++) + { + postingListBytes[i] = postingListSize[i] * vectorInfoSize; + } + } + + // iterate over files for (int i = 0; i < p_opt.m_ssdIndexFileNum; i++) { size_t curPostingListOffSet = i * postingFileSize; size_t curPostingListEnd = min(postingListSize.size(), (i + 1) * postingFileSize); + // postingListSize: number of vectors in the posting list, type vector std::vector curPostingListSizes( postingListSize.begin() + curPostingListOffSet, postingListSize.begin() + curPostingListEnd); + std::vector curPostingListBytes; + curPostingListBytes.assign( + postingListBytes.begin() + curPostingListOffSet, + postingListBytes.begin() + curPostingListEnd); std::unique_ptr postPageNum; std::unique_ptr postPageOffset; std::vector postingOrderInIndex; - SelectPostingOffset(vectorInfoSize, curPostingListSizes, postPageNum, postPageOffset, postingOrderInIndex); + SelectPostingOffset(curPostingListBytes, postPageNum, postPageOffset, postingOrderInIndex); + // LoadBatch: select vectors for each posting list if (p_opt.m_ssdIndexFileNum > 1) selections.LoadBatch(selectionsBatchOffset[i], selectionsBatchOffset[i + 1]); - OutputSSDIndexFile((i == 0) ? outputFile : outputFile + "_" + std::to_string(i), + p_opt.m_enableDeltaEncoding, + p_opt.m_enablePostingListRearrange, + p_opt.m_enableDataCompression, + p_opt.m_enableDictTraining, vectorInfoSize, curPostingListSizes, + curPostingListBytes, + p_headIndex, selections, postPageNum, postPageOffset, @@ -507,6 +710,8 @@ namespace SPTAG private: struct ListInfo { + std::size_t listTotalBytes = 0; + int listEleCount = 0; std::uint16_t listPageCount = 0; @@ -523,6 +728,7 @@ namespace SPTAG LOG(Helper::LogLevel::LL_Error, "Failed to open file: %s\n", p_file.c_str()); exit(1); } + m_pCompressor = std::make_unique(); // no need compress level to decompress int m_listCount; int m_totalDocumentCount; @@ -563,6 +769,13 @@ namespace SPTAG int pageNum; for (int i = 0; i < m_listCount; ++i) { + if (m_enableDataCompression) + { + if (ptr->ReadBinary(sizeof(m_listInfos[i].listTotalBytes), reinterpret_cast(&(m_listInfos[i].listTotalBytes))) != sizeof(m_listInfos[i].listTotalBytes)) { + LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + exit(1); + } + } if (ptr->ReadBinary(sizeof(pageNum), reinterpret_cast(&(pageNum))) != sizeof(pageNum)) { LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); exit(1); @@ -579,10 +792,13 @@ namespace SPTAG LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); exit(1); } - m_listInfos[i].listOffset = (static_cast(m_listPageOffset + pageNum) << PageSizeEx); - m_listInfos[i].listEleCount = min(m_listInfos[i].listEleCount, (min(static_cast(m_listInfos[i].listPageCount), p_postingPageLimit) << PageSizeEx) / m_vectorInfoSize); - m_listInfos[i].listPageCount = static_cast(ceil((m_vectorInfoSize * m_listInfos[i].listEleCount + m_listInfos[i].pageOffset) * 1.0 / (1 << PageSizeEx))); + if (!m_enableDataCompression) + { + m_listInfos[i].listTotalBytes = m_listInfos[i].listEleCount * m_vectorInfoSize; + m_listInfos[i].listEleCount = min(m_listInfos[i].listEleCount, (min(static_cast(m_listInfos[i].listPageCount), p_postingPageLimit) << PageSizeEx) / m_vectorInfoSize); + m_listInfos[i].listPageCount = static_cast(ceil((m_vectorInfoSize * m_listInfos[i].listEleCount + m_listInfos[i].pageOffset) * 1.0 / (1 << PageSizeEx))); + } totalListElementCount += m_listInfos[i].listEleCount; int pageCount = m_listInfos[i].listPageCount; @@ -602,6 +818,28 @@ namespace SPTAG } } + if (m_enableDataCompression && m_enableDictTraining) + { + size_t dictBufferSize; + if (ptr->ReadBinary(sizeof(size_t), reinterpret_cast(&dictBufferSize)) != sizeof(dictBufferSize)) { + LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + exit(1); + } + char* dictBuffer = new char[dictBufferSize]; + if (ptr->ReadBinary(dictBufferSize, dictBuffer) != dictBufferSize) { + LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + exit(1); + } + try { + m_pCompressor->SetDictBuffer(std::string(dictBuffer, dictBufferSize)); + } + catch (std::runtime_error& err) { + LOG(Helper::LogLevel::LL_Error, "Failed to read head info file: %s \n", err.what()); + exit(1); + } + delete[] dictBuffer; + } + LOG(Helper::LogLevel::LL_Info, "Finish reading header info, list count %d, total doc count %d, dimension %d, list page offset %d.\n", m_listCount, @@ -609,7 +847,6 @@ namespace SPTAG m_iDataDimension, m_listPageOffset); - LOG(Helper::LogLevel::LL_Info, "Big page (>4K): list count %zu, total element count %zu.\n", biglistCount, @@ -625,14 +862,14 @@ namespace SPTAG return m_listCount; } - void SelectPostingOffset(size_t p_spacePerVector, - const std::vector& p_postingListSizes, + void SelectPostingOffset( + const std::vector& p_postingListBytes, std::unique_ptr& p_postPageNum, std::unique_ptr& p_postPageOffset, std::vector& p_postingOrderInIndex) { - p_postPageNum.reset(new int[p_postingListSizes.size()]); - p_postPageOffset.reset(new std::uint16_t[p_postingListSizes.size()]); + p_postPageNum.reset(new int[p_postingListBytes.size()]); + p_postPageOffset.reset(new std::uint16_t[p_postingListBytes.size()]); struct PageModWithID { @@ -652,18 +889,18 @@ namespace SPTAG std::set listRestSize; p_postingOrderInIndex.clear(); - p_postingOrderInIndex.reserve(p_postingListSizes.size()); + p_postingOrderInIndex.reserve(p_postingListBytes.size()); PageModWithID listInfo; - for (size_t i = 0; i < p_postingListSizes.size(); ++i) + for (size_t i = 0; i < p_postingListBytes.size(); ++i) { - if (p_postingListSizes[i] == 0) + if (p_postingListBytes[i] == 0) { continue; } listInfo.id = static_cast(i); - listInfo.rest = static_cast((p_spacePerVector * p_postingListSizes[i]) % PageSize); + listInfo.rest = static_cast(p_postingListBytes[i] % PageSize); listRestSize.insert(listInfo); } @@ -676,8 +913,8 @@ namespace SPTAG while (!listRestSize.empty()) { listInfo.rest = PageSize - currOffset; - auto iter = listRestSize.lower_bound(listInfo); - if (iter == listRestSize.end()) + auto iter = listRestSize.lower_bound(listInfo); // avoid page-crossing + if (iter == listRestSize.end() || (listInfo.rest != PageSize && iter->rest == 0)) { ++currPageNum; currOffset = 0; @@ -702,7 +939,7 @@ namespace SPTAG currOffset = 0; } - currPageNum += static_cast((p_spacePerVector * p_postingListSizes[iter->id]) / PageSize); + currPageNum += static_cast(p_postingListBytes[iter->id] / PageSize); listRestSize.erase(iter); } @@ -711,10 +948,15 @@ namespace SPTAG LOG(Helper::LogLevel::LL_Info, "TotalPageNumbers: %d, IndexSize: %llu\n", currPageNum, static_cast(currPageNum) * PageSize + currOffset); } - void OutputSSDIndexFile(const std::string& p_outputFile, + bool m_enableDeltaEncoding, + bool m_enablePostingListRearrange, + bool m_enableDataCompression, + bool m_enableDictTraining, size_t p_spacePerVector, const std::vector& p_postingListSizes, + const std::vector& p_postingListBytes, + std::shared_ptr p_headIndex, Selection& p_postingSelections, const std::unique_ptr& p_postPageNum, const std::unique_ptr& p_postPageOffset, @@ -728,9 +970,10 @@ namespace SPTAG auto ptr = SPTAG::f_createIO(); int retry = 3; + // open file while (retry > 0 && (ptr == nullptr || !ptr->Initialize(p_outputFile.c_str(), std::ios::binary | std::ios::out))) { - LOG(Helper::LogLevel::LL_Error, "Failed open file %s\n", p_outputFile.c_str()); + LOG(Helper::LogLevel::LL_Error, "Failed open file %s, retrying...\n", p_outputFile.c_str()); retry--; } @@ -738,13 +981,26 @@ namespace SPTAG LOG(Helper::LogLevel::LL_Error, "Failed open file %s\n", p_outputFile.c_str()); exit(1); } - + // meta size of global info std::uint64_t listOffset = sizeof(int) * 4; + // meta size of the posting lists listOffset += (sizeof(int) + sizeof(std::uint16_t) + sizeof(int) + sizeof(std::uint16_t)) * p_postingListSizes.size(); + // write listTotalBytes only when enabled data compression + if (m_enableDataCompression) + { + listOffset += sizeof(size_t) * p_postingListSizes.size(); + } + + // compression dict + if (m_enableDataCompression && m_enableDictTraining) + { + listOffset += sizeof(size_t); + listOffset += m_pCompressor->GetDictBuffer().size(); + } std::unique_ptr paddingVals(new char[PageSize]); memset(paddingVals.get(), 0, sizeof(char) * PageSize); - + // paddingSize: bytes left in the last page std::uint64_t paddingSize = PageSize - (listOffset % PageSize); if (paddingSize == PageSize) { @@ -755,37 +1011,39 @@ namespace SPTAG listOffset += paddingSize; } - // Number of lists. + // Number of posting lists int i32Val = static_cast(p_postingListSizes.size()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } - // Number of all documents. + // Number of vectors i32Val = static_cast(p_fullVectors->Count()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } - // Bytes of each vector. + // Vector dimension i32Val = static_cast(p_fullVectors->Dimension()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } - // Page offset of list content section. + // Page offset of list content section i32Val = static_cast(listOffset / PageSize); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } + // Meta of each posting list for (int i = 0; i < p_postingListSizes.size(); ++i) { - int pageNum = 0; + size_t postingListByte = 0; + int pageNum = 0; // starting page number std::uint16_t pageOffset = 0; int listEleCount = 0; std::uint16_t listPageCount = 0; @@ -795,30 +1053,59 @@ namespace SPTAG pageNum = p_postPageNum[i]; pageOffset = static_cast(p_postPageOffset[i]); listEleCount = static_cast(p_postingListSizes[i]); - listPageCount = static_cast((p_spacePerVector * p_postingListSizes[i]) / PageSize); - if (0 != ((p_spacePerVector * p_postingListSizes[i]) % PageSize)) + postingListByte = p_postingListBytes[i]; + listPageCount = static_cast(postingListByte / PageSize); + if (0 != (postingListByte % PageSize)) { ++listPageCount; } } + // Total bytes of the posting list, write only when enabled data compression + if (m_enableDataCompression && ptr->WriteBinary(sizeof(postingListByte), reinterpret_cast(&postingListByte)) != sizeof(postingListByte)) { + LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + exit(1); + } + // Page number of the posting list if (ptr->WriteBinary(sizeof(pageNum), reinterpret_cast(&pageNum)) != sizeof(pageNum)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } + // Page offset if (ptr->WriteBinary(sizeof(pageOffset), reinterpret_cast(&pageOffset)) != sizeof(pageOffset)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } + // Number of vectors in the posting list if (ptr->WriteBinary(sizeof(listEleCount), reinterpret_cast(&listEleCount)) != sizeof(listEleCount)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } + // Page count of the posting list if (ptr->WriteBinary(sizeof(listPageCount), reinterpret_cast(&listPageCount)) != sizeof(listPageCount)) { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } } + // compression dict + if (m_enableDataCompression && m_enableDictTraining) + { + std::string dictBuffer = m_pCompressor->GetDictBuffer(); + // dict size + size_t dictBufferSize = dictBuffer.size(); + if (ptr->WriteBinary(sizeof(size_t), reinterpret_cast(&dictBufferSize)) != sizeof(dictBufferSize)) + { + LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + exit(1); + } + // dict + if (ptr->WriteBinary(dictBuffer.size(), const_cast(dictBuffer.data())) != dictBuffer.size()) + { + LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + exit(1); + } + } + // Write padding vals if (paddingSize > 0) { if (ptr->WriteBinary(paddingSize, reinterpret_cast(paddingVals.get())) != paddingSize) { @@ -838,6 +1125,7 @@ namespace SPTAG listOffset = 0; std::uint64_t paddedSize = 0; + // iterate over all the posting lists for (auto id : p_postingOrderInIndex) { std::uint64_t targetOffset = static_cast(p_postPageNum[id]) * PageSize + p_postPageOffset[id]; @@ -846,7 +1134,7 @@ namespace SPTAG LOG(Helper::LogLevel::LL_Info, "List offset not match, targetOffset < listOffset!\n"); exit(1); } - + // write padding vals before the posting list if (targetOffset > listOffset) { if (targetOffset - listOffset > PageSize) @@ -865,25 +1153,49 @@ namespace SPTAG listOffset = targetOffset; } - std::size_t selectIdx = p_postingSelections.lower_bound(id + (int)p_postingListOffset); - for (int j = 0; j < p_postingListSizes[id]; ++j) + if (p_postingListSizes[id] == 0) { - if (p_postingSelections[selectIdx].node != id + (int)p_postingListOffset) + continue; + } + int postingListId = id + (int)p_postingListOffset; + // get posting list full content and write it at once + ValueType *headVector = nullptr; + if (m_enableDeltaEncoding) + { + headVector = (ValueType *)p_headIndex->GetSample(postingListId); + } + std::string postingListFullData = GetPostingListFullData( + postingListId, p_postingListSizes[id], p_postingSelections, p_fullVectors, m_enableDeltaEncoding, m_enablePostingListRearrange, headVector); + size_t postingListFullSize = p_postingListSizes[id] * p_spacePerVector; + if (postingListFullSize != postingListFullData.size()) + { + LOG(Helper::LogLevel::LL_Error, "posting list full data size NOT MATCH! postingListFullData.size(): %zu postingListFullSize: %zu \n", postingListFullData.size(), postingListFullSize); + exit(1); + } + if (m_enableDataCompression) + { + std::string compressedData = m_pCompressor->Compress(postingListFullData, m_enableDictTraining); + size_t compressedSize = compressedData.size(); + if (compressedSize != p_postingListBytes[id]) { - LOG(Helper::LogLevel::LL_Error, "Selection ID NOT MATCH! node:%d offset:%zu\n", id + (int)p_postingListOffset, selectIdx); + LOG(Helper::LogLevel::LL_Error, "Compressed size NOT MATCH! compressed size:%zu, pre-calculated compressed size:%zu\n", compressedSize, p_postingListBytes[id]); exit(1); } - - i32Val = p_postingSelections[selectIdx++].tonode; - if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { + if (ptr->WriteBinary(compressedSize, compressedData.data()) != compressedSize) + { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } - if (ptr->WriteBinary(p_fullVectors->PerVectorDataSize(), reinterpret_cast(p_fullVectors->GetVector(i32Val))) != p_fullVectors->PerVectorDataSize()) { + listOffset += compressedSize; + } + else + { + if (ptr->WriteBinary(postingListFullSize, postingListFullData.data()) != postingListFullSize) + { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } - listOffset += p_spacePerVector; + listOffset += postingListFullSize; } } @@ -900,7 +1212,8 @@ namespace SPTAG if (paddingSize > 0) { - if (ptr->WriteBinary(paddingSize, reinterpret_cast(paddingVals.get())) != paddingSize) { + if (ptr->WriteBinary(paddingSize, reinterpret_cast(paddingVals.get())) != paddingSize) + { LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); exit(1); } @@ -920,6 +1233,11 @@ namespace SPTAG std::vector> m_listInfos; std::vector> m_indexFiles; + std::unique_ptr m_pCompressor; + bool m_enableDeltaEncoding; + bool m_enablePostingListRearrange; + bool m_enableDataCompression; + bool m_enableDictTraining; int m_vectorInfoSize = 0; diff --git a/AnnService/inc/Core/SPANN/IExtraSearcher.h b/AnnService/inc/Core/SPANN/IExtraSearcher.h index 3063d8cc..8b5d5d07 100644 --- a/AnnService/inc/Core/SPANN/IExtraSearcher.h +++ b/AnnService/inc/Core/SPANN/IExtraSearcher.h @@ -108,11 +108,11 @@ namespace SPTAG { ~ExtraWorkSpace() {} ExtraWorkSpace(ExtraWorkSpace& other) { - Initialize(other.m_deduper.MaxCheck(), other.m_deduper.HashTableExponent(), (int)other.m_pageBuffers.size(), (int)(other.m_pageBuffers[0].GetPageSize())); + Initialize(other.m_deduper.MaxCheck(), other.m_deduper.HashTableExponent(), (int)other.m_pageBuffers.size(), (int)(other.m_pageBuffers[0].GetPageSize()), other.m_enableDataCompression); m_spaceID = g_spaceCount++; } - void Initialize(int p_maxCheck, int p_hashExp, int p_internalResultNum, int p_maxPages) { + void Initialize(int p_maxCheck, int p_hashExp, int p_internalResultNum, int p_maxPages, bool enableDataCompression) { m_postingIDs.reserve(p_internalResultNum); m_deduper.Init(p_maxCheck, p_hashExp); m_processIocp.reset(p_internalResultNum); @@ -121,6 +121,10 @@ namespace SPTAG { m_pageBuffers[pi].ReservePageBuffer(p_maxPages); } m_diskRequests.resize(p_internalResultNum); + m_enableDataCompression = enableDataCompression; + if (enableDataCompression) { + m_decompressBuffer.ReservePageBuffer(p_maxPages); + } } void Initialize(va_list& arg) { @@ -128,7 +132,8 @@ namespace SPTAG { int hashExp = va_arg(arg, int); int internalResultNum = va_arg(arg, int); int maxPages = va_arg(arg, int); - Initialize(maxCheck, hashExp, internalResultNum, maxPages); + bool enableDataCompression = bool(va_arg(arg, int)); + Initialize(maxCheck, hashExp, internalResultNum, maxPages, enableDataCompression); } static void Reset() { g_spaceCount = 0; } @@ -141,6 +146,9 @@ namespace SPTAG { std::vector> m_pageBuffers; + bool m_enableDataCompression; + PageBuffer m_decompressBuffer; + std::vector m_diskRequests; int m_spaceID; @@ -155,7 +163,6 @@ namespace SPTAG { { } - virtual ~IExtraSearcher() { } @@ -165,7 +172,9 @@ namespace SPTAG { virtual void SearchIndex(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_queryResults, std::shared_ptr p_index, - SearchStats* p_stats, std::set* truth = nullptr, std::map>* found = nullptr) = 0; + SearchStats* p_stats, + std::set* truth = nullptr, + std::map>* found = nullptr) = 0; virtual bool BuildIndex(std::shared_ptr& p_reader, std::shared_ptr p_index, diff --git a/AnnService/inc/Core/SPANN/Options.h b/AnnService/inc/Core/SPANN/Options.h index 133d51a4..7c726c0a 100644 --- a/AnnService/inc/Core/SPANN/Options.h +++ b/AnnService/inc/Core/SPANN/Options.h @@ -79,6 +79,13 @@ namespace SPTAG { bool m_enableSSD; bool m_buildSsdIndex; int m_iSSDNumberOfThreads; + bool m_enableDeltaEncoding; + bool m_enablePostingListRearrange; + bool m_enableDataCompression; + bool m_enableDictTraining; + int m_minDictTraingBufferSize; + int m_dictBufferCapacity; + int m_zstdCompressLevel; // Building int m_replicaCount; diff --git a/AnnService/inc/Core/SPANN/ParameterDefinitionList.h b/AnnService/inc/Core/SPANN/ParameterDefinitionList.h index 72f38fe1..7e3704d6 100644 --- a/AnnService/inc/Core/SPANN/ParameterDefinitionList.h +++ b/AnnService/inc/Core/SPANN/ParameterDefinitionList.h @@ -77,6 +77,13 @@ DefineBuildHeadParameter(m_buildHead, bool, false, "isExecute") DefineSSDParameter(m_enableSSD, bool, false, "isExecute") DefineSSDParameter(m_buildSsdIndex, bool, false, "BuildSsdIndex") DefineSSDParameter(m_iSSDNumberOfThreads, int, 16, "NumberOfThreads") +DefineSSDParameter(m_enableDeltaEncoding, bool, false, "EnableDeltaEncoding") +DefineSSDParameter(m_enablePostingListRearrange, bool, false, "EnablePostingListRearrange") +DefineSSDParameter(m_enableDataCompression, bool, false, "EnableDataCompression") +DefineSSDParameter(m_enableDictTraining, bool, true, "EnableDictTraining") +DefineSSDParameter(m_minDictTraingBufferSize, int, 10240000, "MinDictTrainingBufferSize") +DefineSSDParameter(m_dictBufferCapacity, int, 204800, "DictBufferCapacity") +DefineSSDParameter(m_zstdCompressLevel, int, 0, "ZstdCompressLevel") // Building DefineSSDParameter(m_internalResultNum, int, 64, "InternalResultNum") diff --git a/AnnService/packages.config b/AnnService/packages.config index 991c8539..4ad04937 100644 --- a/AnnService/packages.config +++ b/AnnService/packages.config @@ -2,10 +2,11 @@ + - + \ No newline at end of file diff --git a/AnnService/src/Core/SPANN/SPANNIndex.cpp b/AnnService/src/Core/SPANN/SPANNIndex.cpp index 84efc85c..a31fa80f 100644 --- a/AnnService/src/Core/SPANN/SPANNIndex.cpp +++ b/AnnService/src/Core/SPANN/SPANNIndex.cpp @@ -102,7 +102,7 @@ namespace SPTAG omp_set_num_threads(m_options.m_iSSDNumberOfThreads); m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx); + m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); return ErrorCode::Success; } @@ -134,7 +134,7 @@ namespace SPTAG omp_set_num_threads(m_options.m_iSSDNumberOfThreads); m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx); + m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); return ErrorCode::Success; } @@ -286,7 +286,7 @@ namespace SPTAG auto_ws->m_postingIDs.emplace_back(res->VID); } - m_extraSearcher->SearchIndex(auto_ws.get(), newResults, m_index, p_stats, truth, found); + m_extraSearcher->SearchIndex(auto_ws.get(), newResults, m_index, p_stats, truth, found); } m_workSpacePool->Return(auto_ws); @@ -714,7 +714,7 @@ namespace SPTAG } m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx); + m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); m_bReady = true; return ErrorCode::Success; } @@ -766,7 +766,7 @@ namespace SPTAG omp_set_num_threads(m_options.m_iSSDNumberOfThreads); m_index->UpdateIndex(); m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx); + m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); return ErrorCode::Success; } diff --git a/CMakeLists.txt b/CMakeLists.txt index 96713278..f1b97a45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,9 @@ endif() option(GPU "GPU" ON) option(LIBRARYONLY "LIBRARYONLY" OFF) + +add_subdirectory (ThirdParty/zstd/build/cmake) + add_subdirectory (AnnService) add_subdirectory (Test) add_subdirectory (GPUSupport) diff --git a/Dockerfile b/Dockerfile index 576fa9a8..00ec1f84 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,5 +14,6 @@ COPY AnnService ./AnnService/ COPY Test ./Test/ COPY Wrappers ./Wrappers/ COPY GPUSupport ./GPUSupport/ +COPY ThirdParty ./ThirdParty/ -RUN mkdir build && cd build && cmake .. && make -j && cd .. +RUN mkdir build && cd build && cmake .. && make -j$(nproc) && cd .. diff --git a/Dockerfile.cuda b/Dockerfile.cuda index 6e984fdf..bfd26510 100644 --- a/Dockerfile.cuda +++ b/Dockerfile.cuda @@ -14,5 +14,6 @@ COPY AnnService ./AnnService/ COPY Test ./Test/ COPY Wrappers ./Wrappers/ COPY GPUSupport ./GPUSupport/ +COPY ThirdParty ./ThirdParty/ -RUN mkdir build && cd build && cmake .. && make -j && cd .. +RUN mkdir build && cd build && cmake .. && -j$(nproc) && cd .. diff --git a/SPTAG.nuspec b/SPTAG.nuspec index 0dd57528..b739a1b7 100644 --- a/SPTAG.nuspec +++ b/SPTAG.nuspec @@ -48,6 +48,7 @@ + diff --git a/Test/Test.vcxproj b/Test/Test.vcxproj index 68e67dac..cba1cb90 100644 --- a/Test/Test.vcxproj +++ b/Test/Test.vcxproj @@ -176,6 +176,7 @@ + @@ -190,5 +191,6 @@ + \ No newline at end of file diff --git a/Test/packages.config b/Test/packages.config index 27713806..8adb5ddb 100644 --- a/Test/packages.config +++ b/Test/packages.config @@ -9,4 +9,5 @@ + \ No newline at end of file diff --git a/ThirdParty/zstd b/ThirdParty/zstd new file mode 160000 index 00000000..e47e674c --- /dev/null +++ b/ThirdParty/zstd @@ -0,0 +1 @@ +Subproject commit e47e674cd09583ff0503f0f6defd6d23d8b718d3 diff --git a/Wrappers/CLRCore.vcxproj b/Wrappers/CLRCore.vcxproj index 9aa9db77..dcbf66bf 100644 --- a/Wrappers/CLRCore.vcxproj +++ b/Wrappers/CLRCore.vcxproj @@ -150,6 +150,7 @@ + @@ -162,5 +163,6 @@ + \ No newline at end of file diff --git a/Wrappers/CsharpClient.vcxproj b/Wrappers/CsharpClient.vcxproj index ed558eb7..1e48b766 100644 --- a/Wrappers/CsharpClient.vcxproj +++ b/Wrappers/CsharpClient.vcxproj @@ -169,6 +169,7 @@ + @@ -189,5 +190,6 @@ + \ No newline at end of file diff --git a/Wrappers/CsharpCore.vcxproj b/Wrappers/CsharpCore.vcxproj index 692d418e..eea06963 100644 --- a/Wrappers/CsharpCore.vcxproj +++ b/Wrappers/CsharpCore.vcxproj @@ -54,6 +54,7 @@ + @@ -132,5 +133,6 @@ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + \ No newline at end of file diff --git a/Wrappers/PythonClient.vcxproj b/Wrappers/PythonClient.vcxproj index afdfb14e..25f5b729 100644 --- a/Wrappers/PythonClient.vcxproj +++ b/Wrappers/PythonClient.vcxproj @@ -166,6 +166,7 @@ + @@ -184,5 +185,6 @@ + \ No newline at end of file diff --git a/Wrappers/PythonCore.vcxproj b/Wrappers/PythonCore.vcxproj index 64fdc8be..db9b1184 100644 --- a/Wrappers/PythonCore.vcxproj +++ b/Wrappers/PythonCore.vcxproj @@ -117,6 +117,7 @@ + @@ -130,5 +131,6 @@ + \ No newline at end of file diff --git a/Wrappers/packages.config b/Wrappers/packages.config index 784b338f..d4000416 100644 --- a/Wrappers/packages.config +++ b/Wrappers/packages.config @@ -9,4 +9,5 @@ + \ No newline at end of file