Skip to content

Commit

Permalink
QT_bf16 for scalar quantizer for bfloat16 (#3444)
Browse files Browse the repository at this point in the history
Summary:
mdouze Please let me know if any additional unit tests are needed

Pull Request resolved: #3444

Reviewed By: algoriddle

Differential Revision: D57665641

Pulled By: mdouze

fbshipit-source-id: 9bec91306a1c31ea4f1f1d726c9d60ac6415fdfc
  • Loading branch information
alexanderguzhva authored and facebook-github-bot committed May 23, 2024
1 parent 414fd1e commit 6a94c67
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 4 deletions.
1 change: 1 addition & 0 deletions benchs/bench_fw/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def optimize_codec(
[
(None, "Flat"),
(None, "SQfp16"),
(None, "SQbf16"),
(None, "SQ8"),
] + [
(f"OPQ{M}_{M * dim}", f"PQ{M}x{b}")
Expand Down
1 change: 1 addition & 0 deletions c_api/IndexScalarQuantizer_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ typedef enum FaissQuantizerType {
QT_fp16,
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
} FaissQuantizerType;

// forward declaration
Expand Down
3 changes: 3 additions & 0 deletions contrib/factory_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def get_code_size(d, indexkey):
return (d * 6 + 7) // 8
elif indexkey == 'SQfp16':
return d * 2
elif indexkey == 'SQbf16':
return d * 2

mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
if mo:
Expand Down Expand Up @@ -140,6 +142,7 @@ def reverse_index_factory(index):
faiss.ScalarQuantizer.QT_4bit: "4",
faiss.ScalarQuantizer.QT_6bit: "6",
faiss.ScalarQuantizer.QT_fp16: "fp16",
faiss.ScalarQuantizer.QT_bf16: "bf16",
}
return f"SQ{sqtypes[index.sq.qtype]}"

Expand Down
1 change: 1 addition & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ set(FAISS_HEADERS
invlists/InvertedLists.h
invlists/InvertedListsIOHook.h
utils/AlignedTable.h
utils/bf16.h
utils/Heap.h
utils/WorkerThread.h
utils/distances.h
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ IndexScalarQuantizer::IndexScalarQuantizer(
MetricType metric)
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
qtype == ScalarQuantizer::QT_8bit_direct;
qtype == ScalarQuantizer::QT_8bit_direct ||
qtype == ScalarQuantizer::QT_bf16;
code_size = sq.code_size;
}

Expand Down
83 changes: 83 additions & 0 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/bf16.h>
#include <faiss/utils/fp16.h>
#include <faiss/utils/utils.h>

Expand Down Expand Up @@ -496,6 +497,72 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
};
#endif

/*******************************************************************
* BF16 quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct QuantizerBF16 {};

template <>
struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
const size_t d;

QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}

void encode_vector(const float* x, uint8_t* code) const final {
for (size_t i = 0; i < d; i++) {
((uint16_t*)code)[i] = encode_bf16(x[i]);
}
}

void decode_vector(const uint8_t* code, float* x) const final {
for (size_t i = 0; i < d; i++) {
x[i] = decode_bf16(((uint16_t*)code)[i]);
}
}

FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
const {
return decode_bf16(((uint16_t*)code)[i]);
}
};

#ifdef __AVX2__

template <>
struct QuantizerBF16<8> : QuantizerBF16<1> {
QuantizerBF16(size_t d, const std::vector<float>& trained)
: QuantizerBF16<1>(d, trained) {}

FAISS_ALWAYS_INLINE __m256
reconstruct_8_components(const uint8_t* code, int i) const {
__m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
__m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
code_256i = _mm256_slli_epi32(code_256i, 16);
return _mm256_castsi256_ps(code_256i);
}
};

#endif

#ifdef __aarch64__

template <>
struct QuantizerBF16<8> : QuantizerBF16<1> {
QuantizerBF16(size_t d, const std::vector<float>& trained)
: QuantizerBF16<1>(d, trained) {}

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
vreinterpretq_f32_u32(
vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
}
};
#endif

/*******************************************************************
* 8bit_direct quantizer
*******************************************************************/
Expand Down Expand Up @@ -589,6 +656,8 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
d, trained);
case ScalarQuantizer::QT_fp16:
return new QuantizerFP16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_bf16:
return new QuantizerBF16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_8bit_direct:
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
}
Expand Down Expand Up @@ -1378,6 +1447,10 @@ SQDistanceComputer* select_distance_computer(
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
d, trained);

