Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QT_bf16 for scalar quantizer for bfloat16 #3444

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchs/bench_fw/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def optimize_codec(
[
(None, "Flat"),
(None, "SQfp16"),
(None, "SQbf16"),
(None, "SQ8"),
] + [
(f"OPQ{M}_{M * dim}", f"PQ{M}x{b}")
Expand Down
1 change: 1 addition & 0 deletions c_api/IndexScalarQuantizer_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ typedef enum FaissQuantizerType {
QT_fp16,
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
} FaissQuantizerType;

// forward declaration
Expand Down
3 changes: 3 additions & 0 deletions contrib/factory_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def get_code_size(d, indexkey):
return (d * 6 + 7) // 8
elif indexkey == 'SQfp16':
return d * 2
elif indexkey == 'SQbf16':
return d * 2

mo = re.match('PCAR?(\\d+),(.*)$', indexkey)
if mo:
Expand Down Expand Up @@ -140,6 +142,7 @@ def reverse_index_factory(index):
faiss.ScalarQuantizer.QT_4bit: "4",
faiss.ScalarQuantizer.QT_6bit: "6",
faiss.ScalarQuantizer.QT_fp16: "fp16",
faiss.ScalarQuantizer.QT_bf16: "bf16",
}
return f"SQ{sqtypes[index.sq.qtype]}"

Expand Down
1 change: 1 addition & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ set(FAISS_HEADERS
invlists/InvertedLists.h
invlists/InvertedListsIOHook.h
utils/AlignedTable.h
utils/bf16.h
utils/Heap.h
utils/WorkerThread.h
utils/distances.h
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ IndexScalarQuantizer::IndexScalarQuantizer(
MetricType metric)
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
qtype == ScalarQuantizer::QT_8bit_direct;
qtype == ScalarQuantizer::QT_8bit_direct ||
qtype == ScalarQuantizer::QT_bf16;
code_size = sq.code_size;
}

Expand Down
83 changes: 83 additions & 0 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/bf16.h>
#include <faiss/utils/fp16.h>
#include <faiss/utils/utils.h>

Expand Down Expand Up @@ -496,6 +497,72 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
};
#endif

/*******************************************************************
* BF16 quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct QuantizerBF16 {};

template <>
struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
const size_t d;

QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}

void encode_vector(const float* x, uint8_t* code) const final {
for (size_t i = 0; i < d; i++) {
((uint16_t*)code)[i] = encode_bf16(x[i]);
}
}

void decode_vector(const uint8_t* code, float* x) const final {
for (size_t i = 0; i < d; i++) {
x[i] = decode_bf16(((uint16_t*)code)[i]);
}
}

FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
const {
return decode_bf16(((uint16_t*)code)[i]);
}
};

#ifdef __AVX2__

template <>
struct QuantizerBF16<8> : QuantizerBF16<1> {
QuantizerBF16(size_t d, const std::vector<float>& trained)
: QuantizerBF16<1>(d, trained) {}

FAISS_ALWAYS_INLINE __m256
reconstruct_8_components(const uint8_t* code, int i) const {
__m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
__m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
code_256i = _mm256_slli_epi32(code_256i, 16);
return _mm256_castsi256_ps(code_256i);
}
};

#endif

#ifdef __aarch64__

template <>
struct QuantizerBF16<8> : QuantizerBF16<1> {
QuantizerBF16(size_t d, const std::vector<float>& trained)
: QuantizerBF16<1>(d, trained) {}

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
vreinterpretq_f32_u32(
vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
}
};
#endif

/*******************************************************************
* 8bit_direct quantizer
*******************************************************************/
Expand Down Expand Up @@ -589,6 +656,8 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
d, trained);
case ScalarQuantizer::QT_fp16:
return new QuantizerFP16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_bf16:
return new QuantizerBF16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_8bit_direct:
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
}
Expand Down Expand Up @@ -1378,6 +1447,10 @@ SQDistanceComputer* select_distance_computer(
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
d, trained);

case ScalarQuantizer::QT_bf16:
return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
d, trained);

case ScalarQuantizer::QT_8bit_direct:
if (d % 16 == 0) {
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
Expand Down Expand Up @@ -1426,6 +1499,10 @@ void ScalarQuantizer::set_derived_sizes() {
code_size = d * 2;
bits = 16;
break;
case QT_bf16:
code_size = d * 2;
bits = 16;
break;
}
}

Expand Down Expand Up @@ -1462,6 +1539,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
break;
case QT_fp16:
case QT_8bit_direct:
case QT_bf16:
// no training necessary
break;
}
Expand Down Expand Up @@ -1791,6 +1869,11 @@ InvertedListScanner* sel1_InvertedListScanner(
QuantizerFP16<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_bf16:
return sel2_InvertedListScanner<DCTemplate<
QuantizerBF16<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_8bit_direct:
if (sq->d % 16 == 0) {
return sel2_InvertedListScanner<
Expand Down
1 change: 1 addition & 0 deletions faiss/impl/ScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct ScalarQuantizer : Quantizer {
QT_fp16,
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
};

QuantizerType qtype = QT_8bit;
Expand Down
3 changes: 2 additions & 1 deletion faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
{"SQ4", ScalarQuantizer::QT_4bit},
{"SQ6", ScalarQuantizer::QT_6bit},
{"SQfp16", ScalarQuantizer::QT_fp16},
{"SQbf16", ScalarQuantizer::QT_bf16},
};
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16|SQbf16)";

std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
Expand Down
36 changes: 36 additions & 0 deletions faiss/utils/bf16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace faiss {

namespace {

union fp32_bits {
uint32_t as_u32;
float as_f32;
};

} // namespace

inline uint16_t encode_bf16(const float f) {
// Round off
fp32_bits fp;
fp.as_f32 = f;
return static_cast<uint16_t>((fp.as_u32 + 0x8000) >> 16);
}

inline float decode_bf16(const uint16_t v) {
fp32_bits fp;
fp.as_u32 = (uint32_t(v) << 16);
return fp.as_f32;
}

} // namespace faiss
6 changes: 4 additions & 2 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_4variants_ivf(self):
D, I = index.search(xq, 10)
nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum()

for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split():
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16 QT_bf16".split():
qtype = getattr(faiss.ScalarQuantizer, qname)
index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
qtype, faiss.METRIC_L2)
Expand All @@ -349,6 +349,7 @@ def test_4variants_ivf(self):
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit'])
self.assertGreaterEqual(nok['QT_bf16'], nok['QT_8bit'])

def test_4variants(self):
d = 32
Expand All @@ -364,7 +365,7 @@ def test_4variants(self):

nok = {}

for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split():
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16 QT_bf16".split():
qtype = getattr(faiss.ScalarQuantizer, qname)
index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2)
index.train(xt)
Expand All @@ -377,6 +378,7 @@ def test_4variants(self):
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit'])
self.assertGreaterEqual(nok['QT_bf16'], nq * 0.9)


class TestRangeSearch(unittest.TestCase):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_standalone_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def test_SQ2(self):
def test_SQ3(self):
self.compare_accuracy('SQ8', 'SQfp16')

def test_SQ4(self):
self.compare_accuracy('SQ8', 'SQbf16')

def test_PQ(self):
self.compare_accuracy('PQ6x8np', 'PQ8x8np')

Expand Down