diff --git a/benchs/bench_fw/optimize.py b/benchs/bench_fw/optimize.py index 473436ea68..a2653b7144 100644 --- a/benchs/bench_fw/optimize.py +++ b/benchs/bench_fw/optimize.py @@ -226,6 +226,7 @@ def optimize_codec( [ (None, "Flat"), (None, "SQfp16"), + (None, "SQbf16"), (None, "SQ8"), ] + [ (f"OPQ{M}_{M * dim}", f"PQ{M}x{b}") diff --git a/c_api/IndexScalarQuantizer_c.h b/c_api/IndexScalarQuantizer_c.h index 2c5e3f2942..87fe6d3415 100644 --- a/c_api/IndexScalarQuantizer_c.h +++ b/c_api/IndexScalarQuantizer_c.h @@ -26,6 +26,7 @@ typedef enum FaissQuantizerType { QT_fp16, QT_8bit_direct, ///< fast indexing of uint8s QT_6bit, ///< 6 bits per component + QT_bf16, } FaissQuantizerType; // forward declaration diff --git a/contrib/factory_tools.py b/contrib/factory_tools.py index 745dc7f7ff..cfad7c7b5c 100644 --- a/contrib/factory_tools.py +++ b/contrib/factory_tools.py @@ -56,6 +56,8 @@ def get_code_size(d, indexkey): return (d * 6 + 7) // 8 elif indexkey == 'SQfp16': return d * 2 + elif indexkey == 'SQbf16': + return d * 2 mo = re.match('PCAR?(\\d+),(.*)$', indexkey) if mo: @@ -140,6 +142,7 @@ def reverse_index_factory(index): faiss.ScalarQuantizer.QT_4bit: "4", faiss.ScalarQuantizer.QT_6bit: "6", faiss.ScalarQuantizer.QT_fp16: "fp16", + faiss.ScalarQuantizer.QT_bf16: "bf16", } return f"SQ{sqtypes[index.sq.qtype]}" diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 33e1849568..1b0860f3fb 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -183,6 +183,7 @@ set(FAISS_HEADERS invlists/InvertedLists.h invlists/InvertedListsIOHook.h utils/AlignedTable.h + utils/bf16.h utils/Heap.h utils/WorkerThread.h utils/distances.h diff --git a/faiss/IndexScalarQuantizer.cpp b/faiss/IndexScalarQuantizer.cpp index 9203a98932..7ce838db5e 100644 --- a/faiss/IndexScalarQuantizer.cpp +++ b/faiss/IndexScalarQuantizer.cpp @@ -32,7 +32,8 @@ IndexScalarQuantizer::IndexScalarQuantizer( MetricType metric) : IndexFlatCodes(0, d, metric), sq(d, qtype) { is_trained = qtype == ScalarQuantizer::QT_fp16 || - qtype == ScalarQuantizer::QT_8bit_direct; + qtype == ScalarQuantizer::QT_8bit_direct || + qtype == ScalarQuantizer::QT_bf16; code_size = sq.code_size; } diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index e3b29e621d..7ad50189e4 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -496,6 +497,72 @@ struct QuantizerFP16<8> : QuantizerFP16<1> { }; #endif +/******************************************************************* + * BF16 quantizer + *******************************************************************/ + +template +struct QuantizerBF16 {}; + +template <> +struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { + const size_t d; + + QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + ((uint16_t*)code)[i] = encode_bf16(x[i]); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = decode_bf16(((uint16_t*)code)[i]); + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return decode_bf16(((uint16_t*)code)[i]); + } +}; + +#ifdef __AVX2__ + +template <> +struct QuantizerBF16<8> : QuantizerBF16<1> { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE __m256 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); + code_256i = _mm256_slli_epi32(code_256i, 16); + return _mm256_castsi256_ps(code_256i); + } +}; + +#endif + +#ifdef __aarch64__ + +template <> +struct QuantizerBF16<8> : QuantizerBF16<1> { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE float32x4x2_t + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}; + } +}; +#endif + /******************************************************************* * 8bit_direct quantizer *******************************************************************/ @@ -589,6 +656,8 @@ ScalarQuantizer::SQuantizer* select_quantizer_1( d, trained); case ScalarQuantizer::QT_fp16: return new QuantizerFP16(d, trained); + case ScalarQuantizer::QT_bf16: + return new QuantizerBF16(d, trained); case ScalarQuantizer::QT_8bit_direct: return new Quantizer8bitDirect(d, trained); } @@ -1378,6 +1447,10 @@ SQDistanceComputer* select_distance_computer( return new DCTemplate, Sim, SIMDWIDTH>( d, trained); + case ScalarQuantizer::QT_bf16: + return new DCTemplate, Sim, SIMDWIDTH>( + d, trained); + case ScalarQuantizer::QT_8bit_direct: if (d % 16 == 0) { return new DistanceComputerByte(d, trained); @@ -1426,6 +1499,10 @@ void ScalarQuantizer::set_derived_sizes() { code_size = d * 2; bits = 16; break; + case QT_bf16: + code_size = d * 2; + bits = 16; + break; } } @@ -1462,6 +1539,7 @@ void ScalarQuantizer::train(size_t n, const float* x) { break; case QT_fp16: case QT_8bit_direct: + case QT_bf16: // no training necessary break; } @@ -1791,6 +1869,11 @@ InvertedListScanner* sel1_InvertedListScanner( QuantizerFP16, Similarity, SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_bf16: + return sel2_InvertedListScanner, + Similarity, + SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_8bit_direct: if (sq->d % 16 == 0) { return sel2_InvertedListScanner< diff --git a/faiss/impl/ScalarQuantizer.h b/faiss/impl/ScalarQuantizer.h index 550a979092..49fd42cc31 100644 --- a/faiss/impl/ScalarQuantizer.h +++ b/faiss/impl/ScalarQuantizer.h @@ -32,6 +32,7 @@ struct ScalarQuantizer : Quantizer { QT_fp16, QT_8bit_direct, ///< fast indexing of uint8s QT_6bit, ///< 6 bits per component + QT_bf16, }; QuantizerType qtype = QT_8bit; diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp index 0d61b73ecd..d88fe7b393 100644 --- a/faiss/index_factory.cpp +++ b/faiss/index_factory.cpp @@ -140,8 +140,9 @@ std::map sq_types = { {"SQ4", ScalarQuantizer::QT_4bit}, {"SQ6", ScalarQuantizer::QT_6bit}, {"SQfp16", ScalarQuantizer::QT_fp16}, + {"SQbf16", ScalarQuantizer::QT_bf16}, }; -const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)"; +const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16|SQbf16)"; std::map aq_search_type = { {"_Nfloat", AdditiveQuantizer::ST_norm_float}, diff --git a/faiss/utils/bf16.h b/faiss/utils/bf16.h new file mode 100644 index 0000000000..ff0fbe898b --- /dev/null +++ b/faiss/utils/bf16.h @@ -0,0 +1,36 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { + +namespace { + +union fp32_bits { + uint32_t as_u32; + float as_f32; +}; + +} // namespace + +inline uint16_t encode_bf16(const float f) { + // Round off + fp32_bits fp; + fp.as_f32 = f; + return static_cast((fp.as_u32 + 0x8000) >> 16); +} + +inline float decode_bf16(const uint16_t v) { + fp32_bits fp; + fp.as_u32 = (uint32_t(v) << 16); + return fp.as_f32; +} + +} // namespace faiss diff --git a/tests/test_index.py b/tests/test_index.py index b9f3dbd46b..43db906e47 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -327,7 +327,7 @@ def test_4variants_ivf(self): D, I = index.search(xq, 10) nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum() - for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split(): + for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16 QT_bf16".split(): qtype = getattr(faiss.ScalarQuantizer, qname) index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent, qtype, faiss.METRIC_L2) @@ -349,6 +349,7 @@ def test_4variants_ivf(self): self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit']) + self.assertGreaterEqual(nok['QT_bf16'], nok['QT_8bit']) def test_4variants(self): d = 32 @@ -364,7 +365,7 @@ def test_4variants(self): nok = {} - for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split(): + for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16 QT_bf16".split(): qtype = getattr(faiss.ScalarQuantizer, qname) index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2) index.train(xt) @@ -377,6 +378,7 @@ def test_4variants(self): self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit']) + self.assertGreaterEqual(nok['QT_bf16'], nq * 0.9) class TestRangeSearch(unittest.TestCase): diff --git a/tests/test_standalone_codec.py b/tests/test_standalone_codec.py index 2176a12e99..391b88b9dd 100644 --- a/tests/test_standalone_codec.py +++ b/tests/test_standalone_codec.py @@ -173,6 +173,9 @@ def test_SQ2(self): def test_SQ3(self): self.compare_accuracy('SQ8', 'SQfp16') + def test_SQ4(self): + self.compare_accuracy('SQ8', 'SQbf16') + def test_PQ(self): self.compare_accuracy('PQ6x8np', 'PQ8x8np')