/
SaDecodeKernels.h
322 lines (318 loc) · 11.8 KB
/
SaDecodeKernels.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
// This file contains a custom fast implementation of faiss::Index::sa_decode()
// function for the following index families:
// * IVF256,PQ[1]x8np
// * Residual[1]x8,PQ[2]x8
// * IVF[2^9-2^16 bit],PQ[1]x8 (such as IVF1024,PQ16np)
// * Residual1x[9-16 bit],PQ[1]x8 (such as Residual1x9,PQ8)
// * PQ[1]x8
// Additionally, AVX2 and ARM versions support
// * Residual[1]x8,PQ[2]x10
// * Residual[1]x8,PQ[2]x12
// * Residual[1]x8,PQ[2]x16
// * Residual[1]x10,PQ[2]x10
// * Residual[1]x10,PQ[2]x12
// * Residual[1]x10,PQ[2]x16
// * Residual[1]x12,PQ[2]x10
// * Residual[1]x12,PQ[2]x12
// * Residual[1]x12,PQ[2]x16
// * Residual[1]x16,PQ[2]x10
// * Residual[1]x16,PQ[2]x12
// * Residual[1]x16,PQ[2]x16
// * Residual1x[9-16 bit],PQ[1]x10 (such as Residual1x9,PQ16x10)
// * * (use with COARSE_BITS=16)
// * Residual1x[9-16 bit],PQ[1]x12 (such as Residual1x9,PQ16x12)
// * * (use with COARSE_BITS=16)
// * Residual1x[9-16 bit],PQ[1]x16 (such as Residual1x9,PQ16x16)
// * * (use with COARSE_BITS=16)
// * PQ[1]x10
// * PQ[1]x12
// * PQ[1]x16
// * IVF256,PQ[1]x10 (such as IVF256,PQ16x10np)
// * IVF256,PQ[1]x12 (such as IVF256,PQ16x12np)
// * IVF256,PQ[1]x16 (such as IVF256,PQ16x16np)
// * IVF[2^9-2^16 bit],PQ[1]x10 (such as IVF1024,PQ16x10np)
// * IVF[2^9-2^16 bit],PQ[1]x12 (such as IVF1024,PQ16x12np)
// * IVF[2^9-2^16 bit],PQ[1]x16 (such as IVF1024,PQ16x16np)
//
// The goal was to achieve the maximum performance, so the template version it
// is. The provided index families share the same code for sa_decode.
//
// The front-end code provides two high-level structures.
//
// First one:
// {
// template <
// intptr_t DIM,
// intptr_t COARSE_SIZE,
// intptr_t FINE_SIZE,
// intptr_t COARSE_BITS = 8
// intptr_t FINE_BITS = 8>
// struct Index2LevelDecoder { /*...*/ };
// }
// * DIM is the dimensionality of data
// * COARSE_SIZE is the dimensionality of the coarse quantizer (IVF, Residual)
// * FINE_SIZE is the dimensionality of the ProductQuantizer dsq
// * COARSE_BITS is the number of bits that are needed to represent a coarse
// quantizer code.
// * FINE_BITS is the number of bits that are needed to represent a fine
// quantizer code.
// For example, "IVF256,PQ8np" for 160-dim data translates into
// Index2LevelDecoder<160,160,20,8>
// For example, "Residual4x8,PQ16" for 256-dim data translates into
// Index2LevelDecoder<256,64,1,8>
// For example, "IVF1024,PQ16np" for 256-dim data translates into
// Index2LevelDecoder<256,256,16,10>. But as there are only 1 coarse code
// element, Index2LevelDecoder<256,256,16,16> can be used as a faster
// decoder.
// For example, "Residual4x10,PQ16x10np" for 256-dim data translates into
// Index2LevelDecoder<256,64,16,10,10>
// For example, "IVF1024,PQ16x10np" for 256-dim data translates into
// Index2LevelDecoder<256,256,16,10,10>. But as there are only 1 coarse code
// element, Index2LevelDecoder<256,256,16,16,10> can be used as a faster
// decoder.
//
// Additional supported values for COARSE_BITS and FINE_BITS may be added later.
//
// Second one:
// {
// template <
// intptr_t DIM,
// intptr_t FINE_SIZE,
// intptr_t FINE_BITS = 8>
// struct IndexPQDecoder { /*...*/ };
// }
// * DIM is the dimensionality of data
// * FINE_SIZE is the dimensionality of the ProductQuantizer dsq
// * FINE_BITS is the number of bits that are needed to represent a fine
// quantizer code.
// For example, "PQ8np" for 160-dim data translates into
// IndexPQDecoder<160,20>
//
// Unlike the general purpose version in faiss::Index::sa_decode(),
// this version provides the following functions (please note that
// pqCoarseCentroids params are not available for IndexPQDecoder,
// but the functionality is the same as for Index2LevelDecoder):
//
// * ::store(), which is similar to sa_decode(1, input, output),
// The method signature is the following:
// {
// void store(
// const float* const __restrict pqCoarseCentroids,
// const float* const __restrict pqFineCentroids,
// const uint8_t* const __restrict code,
// float* const __restrict outputStore);
// }
//
// * ::accum(), which is used to create a linear combination
// of decoded vectors:
// {
// const faiss::Index* const index;
// const uint8_t* const input;
// float weight;
//
// std::vector<float> buffer(d, 0);
//
// index->sa_decode(1, input, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight * buffer[iDim];
// }
// The method signature is the following:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids,
// const float* const __restrict pqFineCentroids,
// const uint8_t* const __restrict code,
// const float weight,
// float* const __restrict outputAccum);
// }
//
// * There is an additional overload for ::accum() that decodes two vectors
// per call. This provides an additional speedup because of a CPU
// superscalar architecture:
// {
// const faiss::Index* const index;
// const uint8_t* const input0;
// float weight0;
// const uint8_t* const input1;
// float weight1;
//
// std::vector<float> buffer(d, 0);
//
// index->sa_decode(1, input0, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight0 * buffer[iDim];
//
// index->sa_decode(1, input1, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight1 * buffer[iDim];
// }
// If each code uses its own coarse quantizer centroids table and its own fine
// quantizer centroids table, then the following overload can be used:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids0,
// const float* const __restrict pqFineCentroids0,
// const uint8_t* const __restrict code0,
// const float weight0,
// const float* const __restrict pqCoarseCentroids1,
// const float* const __restrict pqFineCentroids1,
// const uint8_t* const __restrict code1,
// const float weight1,
// float* const __restrict outputAccum);
// }
// If codes share the coarse quantizer centroids table and also share
// the fine quantizer centroids table, then the following overload can be
// used:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids,
// const float* const __restrict pqFineCentroids,
// const uint8_t* const __restrict code0,
// const float weight0,
// const uint8_t* const __restrict code1,
// const float weight1,
// float* const __restrict outputAccum);
// }
//
// * And one more overload for ::accum() that decodes and accumulates
// three vectors per call.
// {
// const faiss::Index* const index;
// const uint8_t* const input0;
// float weight0;
// const uint8_t* const input1;
// float weight1;
// const uint8_t* const input2;
// float weight2;
//
// std::vector<float> buffer(d, 0);
//
// index->sa_decode(1, input0, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight0 * buffer[iDim];
//
// index->sa_decode(1, input1, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight1 * buffer[iDim];
//
// index->sa_decode(1, input2, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight2 * buffer[iDim];
// }
//
// If each code uses its own coarse quantizer centroids table and its own fine
// quantizer centroids table, then the following overload can be used:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids0,
// const float* const __restrict pqFineCentroids0,
// const uint8_t* const __restrict code0,
// const float weight0,
// const float* const __restrict pqCoarseCentroids1,
// const float* const __restrict pqFineCentroids1,
// const uint8_t* const __restrict code1,
// const float weight1,
// const float* const __restrict pqCoarseCentroids2,
// const float* const __restrict pqFineCentroids2,
// const uint8_t* const __restrict code2,
// const float weight2,
// float* const __restrict outputAccum);
// }
// If codes share the coarse quantizer centroids table and also share
// the fine quantizer centroids table, then the following overload can be
// used:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids,
// const float* const __restrict pqFineCentroids,
// const uint8_t* const __restrict code0,
// const float weight0,
// const uint8_t* const __restrict code1,
// const float weight1,
// const uint8_t* const __restrict code2,
// const float weight2,
// float* const __restrict outputAccum);
// }
//
// The provided version is not multithreaded.
//
// Currently, an AVX2+FMA implementation is available. AVX512 version is also
// doable, but it was found to be slower than AVX2 for real world applications
// that I needed.
//
////////////////////////////////////////////////////////////////////////////////////
//
// It is possible to use an additional index wrapper on top of IVFPQ /
// Residual+PQ, known as IndexRowwiseMinMax / IndexRowwiseMinMaxFP16. Index
// wrapper that performs rowwise normalization to [0,1], preserving the
// coefficients. This is a vector codec index only.
// For more details please refer to the description in
// faiss/IndexRowwiseMinMax.h file.
//
// If such a wrapper is used, then the quantizer will look like, say,
// MinMaxFP16,IVF256,PQ32np
// or
// MinMax,PQ16np
// In this case, please use the following contruction for the decoding,
// basically, wrapping a kernel in a kernel:
// {
// using SubT = faiss::cppcontrib::Index2LevelDecoder<128, 128, 2>;
// using T = faiss::cppcontrib::IndexMinMaxFP16Decoder<SubT>;
// // do T::store(...) or T::accum(...)
// }
//
// T::accum(...) contains an additional function variable which is
// used for accumulating scaling. Thus, the code pattern is the following:
// {
// const float* const __restrict pqCoarseCentroidsQ;
// const float* const __restrict pqFineCentroidsQ;
// const uint8_t* const __restrict input;
// const float* const __restrict weights;
// float* const __restrict output;
// float outputAccumMin = 0;
//
// for (size_t i = 0; i < n; i++) {
// T::accum(
// pqCoarseCentroidsQ,
// pqFineCentroidsQ,
// input + i * code_size,
// weights[i],
// output,
// outputAccumMin);
// }
// for (size_t j = 0; j < d; j++)
// output[j] += outputAccumMin;
// }
// This is similar to the following regular pseudo-code:
// {
// const faiss::Index* const index;
// const uint8_t* const __restrict input;
// const float* const __restrict weights;
// float* const __restrict output;
//
// for (size_t i = 0; i < n; i++) {
// std::vector<float> buffer(d, 0);
//
// index->sa_decode(1, input + i * code_size, buffer.data());
// for (size_t j = 0; j < d; j++)
// output[j] += weights[i] * buffer[j];
// }
#include <faiss/cppcontrib/sa_decode/MinMax-inl.h>
#include <faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h>
#ifdef __AVX2__
#include <faiss/cppcontrib/sa_decode/Level2-avx2-inl.h>
#include <faiss/cppcontrib/sa_decode/PQ-avx2-inl.h>
#elif defined(__ARM_NEON)
#include <faiss/cppcontrib/sa_decode/Level2-neon-inl.h>
#include <faiss/cppcontrib/sa_decode/PQ-neon-inl.h>
#else
#include <faiss/cppcontrib/sa_decode/Level2-inl.h>
#include <faiss/cppcontrib/sa_decode/PQ-inl.h>
#endif