diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index bef961353c..9fc201ea39 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -9,7 +9,6 @@ #include -#include #include #include @@ -543,11 +542,12 @@ int search_from_candidates( for (int i = 0; i < candidates.size(); i++) { idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; - assert(v1 >= 0); + FAISS_ASSERT(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (d < D[0]) { - faiss::maxheap_replace_top(k, D, I, d, v1); - nres++; + if (nres < k) { + faiss::maxheap_push(++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_replace_top(nres, D, I, d, v1); } } vt.set(v1); @@ -612,9 +612,10 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (dis < D[0]) { - faiss::maxheap_replace_top(k, D, I, dis, idx); - nres++; + if (nres < k) { + faiss::maxheap_push(++nres, D, I, dis, idx); + } else if (dis < D[0]) { + faiss::maxheap_replace_top(nres, D, I, dis, idx); } } candidates.push(idx, dis); @@ -667,7 +668,7 @@ int search_from_candidates( stats.n3 += ndis; } - return std::min(nres, k); + return nres; } std::priority_queue search_from_candidate_unbounded( @@ -815,11 +816,6 @@ HNSWStats HNSW::search( // greedy search on upper levels storage_idx_t nearest = entry_point; float d_nearest = qdis(nearest); - if (!std::isfinite(d_nearest)) { - // means either the query or the entry point are NaN: in - // both cases we can only return -1 as a result - return stats; - } for (int level = max_level; level >= 1; level--) { greedy_update_nearest(*this, qdis, level, nearest, d_nearest); @@ -830,6 +826,7 @@ HNSWStats HNSW::search( MinimaxHeap candidates(ef); candidates.push(nearest, d_nearest); + search_from_candidates( *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params); } else { diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index 945f68cf93..d096fbcfa3 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -445,8 +445,8 @@ struct SingleBestResultHandler { /// begin results for query # i void begin(const size_t current_idx) { this->current_idx = current_idx; - min_dis = C::neutral(); - min_idx = -1; + min_dis = HUGE_VALF; + min_idx = 0; } /// add one result for query i @@ -472,8 +472,7 @@ struct SingleBestResultHandler { this->i1 = i1; for (size_t i = i0; i < i1; i++) { - this->dis_tab[i] = C::neutral(); - this->ids_tab[i] = -1; + this->dis_tab[i] = HUGE_VALF; } } diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index b1da370e6f..680a3bc059 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -1075,11 +1075,6 @@ void ScalarQuantizer::set_derived_sizes() { } void ScalarQuantizer::train(size_t n, const float* x) { - for (size_t i = 0; i < n * d; i++) { - FAISS_THROW_IF_NOT_MSG( - std::isfinite(x[i]), "training data contains NaN or Inf"); - } - int bit_per_dim = qtype == QT_4bit_uniform ? 4 : qtype == QT_4bit ? 4 : qtype == QT_6bit ? 6 diff --git a/tests/test_error_reporting.py b/tests/test_error_reporting.py deleted file mode 100644 index 04a023a36a..0000000000 --- a/tests/test_error_reporting.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -"""This script tests a few failure cases of Faiss and whether they are handled -properly.""" - -import numpy as np -import unittest -import faiss - -from common_faiss_tests import get_dataset_2 -from faiss.contrib.datasets import SyntheticDataset - - -class TestValidIndexParams(unittest.TestCase): - - def test_IndexIVFPQ(self): - d = 32 - nb = 1000 - nt = 1500 - nq = 200 - - (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) - - coarse_quantizer = faiss.IndexFlatL2(d) - index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) - index.cp.min_points_per_centroid = 5 # quiet warning - index.train(xt) - index.add(xb) - - # invalid nprobe - index.nprobe = 0 - k = 10 - self.assertRaises(RuntimeError, index.search, xq, k) - - # invalid k - index.nprobe = 4 - k = -10 - self.assertRaises(AssertionError, index.search, xq, k) - - # valid params - index.nprobe = 4 - k = 10 - D, nns = index.search(xq, k) - - self.assertEqual(D.shape[0], nq) - self.assertEqual(D.shape[1], k) - - def test_IndexFlat(self): - d = 32 - nb = 1000 - nt = 0 - nq = 200 - - (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) - index = faiss.IndexFlat(d, faiss.METRIC_L2) - - index.add(xb) - - # invalid k - k = -5 - self.assertRaises(AssertionError, index.search, xq, k) - - # valid k - k = 5 - D, I = index.search(xq, k) - - self.assertEqual(D.shape[0], nq) - self.assertEqual(D.shape[1], k) - - -class TestReconsException(unittest.TestCase): - - def test_recons_exception(self): - - d = 64 # dimension - nb = 1000 - rs = np.random.RandomState(1234) - xb = rs.rand(nb, d).astype('float32') - nlist = 10 - quantizer = faiss.IndexFlatL2(d) # the other index - index = faiss.IndexIVFFlat(quantizer, d, nlist) - index.train(xb) - index.add(xb) - index.make_direct_map() - - index.reconstruct(9) - - self.assertRaises( - RuntimeError, - index.reconstruct, 100001 - ) - - def test_reconstuct_after_add(self): - index = faiss.index_factory(10, 'IVF5,SQfp16') - index.train(faiss.randn((100, 10), 123)) - index.add(faiss.randn((100, 10), 345)) - index.make_direct_map() - index.add(faiss.randn((100, 10), 678)) - - # should not raise an exception - index.reconstruct(5) - print(index.ntotal) - index.reconstruct(150) - - -class TestNaN(unittest.TestCase): - """ NaN values handling is transparent: they don't produce results - but should not crash. The tests below cover a few common index types. - """ - - def do_test_train(self, factory_string): - """ NaN and Inf should raise an exception at train time """ - ds = SyntheticDataset(32, 200, 20, 10) - index = faiss.index_factory(ds.d, factory_string) - # try to train with NaNs - xt = ds.get_train().copy() - xt[:, ::4] = np.nan - self.assertRaises(RuntimeError, index.train, xt) - - def test_train_IVFSQ(self): - self.do_test_train("IVF10,SQ8") - - def test_train_IVFPQ(self): - self.do_test_train("IVF10,PQ4np") - - def test_train_SQ(self): - self.do_test_train("SQ8") - - def do_test_add(self, factory_string): - """ stored NaNs should not be returned at search time """ - ds = SyntheticDataset(32, 200, 20, 10) - index = faiss.index_factory(ds.d, factory_string) - if not index.is_trained: - index.train(ds.get_train()) - xb = ds.get_database() - xb[12, 3] = np.nan - index.add(xb) - D, I = index.search(ds.get_queries(), 20) - self.assertTrue(np.where(I == 12)[0].size == 0) - - def test_add_Flat(self): - self.do_test_add("Flat") - - def test_add_HNSW(self): - self.do_test_add("HNSW32,Flat") - - def xx_test_add_SQ8(self): - # this is expected to fail because: - # in ASAN mode, the float NaN -> int conversion crashes - # in opt mode it works but there is no way to encode the NaN, - # so the value cannot be ignored. - self.do_test_add("SQ8") - - def test_add_IVFFlat(self): - self.do_test_add("IVF10,Flat") - - def do_test_search(self, factory_string): - """ NaN query vectors should return -1 """ - ds = SyntheticDataset(32, 200, 20, 10) - index = faiss.index_factory(ds.d, factory_string) - if not index.is_trained: - index.train(ds.get_train()) - index.add(ds.get_database()) - xq = ds.get_queries() - xq[7, 3] = np.nan - D, I = index.search(ds.get_queries(), 20) - self.assertTrue(np.all(I[7] == -1)) - - def test_search_Flat(self): - self.do_test_search("Flat") - - def test_search_HNSW(self): - self.do_test_search("HNSW32,Flat") - - def test_search_IVFFlat(self): - self.do_test_search("IVF10,Flat") - - def test_search_SQ(self): - self.do_test_search("SQ8") diff --git a/tests/test_index.py b/tests/test_index.py index e850f5aab9..0e828e08c1 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. """this is a basic test script for simple indices work""" +from __future__ import absolute_import, division, print_function +# no unicode_literals because it messes up in py2 import numpy as np import unittest @@ -11,17 +13,16 @@ import tempfile import os import re +import warnings from common_faiss_tests import get_dataset, get_dataset_2 - class TestModuleInterface(unittest.TestCase): def test_version_attribute(self): assert hasattr(faiss, '__version__') assert re.match('^\\d+\\.\\d+\\.\\d+$', faiss.__version__) - class TestIndexFlat(unittest.TestCase): def do_test(self, nq, metric_type=faiss.METRIC_L2, k=10): @@ -108,14 +109,6 @@ def test_noblas_reservoir_ip(self): def test_with_blas_reservoir_ip(self): self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=150) - def test_noblas_1res(self): - self.do_test(10, k=1) - - def test_with_blas_1res(self): - self.do_test(200, k=1) - - def test_with_blas_1res_ip(self): - self.do_test(200, faiss.METRIC_INNER_PRODUCT, k=1) class TestIndexFlatL2(unittest.TestCase): def test_indexflat_l2_sync_norms_1(self): @@ -1014,6 +1007,41 @@ def test_replica_flag_propagation(self): index.remove_replica(index1) self.assertEqual(index.ntotal, 0) +class TestReconsException(unittest.TestCase): + + def test_recons_exception(self): + + d = 64 # dimension + nb = 1000 + rs = np.random.RandomState(1234) + xb = rs.rand(nb, d).astype('float32') + nlist = 10 + quantizer = faiss.IndexFlatL2(d) # the other index + index = faiss.IndexIVFFlat(quantizer, d, nlist) + index.train(xb) + index.add(xb) + index.make_direct_map() + + index.reconstruct(9) + + self.assertRaises( + RuntimeError, + index.reconstruct, 100001 + ) + + def test_reconstuct_after_add(self): + index = faiss.index_factory(10, 'IVF5,SQfp16') + index.train(faiss.randn((100, 10), 123)) + index.add(faiss.randn((100, 10), 345)) + index.make_direct_map() + index.add(faiss.randn((100, 10), 678)) + + # should not raise an exception + index.reconstruct(5) + print(index.ntotal) + index.reconstruct(150) + + class TestReconsHash(unittest.TestCase): def do_test(self, index_key): @@ -1085,6 +1113,62 @@ def test_IVFPQ(self): self.do_test("IVF5,PQ4x4np") +class TestValidIndexParams(unittest.TestCase): + + def test_IndexIVFPQ(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + + coarse_quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.train(xt) + index.add(xb) + + # invalid nprobe + index.nprobe = 0 + k = 10 + self.assertRaises(RuntimeError, index.search, xq, k) + + # invalid k + index.nprobe = 4 + k = -10 + self.assertRaises(AssertionError, index.search, xq, k) + + # valid params + index.nprobe = 4 + k = 10 + D, nns = index.search(xq, k) + + self.assertEqual(D.shape[0], nq) + self.assertEqual(D.shape[1], k) + + def test_IndexFlat(self): + d = 32 + nb = 1000 + nt = 0 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + index = faiss.IndexFlat(d, faiss.METRIC_L2) + + index.add(xb) + + # invalid k + k = -5 + self.assertRaises(AssertionError, index.search, xq, k) + + # valid k + k = 5 + D, I = index.search(xq, k) + + self.assertEqual(D.shape[0], nq) + self.assertEqual(D.shape[1], k) + class TestLargeRangeSearch(unittest.TestCase):