diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 90d26161..4bf91c17 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -260,11 +260,16 @@ class Index { if (!ids_.is_none()) { py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); auto ids_numpy = items.request(); - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; + + if (ids_numpy.ndim == 0) { + throw std::invalid_argument("get_items accepts a list of indices and returns a list of vectors"); + } else { + std::vector ids1(ids_numpy.shape[0]); + for (size_t i = 0; i < ids1.size(); i++) { + ids1[i] = items.data()[i]; + } + ids.swap(ids1); } - ids.swap(ids1); } std::vector> data; diff --git a/python_bindings/tests/bindings_test_getdata.py b/python_bindings/tests/bindings_test_getdata.py index 2985c1dd..515ecebd 100644 --- a/python_bindings/tests/bindings_test_getdata.py +++ b/python_bindings/tests/bindings_test_getdata.py @@ -41,6 +41,9 @@ def testGettingItems(self): print("Adding all elements (%d)" % (len(data))) p.add_items(data, labels) + # Getting data by label should raise an exception if a scalar is passed: + self.assertRaises(ValueError, lambda: p.get_items(labels[0])) + # After adding them, all labels should be retrievable returned_items = p.get_items(labels) self.assertSequenceEqual(data.tolist(), returned_items)