Skip to content

Commit

Permalink
Throw when attempting to move IndexPQ to GPU (#3328)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3328

Reviewed By: junjieqi

Differential Revision: D55476917

fbshipit-source-id: e7f64adefa07650fda32ad2300a1b933cedc9c79
  • Loading branch information
ramilbakhshyiev authored and facebook-github-bot committed Mar 29, 2024
1 parent 4e6b6f8 commit 77e2e79
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions faiss/gpu/GpuCloner.cpp
Expand Up @@ -224,6 +224,8 @@ faiss::Index* index_cpu_to_gpu(
int device,
const faiss::Index* index,
const GpuClonerOptions* options) {
auto index_pq = dynamic_cast<const faiss::IndexPQ*>(index);
FAISS_THROW_IF_MSG(index_pq, "This index type is not implemented on GPU.");
GpuClonerOptions defaults;
ToGpuCloner cl(provider, device, options ? *options : defaults);
return cl.clone_Index(index);
Expand Down
29 changes: 29 additions & 0 deletions faiss/gpu/test/test_index_cpu_to_gpu.py
@@ -0,0 +1,29 @@
import numpy as np
import unittest
import faiss


class TestMoveToGpu(unittest.TestCase):
def test_index_cpu_to_gpu(self):
dimension = 128
n = 2500
db_vectors = np.random.random((n, dimension)).astype('float32')
code_size = 16
res = faiss.StandardGpuResources()
index_pq = faiss.IndexPQ(dimension, code_size, 6)
index_pq.train(db_vectors)
index_pq.add(db_vectors)
self.assertRaisesRegex(Exception, ".*not implemented.*",
faiss.index_cpu_to_gpu, res, 0, index_pq)

def test_index_cpu_to_gpu_does_not_throw_with_index_flat(self):
dimension = 128
n = 100
db_vectors = np.random.random((n, dimension)).astype('float32')
res = faiss.StandardGpuResources()
index_flat = faiss.IndexFlatL2(dimension)
index_flat.add(db_vectors)
try:
faiss.index_cpu_to_gpu(res, 0, index_flat)
except Exception:
self.fail("index_cpu_to_gpu() threw an unexpected exception.")
7 changes: 7 additions & 0 deletions faiss/impl/FaissAssert.h
Expand Up @@ -94,6 +94,13 @@
} \
} while (false)

#define FAISS_THROW_IF_MSG(X, MSG) \
do { \
if (X) { \
FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
} \
} while (false)

#define FAISS_THROW_IF_NOT_MSG(X, MSG) \
do { \
if (!(X)) { \
Expand Down

0 comments on commit 77e2e79

Please sign in to comment.