From e8b2c7c98347fac16c6e4b408d2e74cd065c920c Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Wed, 17 May 2023 15:04:40 -0700 Subject: [PATCH] Support RAFT from python (#2864) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2864 Adds use_raft to the cloner options. Adds tests for the python interface. Also continue cleanup of data structures to set default arguments. Add flags GPU and NVIDIA_RAFT to get_compile_options() Reviewed By: algoriddle Differential Revision: D45943372 fbshipit-source-id: 276bedf7461e2f61a91ec72aa8695d97156e7fbe --- faiss/gpu/GpuCloner.cpp | 9 ++++ faiss/gpu/GpuClonerOptions.h | 3 ++ faiss/gpu/GpuDistance.h | 56 +++++++-------------- faiss/gpu/GpuIndex.cu | 17 +++++++ faiss/gpu/GpuIndex.h | 6 +-- faiss/gpu/GpuIndexFlat.h | 6 +-- faiss/gpu/GpuIndexIVF.h | 4 +- faiss/gpu/GpuIndexIVFFlat.h | 4 +- faiss/gpu/GpuIndexIVFPQ.h | 14 ++---- faiss/gpu/GpuIndexIVFScalarQuantizer.h | 4 +- faiss/gpu/GpuResources.h | 18 +++---- faiss/gpu/test/test_gpu_basics.py | 6 +++ faiss/gpu/test/test_raft.py | 68 ++++++++++++++++++++++++++ faiss/python/gpu_wrappers.py | 3 +- faiss/utils/utils.cpp | 11 +++-- tests/test_fast_scan.py | 9 ---- 16 files changed, 149 insertions(+), 89 deletions(-) create mode 100644 faiss/gpu/test/test_raft.py diff --git a/faiss/gpu/GpuCloner.cpp b/faiss/gpu/GpuCloner.cpp index 9109981a73..a77ee0485d 100644 --- a/faiss/gpu/GpuCloner.cpp +++ b/faiss/gpu/GpuCloner.cpp @@ -121,6 +121,7 @@ Index* ToGpuCloner::clone_Index(const Index* index) { GpuIndexFlatConfig config; config.device = device; config.useFloat16 = useFloat16; + config.use_raft = use_raft; return new GpuIndexFlat(provider, ifl, config); } else if ( dynamic_cast(index) && @@ -129,6 +130,8 @@ Index* ToGpuCloner::clone_Index(const Index* index) { GpuIndexFlatConfig config; config.device = device; config.useFloat16 = true; + FAISS_THROW_IF_NOT_MSG( + !use_raft, "this type of index is not implemented for RAFT"); GpuIndexFlat* gif = new GpuIndexFlat( provider, index->d, index->metric_type, config); // transfer data by blocks @@ -146,6 +149,8 @@ Index* ToGpuCloner::clone_Index(const Index* index) { config.device = device; config.indicesOptions = indicesOptions; config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + FAISS_THROW_IF_NOT_MSG( + !use_raft, "this type of index is not implemented for RAFT"); GpuIndexIVFFlat* res = new GpuIndexIVFFlat( provider, ifl->d, ifl->nlist, ifl->metric_type, config); @@ -162,6 +167,8 @@ Index* ToGpuCloner::clone_Index(const Index* index) { config.device = device; config.indicesOptions = indicesOptions; config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + FAISS_THROW_IF_NOT_MSG( + !use_raft, "this type of index is not implemented for RAFT"); GpuIndexIVFScalarQuantizer* res = new GpuIndexIVFScalarQuantizer( provider, @@ -194,6 +201,8 @@ Index* ToGpuCloner::clone_Index(const Index* index) { config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; config.useFloat16LookupTables = useFloat16; config.usePrecomputedTables = usePrecomputed; + FAISS_THROW_IF_NOT_MSG( + !use_raft, "this type of index is not implemented for RAFT"); GpuIndexIVFPQ* res = new GpuIndexIVFPQ(provider, ipq, config); diff --git a/faiss/gpu/GpuClonerOptions.h b/faiss/gpu/GpuClonerOptions.h index dd3e6c6e85..fbde4c4ea4 100644 --- a/faiss/gpu/GpuClonerOptions.h +++ b/faiss/gpu/GpuClonerOptions.h @@ -36,6 +36,9 @@ struct GpuClonerOptions { /// Set verbose options on the index bool verbose = false; + + /// use the RAFT implementation + bool use_raft = false; }; struct GpuMultipleClonerOptions : public GpuClonerOptions { diff --git a/faiss/gpu/GpuDistance.h b/faiss/gpu/GpuDistance.h index 3d9a318990..858abe28eb 100644 --- a/faiss/gpu/GpuDistance.h +++ b/faiss/gpu/GpuDistance.h @@ -28,44 +28,24 @@ enum class IndicesDataType { /// Arguments to brute-force GPU k-nearest neighbor searching struct GpuDistanceParams { - GpuDistanceParams() - : metric(faiss::MetricType::METRIC_L2), - metricArg(0), - k(0), - dims(0), - vectors(nullptr), - vectorType(DistanceDataType::F32), - vectorsRowMajor(true), - numVectors(0), - vectorNorms(nullptr), - queries(nullptr), - queryType(DistanceDataType::F32), - queriesRowMajor(true), - numQueries(0), - outDistances(nullptr), - ignoreOutDistances(false), - outIndicesType(IndicesDataType::I64), - outIndices(nullptr), - device(-1) {} - // // Search parameters // /// Search parameter: distance metric - faiss::MetricType metric; + faiss::MetricType metric = METRIC_L2; /// Search parameter: distance metric argument (if applicable) /// For metric == METRIC_Lp, this is the p-value - float metricArg; + float metricArg = 0; /// Search parameter: return k nearest neighbors /// If the value provided is -1, then we report all pairwise distances /// without top-k filtering - int k; + int k = 0; /// Vector dimensionality - int dims; + int dims = 0; // // Vectors being queried @@ -74,14 +54,14 @@ struct GpuDistanceParams { /// If vectorsRowMajor is true, this is /// numVectors x dims, with dims innermost; otherwise, /// dims x numVectors, with numVectors innermost - const void* vectors; - DistanceDataType vectorType; - bool vectorsRowMajor; - idx_t numVectors; + const void* vectors = nullptr; + DistanceDataType vectorType = DistanceDataType::F32; + bool vectorsRowMajor = true; + idx_t numVectors = 0; /// Precomputed L2 norms for each vector in `vectors`, which can be /// optionally provided in advance to speed computation for METRIC_L2 - const float* vectorNorms; + const float* vectorNorms = nullptr; // // The query vectors (i.e., find k-nearest neighbors in `vectors` for each @@ -91,10 +71,10 @@ struct GpuDistanceParams { /// If queriesRowMajor is true, this is /// numQueries x dims, with dims innermost; otherwise, /// dims x numQueries, with numQueries innermost - const void* queries; - DistanceDataType queryType; - bool queriesRowMajor; - idx_t numQueries; + const void* queries = nullptr; + DistanceDataType queryType = DistanceDataType::F32; + bool queriesRowMajor = true; + idx_t numQueries = 0; // // Output results @@ -103,16 +83,16 @@ struct GpuDistanceParams { /// A region of memory size numQueries x k, with k /// innermost (row major) if k > 0, or if k == -1, a region of memory of /// size numQueries x numVectors - float* outDistances; + float* outDistances = nullptr; /// Do we only care about the indices reported, rather than the output /// distances? Not used if k == -1 (all pairwise distances) - bool ignoreOutDistances; + bool ignoreOutDistances = false; /// A region of memory size numQueries x k, with k /// innermost (row major). Not used if k == -1 (all pairwise distances) - IndicesDataType outIndicesType; - void* outIndices; + IndicesDataType outIndicesType = IndicesDataType::I64; + void* outIndices = nullptr; // // Execution information @@ -123,7 +103,7 @@ struct GpuDistanceParams { /// (via cudaGetDevice/cudaSetDevice) is used /// Otherwise, an integer 0 <= device < numDevices indicates the device for /// execution - int device; + int device = -1; /// Should the index dispatch down to RAFT? bool use_raft = false; diff --git a/faiss/gpu/GpuIndex.cu b/faiss/gpu/GpuIndex.cu index 8647e66588..575656db71 100644 --- a/faiss/gpu/GpuIndex.cu +++ b/faiss/gpu/GpuIndex.cu @@ -514,4 +514,21 @@ bool isGpuIndexImplemented(faiss::Index* index) { } } // namespace gpu + +// This is the one defined in utils.cpp +// Crossing fingers that the InitGpuOptions_instance will +// be instanciated after this global variable +extern std::string gpu_options; + +struct InitGpuOptions { + InitGpuOptions() { + gpu_options = "GPU "; +#ifdef USE_NVIDIA_RAFT + gpu_options += "NVIDIA_RAFT "; +#endif + } +}; + +InitGpuOptions InitGpuOptions_instance; + } // namespace faiss diff --git a/faiss/gpu/GpuIndex.h b/faiss/gpu/GpuIndex.h index 8f981ccd74..629a57583d 100644 --- a/faiss/gpu/GpuIndex.h +++ b/faiss/gpu/GpuIndex.h @@ -29,15 +29,13 @@ namespace faiss { namespace gpu { struct GpuIndexConfig { - inline GpuIndexConfig() : device(0), memorySpace(MemorySpace::Device) {} - /// GPU device on which the index is resident - int device; + int device = 0; /// What memory space to use for primary storage. /// On Pascal and above (CC 6+) architectures, allows GPUs to use /// more memory than is available on the GPU. - MemorySpace memorySpace; + MemorySpace memorySpace = MemorySpace::Device; /// Should the index dispatch down to RAFT? bool use_raft = false; diff --git a/faiss/gpu/GpuIndexFlat.h b/faiss/gpu/GpuIndexFlat.h index 514220039d..eb7780187a 100644 --- a/faiss/gpu/GpuIndexFlat.h +++ b/faiss/gpu/GpuIndexFlat.h @@ -24,15 +24,13 @@ namespace gpu { class FlatIndex; struct GpuIndexFlatConfig : public GpuIndexConfig { - inline GpuIndexFlatConfig() : useFloat16(false) {} - /// Whether or not data is stored as float16 - bool useFloat16; + bool useFloat16 = false; /// Deprecated: no longer used /// Previously used to indicate whether internal storage of vectors is /// transposed - bool storeTransposed; + bool storeTransposed = false; }; /// Wrapper around the GPU implementation that looks like diff --git a/faiss/gpu/GpuIndexIVF.h b/faiss/gpu/GpuIndexIVF.h index 7bb77c06a0..48096eaaf0 100644 --- a/faiss/gpu/GpuIndexIVF.h +++ b/faiss/gpu/GpuIndexIVF.h @@ -21,10 +21,8 @@ class GpuIndexFlat; class IVFBase; struct GpuIndexIVFConfig : public GpuIndexConfig { - inline GpuIndexIVFConfig() : indicesOptions(INDICES_64_BIT) {} - /// Index storage options for the GPU - IndicesOptions indicesOptions; + IndicesOptions indicesOptions = INDICES_64_BIT; /// Configuration for the coarse quantizer object GpuIndexFlatConfig flatConfig; diff --git a/faiss/gpu/GpuIndexIVFFlat.h b/faiss/gpu/GpuIndexIVFFlat.h index 3618797df3..9206d20f61 100644 --- a/faiss/gpu/GpuIndexIVFFlat.h +++ b/faiss/gpu/GpuIndexIVFFlat.h @@ -21,11 +21,9 @@ class IVFFlat; class GpuIndexFlat; struct GpuIndexIVFFlatConfig : public GpuIndexIVFConfig { - inline GpuIndexIVFFlatConfig() : interleavedLayout(true) {} - /// Use the alternative memory layout for the IVF lists /// (currently the default) - bool interleavedLayout; + bool interleavedLayout = true; }; /// Wrapper around the GPU implementation that looks like diff --git a/faiss/gpu/GpuIndexIVFPQ.h b/faiss/gpu/GpuIndexIVFPQ.h index 466f902ebf..22e9961675 100644 --- a/faiss/gpu/GpuIndexIVFPQ.h +++ b/faiss/gpu/GpuIndexIVFPQ.h @@ -23,24 +23,18 @@ class GpuIndexFlat; class IVFPQ; struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig { - inline GpuIndexIVFPQConfig() - : useFloat16LookupTables(false), - usePrecomputedTables(false), - interleavedLayout(false), - useMMCodeDistance(false) {} - /// Whether or not float16 residual distance tables are used in the /// list scanning kernels. When subQuantizers * 2^bitsPerCode > /// 16384, this is required. - bool useFloat16LookupTables; + bool useFloat16LookupTables = false; /// Whether or not we enable the precomputed table option for /// search, which can substantially increase the memory requirement. - bool usePrecomputedTables; + bool usePrecomputedTables = false; /// Use the alternative memory layout for the IVF lists /// WARNING: this is a feature under development, do not use! - bool interleavedLayout; + bool interleavedLayout = false; /// Use GEMM-backed computation of PQ code distances for the no precomputed /// table version of IVFPQ. @@ -50,7 +44,7 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig { /// Note that MM code distance is enabled automatically if one uses a number /// of dimensions per sub-quantizer that is not natively specialized (an odd /// number like 7 or so). - bool useMMCodeDistance; + bool useMMCodeDistance = false; }; /// IVFPQ index for the GPU diff --git a/faiss/gpu/GpuIndexIVFScalarQuantizer.h b/faiss/gpu/GpuIndexIVFScalarQuantizer.h index 96dbb3a058..af966bc8ac 100644 --- a/faiss/gpu/GpuIndexIVFScalarQuantizer.h +++ b/faiss/gpu/GpuIndexIVFScalarQuantizer.h @@ -18,11 +18,9 @@ class IVFFlat; class GpuIndexFlat; struct GpuIndexIVFScalarQuantizerConfig : public GpuIndexIVFConfig { - inline GpuIndexIVFScalarQuantizerConfig() : interleavedLayout(true) {} - /// Use the alternative memory layout for the IVF lists /// (currently the default) - bool interleavedLayout; + bool interleavedLayout = true; }; /// Wrapper around the GPU implementation that looks like diff --git a/faiss/gpu/GpuResources.h b/faiss/gpu/GpuResources.h index 5177065374..7d0459955b 100644 --- a/faiss/gpu/GpuResources.h +++ b/faiss/gpu/GpuResources.h @@ -102,11 +102,7 @@ std::string memorySpaceToString(MemorySpace s); /// Information on what/where an allocation is struct AllocInfo { - inline AllocInfo() - : type(AllocType::Other), - device(0), - space(MemorySpace::Device), - stream(nullptr) {} + inline AllocInfo() {} inline AllocInfo(AllocType at, int dev, MemorySpace sp, cudaStream_t st) : type(at), device(dev), space(sp), stream(st) {} @@ -115,13 +111,13 @@ struct AllocInfo { std::string toString() const; /// The internal category of the allocation - AllocType type; + AllocType type = AllocType::Other; /// The device on which the allocation is happening - int device; + int device = 0; /// The memory space of the allocation - MemorySpace space; + MemorySpace space = MemorySpace::Device; /// The stream on which new work on the memory will be ordered (e.g., if a /// piece of memory cached and to be returned for this call was last used on @@ -131,7 +127,7 @@ struct AllocInfo { /// /// The memory manager guarantees that the returned memory is free to use /// without data races on this stream specified. - cudaStream_t stream; + cudaStream_t stream = nullptr; }; /// Create an AllocInfo for the current device with MemorySpace::Device @@ -145,7 +141,7 @@ AllocInfo makeSpaceAlloc(AllocType at, MemorySpace sp, cudaStream_t st); /// Information on what/where an allocation is, along with how big it should be struct AllocRequest : public AllocInfo { - inline AllocRequest() : AllocInfo(), size(0) {} + inline AllocRequest() {} inline AllocRequest(const AllocInfo& info, size_t sz) : AllocInfo(info), size(sz) {} @@ -162,7 +158,7 @@ struct AllocRequest : public AllocInfo { std::string toString() const; /// The size in bytes of the allocation - size_t size; + size_t size = 0; }; /// A RAII object that manages a temporary memory request diff --git a/faiss/gpu/test/test_gpu_basics.py b/faiss/gpu/test/test_gpu_basics.py index 0fa7acbb77..f1604eacd0 100755 --- a/faiss/gpu/test/test_gpu_basics.py +++ b/faiss/gpu/test/test_gpu_basics.py @@ -426,3 +426,9 @@ def test_with_gpu(self): self.assertTrue(0.9 * err_rq0 < err_rq1 < 1.1 * err_rq0) # np.testing.assert_array_equal(codes0, codes1) + + +class TestGpuFlags(unittest.TestCase): + + def test_gpu_flag(self): + assert "GPU" in faiss.get_compile_options().split() diff --git a/faiss/gpu/test/test_raft.py b/faiss/gpu/test/test_raft.py new file mode 100644 index 0000000000..fe99b13cde --- /dev/null +++ b/faiss/gpu/test/test_raft.py @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest +import numpy as np +import faiss +from faiss.contrib.datasets import SyntheticDataset + + +@unittest.skipIf( + "RAFT" not in faiss.get_compile_options(), + "only if RAFT is compiled in") +class TestBfKnn(unittest.TestCase): + + def test_bfKnn(self): + + ds = SyntheticDataset(32, 0, 4321, 1234) + + Dref, Iref = faiss.knn(ds.get_queries(), ds.get_database(), 12) + + res = faiss.StandardGpuResources() + + # Faiss internal implementation + Dnew, Inew = faiss.knn_gpu( + res, ds.get_queries(), ds.get_database(), 12, use_raft=False) + np.testing.assert_allclose(Dref, Dnew, atol=1e-5) + np.testing.assert_array_equal(Iref, Inew) + + # RAFT version + Dnew, Inew = faiss.knn_gpu( + res, ds.get_queries(), ds.get_database(), 12, use_raft=True) + np.testing.assert_allclose(Dref, Dnew, atol=1e-5) + np.testing.assert_array_equal(Iref, Inew) + + def test_IndexFlat(self): + ds = SyntheticDataset(32, 0, 4000, 1234) + + # add only first half of database + xb = ds.get_database() + index = faiss.IndexFlatL2(ds.d) + index.add(xb[:2000]) + Dref, Iref = index.search(ds.get_queries(), 13) + + res = faiss.StandardGpuResources() + co = faiss.GpuClonerOptions() + co.use_raft = True + index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) + Dnew, Inew = index_gpu.search(ds.get_queries(), 13) + np.testing.assert_allclose(Dref, Dnew, atol=1e-5) + np.testing.assert_array_equal(Iref, Inew) + + # add rest of database + index.add(xb[2000:]) + Dref, Iref = index.search(ds.get_queries(), 13) + + index_gpu.add(xb[2000:]) + Dnew, Inew = index_gpu.search(ds.get_queries(), 13) + np.testing.assert_allclose(Dref, Dnew, atol=1e-5) + np.testing.assert_array_equal(Iref, Inew) + + # copy back to CPU + index2 = faiss.index_gpu_to_cpu(index_gpu) + Dnew, Inew = index2.search(ds.get_queries(), 13) + np.testing.assert_allclose(Dref, Dnew, atol=1e-5) + np.testing.assert_array_equal(Iref, Inew) diff --git a/faiss/python/gpu_wrappers.py b/faiss/python/gpu_wrappers.py index 24c24fac39..fd3dd0c1c7 100644 --- a/faiss/python/gpu_wrappers.py +++ b/faiss/python/gpu_wrappers.py @@ -54,7 +54,7 @@ def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): # allows numpy ndarray usage with bfKnn -def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1): +def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1, use_raft=False): """ Compute the k nearest neighbors of a vector on one GPU without constructing an index @@ -168,6 +168,7 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1): args.outIndices = I_ptr args.outIndicesType = I_type args.device = device + args.use_raft = use_raft # no stream synchronization needed, inputs and outputs are guaranteed to # be on the CPU (numpy arrays) diff --git a/faiss/utils/utils.cpp b/faiss/utils/utils.cpp index 3bc362c9c2..994965a9cc 100644 --- a/faiss/utils/utils.cpp +++ b/faiss/utils/utils.cpp @@ -101,6 +101,9 @@ int sgemv_( namespace faiss { +// this will be set at load time from GPU Faiss +std::string gpu_options; + std::string get_compile_options() { std::string options; @@ -110,13 +113,15 @@ std::string get_compile_options() { #endif #ifdef __AVX2__ - options += "AVX2"; + options += "AVX2 "; #elif defined(__aarch64__) - options += "NEON"; + options += "NEON "; #else - options += "GENERIC"; + options += "GENERIC "; #endif + options += gpu_options; + return options; } diff --git a/tests/test_fast_scan.py b/tests/test_fast_scan.py index 7dab10d3fd..b061ee3af0 100644 --- a/tests/test_fast_scan.py +++ b/tests/test_fast_scan.py @@ -17,15 +17,6 @@ # the tests tend to timeout in stress modes + dev otherwise faiss.omp_set_num_threads(4) -class TestCompileOptions(unittest.TestCase): - - def test_compile_options(self): - options = faiss.get_compile_options() - options = options.split(' ') - for option in options: - assert option in ['AVX2', 'NEON', 'GENERIC', 'OPTIMIZE'] - - class TestSearch(unittest.TestCase): def test_PQ4_accuracy(self):