case ScalarQuantizer::QT_bf16:
return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
d, trained);

case ScalarQuantizer::QT_8bit_direct:
if (d % 16 == 0) {
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
Expand Down Expand Up @@ -1426,6 +1499,10 @@ void ScalarQuantizer::set_derived_sizes() {
code_size = d * 2;
bits = 16;
break;
case QT_bf16:
code_size = d * 2;
bits = 16;
break;
}
}

Expand Down Expand Up @@ -1462,6 +1539,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
break;
case QT_fp16:
case QT_8bit_direct:
case QT_bf16:
// no training necessary
break;
}
Expand Down Expand Up @@ -1791,6 +1869,11 @@ InvertedListScanner* sel1_InvertedListScanner(
QuantizerFP16<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_bf16:
return sel2_InvertedListScanner<DCTemplate<
QuantizerBF16<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_8bit_direct:
if (sq->d % 16 == 0) {
return sel2_InvertedListScanner<
Expand Down
1 change: 1 addition & 0 deletions faiss/impl/ScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct ScalarQuantizer : Quantizer {
QT_fp16,
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
};

QuantizerType qtype = QT_8bit;
Expand Down
3 changes: 2 additions & 1 deletion faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
{"SQ4", ScalarQuantizer::QT_4bit},
{"SQ6", ScalarQuantizer::QT_6bit},
{"SQfp16", ScalarQuantizer::QT_fp16},
{"SQbf16", ScalarQuantizer::QT_bf16},
};
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16|SQbf16)";

std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
Expand Down
36 changes: 36 additions & 0 deletions faiss/utils/bf16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace faiss {

namespace {

union fp32_bits {
uint32_t as_u32;
float as_f32;
};

} // namespace

inline uint16_t encode_bf16(const float f) {
// Round off
fp32_bits fp;
fp.as_f32 = f;
return static_cast<uint16_t>((fp.as_u32 + 0x8000) >> 16);
}

inline float decode_bf16(const uint16_t v) {
fp32_bits fp;
fp.as_u32 = (uint32_t(v) << 16);
return fp.as_f32;
}

} // namespace faiss
6 changes: 4 additions & 2 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_4variants_ivf(self):
D, I = index.search(xq, 10)
nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum()

for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split():
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16 QT_bf16".split():
qtype = getattr(faiss.ScalarQuantizer, qname)
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
qtype, faiss.METRIC_L2)
Expand All @@ -349,6 +349,7 @@ def test_4variants_ivf(self):
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit'])
self.assertGreaterEqual(nok['QT_bf16'], nok['QT_8bit'])

def test_4variants(self):
d = 32
Expand All @@ -364,7 +365,7 @@ def test_4variants(self):

nok = {}

for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split():
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16 QT_bf16".split():
qtype = getattr(faiss.ScalarQuantizer, qname)
index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2)
index.train(xt)
Expand All @@ -377,6 +378,7 @@ def test_4variants(self):
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit'])
self.assertGreaterEqual(nok['QT_bf16'], nq * 0.9)


class TestRangeSearch(unittest.TestCase):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_standalone_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def test_SQ2(self):
def test_SQ3(self):
self.compare_accuracy('SQ8', 'SQfp16')

def test_SQ4(self):
self.compare_accuracy('SQ8', 'SQbf16')

def test_PQ(self):
self.compare_accuracy('PQ6x8np', 'PQ8x8np')

Expand Down

0 comments on commit 6a94c67

Please sign in to comment.