diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 66dc82c26bf..d657053b59c 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -383,7 +383,9 @@ def _to_hashed_id(doc_id: Optional[str]) -> int: def _load_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index: """Load an existing HNSW index from disk.""" index = self._create_index_class(col) - index.load_index(self._hnsw_locations[col_name]) + index.load_index( + self._hnsw_locations[col_name], max_elements=col.config['max_elements'] + ) return index # HNSWLib helpers diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index e56452d0f51..e3c8c455bba 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -87,95 +87,95 @@ def test_parametrization(): with pytest.raises(ValueError): DummyDocIndex() - store = DummyDocIndex[SimpleDoc]() - assert store._schema is SimpleDoc + index = DummyDocIndex[SimpleDoc]() + assert index._schema is SimpleDoc def test_build_query(): - store = DummyDocIndex[SimpleDoc]() - q = store.build_query() - assert isinstance(q, store.QueryBuilder) + index = DummyDocIndex[SimpleDoc]() + q = index.build_query() + assert isinstance(q, index.QueryBuilder) def test_create_columns(): # Simple doc - store = DummyDocIndex[SimpleDoc]() - assert list(store._column_infos.keys()) == ['id', 'tens'] + index = DummyDocIndex[SimpleDoc]() + assert list(index._column_infos.keys()) == ['id', 'tens'] - assert store._column_infos['id'].docarray_type == ID - assert store._column_infos['id'].db_type == str - assert store._column_infos['id'].n_dim is None - assert store._column_infos['id'].config == {'hi': 'there'} + assert index._column_infos['id'].docarray_type == ID + assert index._column_infos['id'].db_type == str + assert index._column_infos['id'].n_dim is None + assert index._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['tens'].docarray_type, AbstractTensor) - assert store._column_infos['tens'].db_type == str - assert store._column_infos['tens'].n_dim == 10 - assert store._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'} + assert issubclass(index._column_infos['tens'].docarray_type, AbstractTensor) + assert index._column_infos['tens'].db_type == str + assert index._column_infos['tens'].n_dim == 10 + assert index._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'} # Flat doc - store = DummyDocIndex[FlatDoc]() - assert list(store._column_infos.keys()) == ['id', 'tens_one', 'tens_two'] + index = DummyDocIndex[FlatDoc]() + assert list(index._column_infos.keys()) == ['id', 'tens_one', 'tens_two'] - assert store._column_infos['id'].docarray_type == ID - assert store._column_infos['id'].db_type == str - assert store._column_infos['id'].n_dim is None - assert store._column_infos['id'].config == {'hi': 'there'} + assert index._column_infos['id'].docarray_type == ID + assert index._column_infos['id'].db_type == str + assert index._column_infos['id'].n_dim is None + assert index._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['tens_one'].docarray_type, AbstractTensor) - assert store._column_infos['tens_one'].db_type == str - assert store._column_infos['tens_one'].n_dim is None - assert store._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'} + assert issubclass(index._column_infos['tens_one'].docarray_type, AbstractTensor) + assert index._column_infos['tens_one'].db_type == str + assert index._column_infos['tens_one'].n_dim is None + assert index._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'} - assert issubclass(store._column_infos['tens_two'].docarray_type, AbstractTensor) - assert store._column_infos['tens_two'].db_type == str - assert store._column_infos['tens_two'].n_dim is None - assert store._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'} + assert issubclass(index._column_infos['tens_two'].docarray_type, AbstractTensor) + assert index._column_infos['tens_two'].db_type == str + assert index._column_infos['tens_two'].n_dim is None + assert index._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'} # Nested doc - store = DummyDocIndex[NestedDoc]() - assert list(store._column_infos.keys()) == ['id', 'd__id', 'd__tens'] + index = DummyDocIndex[NestedDoc]() + assert list(index._column_infos.keys()) == ['id', 'd__id', 'd__tens'] - assert store._column_infos['id'].docarray_type == ID - assert store._column_infos['id'].db_type == str - assert store._column_infos['id'].n_dim is None - assert store._column_infos['id'].config == {'hi': 'there'} + assert index._column_infos['id'].docarray_type == ID + assert index._column_infos['id'].db_type == str + assert index._column_infos['id'].n_dim is None + assert index._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['d__tens'].docarray_type, AbstractTensor) - assert store._column_infos['d__tens'].db_type == str - assert store._column_infos['d__tens'].n_dim == 10 - assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} + assert issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor) + assert index._column_infos['d__tens'].db_type == str + assert index._column_infos['d__tens'].n_dim == 10 + assert index._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} def test_flatten_schema(): - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() fields = SimpleDoc.__fields__ - assert set(store._flatten_schema(SimpleDoc)) == { + assert set(index._flatten_schema(SimpleDoc)) == { ('id', ID, fields['id']), ('tens', AbstractTensor, fields['tens']), } - store = DummyDocIndex[FlatDoc]() + index = DummyDocIndex[FlatDoc]() fields = FlatDoc.__fields__ - assert set(store._flatten_schema(FlatDoc)) == { + assert set(index._flatten_schema(FlatDoc)) == { ('id', ID, fields['id']), ('tens_one', AbstractTensor, fields['tens_one']), ('tens_two', AbstractTensor, fields['tens_two']), } - store = DummyDocIndex[NestedDoc]() + index = DummyDocIndex[NestedDoc]() fields = NestedDoc.__fields__ fields_nested = SimpleDoc.__fields__ - assert set(store._flatten_schema(NestedDoc)) == { + assert set(index._flatten_schema(NestedDoc)) == { ('id', ID, fields['id']), ('d__id', ID, fields_nested['id']), ('d__tens', AbstractTensor, fields_nested['tens']), } - store = DummyDocIndex[DeepNestedDoc]() + index = DummyDocIndex[DeepNestedDoc]() fields = DeepNestedDoc.__fields__ fields_nested = NestedDoc.__fields__ fields_nested_nested = SimpleDoc.__fields__ - assert set(store._flatten_schema(DeepNestedDoc)) == { + assert set(index._flatten_schema(DeepNestedDoc)) == { ('id', ID, fields['id']), ('d__id', ID, fields_nested['id']), ('d__d__id', ID, fields_nested_nested['id']), @@ -187,14 +187,14 @@ def test_flatten_schema_union(): class MyDoc(BaseDoc): image: ImageDoc - store = DummyDocIndex[MyDoc]() + index = DummyDocIndex[MyDoc]() fields = MyDoc.__fields__ fields_image = ImageDoc.__fields__ if torch_imported: from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor - assert set(store._flatten_schema(MyDoc)) == { + assert set(index._flatten_schema(MyDoc)) == { ('id', ID, fields['id']), ('image__id', ID, fields_image['id']), ('image__url', ImageUrl, fields_image['url']), @@ -212,9 +212,9 @@ class MyDoc2(BaseDoc): class MyDoc3(BaseDoc): tensor: Union[NdArray, ImageTorchTensor] - store = DummyDocIndex[MyDoc3]() + index = DummyDocIndex[MyDoc3]() fields = MyDoc3.__fields__ - assert set(store._flatten_schema(MyDoc3)) == { + assert set(index._flatten_schema(MyDoc3)) == { ('id', ID, fields['id']), ('tensor', AbstractTensor, fields['tensor']), } @@ -224,19 +224,19 @@ def test_columns_db_type_with_user_defined_mapping(tmp_path): class MyDoc(BaseDoc): tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray) - store = DummyDocIndex[MyDoc](work_dir=str(tmp_path)) + index = DummyDocIndex[MyDoc](work_dir=str(tmp_path)) - assert store._column_infos['tens'].db_type == np.ndarray + assert index._column_infos['tens'].db_type == np.ndarray def test_columns_db_type_with_user_defined_mapping_additional_params(tmp_path): class MyDoc(BaseDoc): tens: NdArray[10] = Field(dim=1000, col_type='varchar', max_len=1024) - store = DummyDocIndex[MyDoc](work_dir=str(tmp_path)) + index = DummyDocIndex[MyDoc](work_dir=str(tmp_path)) - assert store._column_infos['tens'].db_type == 'varchar' - assert store._column_infos['tens'].config['max_len'] == 1024 + assert index._column_infos['tens'].db_type == 'varchar' + assert index._column_infos['tens'].config['max_len'] == 1024 def test_columns_illegal_mapping(tmp_path): @@ -260,18 +260,18 @@ class OtherNestedDoc(NestedDoc): ... # SIMPLE - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[SimpleDoc](in_list) - assert store._validate_docs(in_da) == in_da + assert index._validate_docs(in_da) == in_da in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))] - assert isinstance(store._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) in_other_da = DocList[OtherSimpleDoc](in_other_list) - assert store._validate_docs(in_other_da) == in_other_da + assert index._validate_docs(in_other_da) == in_other_da with pytest.raises(ValueError): - store._validate_docs( + index._validate_docs( [ FlatDoc( tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) @@ -279,7 +279,7 @@ class OtherNestedDoc(NestedDoc): ] ) with pytest.raises(ValueError): - store._validate_docs( + index._validate_docs( DocList[FlatDoc]( [ FlatDoc( @@ -291,19 +291,19 @@ class OtherNestedDoc(NestedDoc): ) # FLAT - store = DummyDocIndex[FlatDoc]() + index = DummyDocIndex[FlatDoc]() in_list = [ FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[FlatDoc]( [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] ) - assert store._validate_docs(in_da) == in_da + assert index._validate_docs(in_da) == in_da in_other_list = [ OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] - assert isinstance(store._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) in_other_da = DocList[OtherFlatDoc]( [ OtherFlatDoc( @@ -311,31 +311,31 @@ class OtherNestedDoc(NestedDoc): ) ] ) - assert store._validate_docs(in_other_da) == in_other_da + assert index._validate_docs(in_other_da) == in_other_da with pytest.raises(ValueError): - store._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) + index._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) with pytest.raises(ValueError): - assert not store._validate_docs( + assert not index._validate_docs( DocList[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) ) # NESTED - store = DummyDocIndex[NestedDoc]() + index = DummyDocIndex[NestedDoc]() in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]) - assert store._validate_docs(in_da) == in_da + assert index._validate_docs(in_da) == in_da in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] - assert isinstance(store._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) in_other_da = DocList[OtherNestedDoc]( [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] ) - assert store._validate_docs(in_other_da) == in_other_da + assert index._validate_docs(in_other_da) == in_other_da with pytest.raises(ValueError): - store._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) + index._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) with pytest.raises(ValueError): - store._validate_docs( + index._validate_docs( DocList[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) ) @@ -351,37 +351,37 @@ class TensorUnionDoc(BaseDoc): tens: Union[NdArray[10], AbstractTensor] = Field(dim=1000) # OPTIONAL - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() in_list = [OptionalDoc(tens=np.random.random((10,)))] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[OptionalDoc](in_list) - assert store._validate_docs(in_da) == in_da + assert index._validate_docs(in_da) == in_da with pytest.raises(ValueError): - store._validate_docs([OptionalDoc(tens=None)]) + index._validate_docs([OptionalDoc(tens=None)]) # MIXED UNION - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() in_list = [MixedUnionDoc(tens=np.random.random((10,)))] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[MixedUnionDoc](in_list) - assert isinstance(store._validate_docs(in_da), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_da), DocList[BaseDoc]) with pytest.raises(ValueError): - store._validate_docs([MixedUnionDoc(tens='hello')]) + index._validate_docs([MixedUnionDoc(tens='hello')]) # TENSOR UNION - store = DummyDocIndex[TensorUnionDoc]() + index = DummyDocIndex[TensorUnionDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[SimpleDoc](in_list) - assert store._validate_docs(in_da) == in_da + assert index._validate_docs(in_da) == in_da - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() in_list = [TensorUnionDoc(tens=np.random.random((10,)))] - assert isinstance(store._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) in_da = DocList[TensorUnionDoc](in_list) - assert store._validate_docs(in_da) == in_da + assert index._validate_docs(in_da) == in_da def test_get_value(): @@ -405,38 +405,38 @@ def test_get_value(): def test_get_data_by_columns(): - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() docs = [SimpleDoc(tens=np.random.random((10,))) for _ in range(10)] - data_by_columns = store._get_col_value_dict(docs) + data_by_columns = index._get_col_value_dict(docs) assert list(data_by_columns.keys()) == ['id', 'tens'] assert list(data_by_columns['id']) == [doc.id for doc in docs] assert list(data_by_columns['tens']) == [doc.tens for doc in docs] - store = DummyDocIndex[FlatDoc]() + index = DummyDocIndex[FlatDoc]() docs = [ FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) for _ in range(10) ] - data_by_columns = store._get_col_value_dict(docs) + data_by_columns = index._get_col_value_dict(docs) assert list(data_by_columns.keys()) == ['id', 'tens_one', 'tens_two'] assert list(data_by_columns['id']) == [doc.id for doc in docs] assert list(data_by_columns['tens_one']) == [doc.tens_one for doc in docs] assert list(data_by_columns['tens_two']) == [doc.tens_two for doc in docs] - store = DummyDocIndex[NestedDoc]() + index = DummyDocIndex[NestedDoc]() docs = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,)))) for _ in range(10)] - data_by_columns = store._get_col_value_dict(docs) + data_by_columns = index._get_col_value_dict(docs) assert list(data_by_columns.keys()) == ['id', 'd__id', 'd__tens'] assert list(data_by_columns['id']) == [doc.id for doc in docs] assert list(data_by_columns['d__id']) == [doc.d.id for doc in docs] assert list(data_by_columns['d__tens']) == [doc.d.tens for doc in docs] - store = DummyDocIndex[DeepNestedDoc]() + index = DummyDocIndex[DeepNestedDoc]() docs = [ DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))) for _ in range(10) ] - data_by_columns = store._get_col_value_dict(docs) + data_by_columns = index._get_col_value_dict(docs) assert list(data_by_columns.keys()) == ['id', 'd__id', 'd__d__id', 'd__d__tens'] assert list(data_by_columns['id']) == [doc.id for doc in docs] assert list(data_by_columns['d__id']) == [doc.d.id for doc in docs] @@ -445,45 +445,45 @@ def test_get_data_by_columns(): def test_transpose_data_by_columns(): - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() docs = [SimpleDoc(tens=np.random.random((10,))) for _ in range(10)] - data_by_columns = store._get_col_value_dict(docs) - data_by_rows = list(store._transpose_col_value_dict(data_by_columns)) + data_by_columns = index._get_col_value_dict(docs) + data_by_rows = list(index._transpose_col_value_dict(data_by_columns)) assert len(data_by_rows) == len(docs) for doc, row in zip(docs, data_by_rows): assert doc.id == row['id'] assert np.all(doc.tens == row['tens']) - store = DummyDocIndex[FlatDoc]() + index = DummyDocIndex[FlatDoc]() docs = [ FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) for _ in range(10) ] - data_by_columns = store._get_col_value_dict(docs) - data_by_rows = list(store._transpose_col_value_dict(data_by_columns)) + data_by_columns = index._get_col_value_dict(docs) + data_by_rows = list(index._transpose_col_value_dict(data_by_columns)) assert len(data_by_rows) == len(docs) for doc, row in zip(docs, data_by_rows): assert doc.id == row['id'] assert np.all(doc.tens_one == row['tens_one']) assert np.all(doc.tens_two == row['tens_two']) - store = DummyDocIndex[NestedDoc]() + index = DummyDocIndex[NestedDoc]() docs = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,)))) for _ in range(10)] - data_by_columns = store._get_col_value_dict(docs) - data_by_rows = list(store._transpose_col_value_dict(data_by_columns)) + data_by_columns = index._get_col_value_dict(docs) + data_by_rows = list(index._transpose_col_value_dict(data_by_columns)) assert len(data_by_rows) == len(docs) for doc, row in zip(docs, data_by_rows): assert doc.id == row['id'] assert doc.d.id == row['d__id'] assert np.all(doc.d.tens == row['d__tens']) - store = DummyDocIndex[DeepNestedDoc]() + index = DummyDocIndex[DeepNestedDoc]() docs = [ DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))) for _ in range(10) ] - data_by_columns = store._get_col_value_dict(docs) - data_by_rows = list(store._transpose_col_value_dict(data_by_columns)) + data_by_columns = index._get_col_value_dict(docs) + data_by_rows = list(index._transpose_col_value_dict(data_by_columns)) assert len(data_by_rows) == len(docs) for doc, row in zip(docs, data_by_rows): assert doc.id == row['id'] @@ -493,32 +493,32 @@ def test_transpose_data_by_columns(): def test_convert_dict_to_doc(): - store = DummyDocIndex[SimpleDoc]() + index = DummyDocIndex[SimpleDoc]() doc_dict = {'id': 'simple', 'tens': np.random.random((10,))} - doc = store._convert_dict_to_doc(doc_dict, store._schema) + doc = index._convert_dict_to_doc(doc_dict, index._schema) assert doc.id == doc_dict['id'] assert np.all(doc.tens == doc_dict['tens']) - store = DummyDocIndex[FlatDoc]() + index = DummyDocIndex[FlatDoc]() doc_dict = { 'id': 'nested', 'tens_one': np.random.random((10,)), 'tens_two': np.random.random((50,)), } - doc = store._convert_dict_to_doc(doc_dict, store._schema) + doc = index._convert_dict_to_doc(doc_dict, index._schema) assert doc.id == doc_dict['id'] assert np.all(doc.tens_one == doc_dict['tens_one']) assert np.all(doc.tens_two == doc_dict['tens_two']) - store = DummyDocIndex[NestedDoc]() + index = DummyDocIndex[NestedDoc]() doc_dict = {'id': 'nested', 'd__id': 'simple', 'd__tens': np.random.random((10,))} doc_dict_copy = doc_dict.copy() - doc = store._convert_dict_to_doc(doc_dict, store._schema) + doc = index._convert_dict_to_doc(doc_dict, index._schema) assert doc.id == doc_dict_copy['id'] assert doc.d.id == doc_dict_copy['d__id'] assert np.all(doc.d.tens == doc_dict_copy['d__tens']) - store = DummyDocIndex[DeepNestedDoc]() + index = DummyDocIndex[DeepNestedDoc]() doc_dict = { 'id': 'deep', 'd__id': 'nested', @@ -526,7 +526,7 @@ def test_convert_dict_to_doc(): 'd__d__tens': np.random.random((10,)), } doc_dict_copy = doc_dict.copy() - doc = store._convert_dict_to_doc(doc_dict, store._schema) + doc = index._convert_dict_to_doc(doc_dict, index._schema) assert doc.id == doc_dict_copy['id'] assert doc.d.id == doc_dict_copy['d__id'] assert doc.d.d.id == doc_dict_copy['d__d__id'] @@ -535,13 +535,13 @@ def test_convert_dict_to_doc(): class MyDoc(BaseDoc): image: ImageDoc - store = DummyDocIndex[MyDoc]() + index = DummyDocIndex[MyDoc]() doc_dict = { 'id': 'root', 'image__id': 'nested', 'image__tensor': np.random.random((128,)), } - doc = store._convert_dict_to_doc(doc_dict, store._schema) + doc = index._convert_dict_to_doc(doc_dict, index._schema) if torch_imported: from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor @@ -549,26 +549,26 @@ class MyDoc(BaseDoc): class MyDoc2(BaseDoc): tens: Union[NdArray, ImageTorchTensor] - store = DummyDocIndex[MyDoc2]() + index = DummyDocIndex[MyDoc2]() doc_dict = { 'id': 'root', 'tens': np.random.random((128,)), } doc_dict_copy = doc_dict.copy() - doc = store._convert_dict_to_doc(doc_dict, store._schema) + doc = index._convert_dict_to_doc(doc_dict, index._schema) assert doc.id == doc_dict_copy['id'] assert np.all(doc.tens == doc_dict_copy['tens']) def test_validate_search_fields(): - store = DummyDocIndex[SimpleDoc]() - assert list(store._column_infos.keys()) == ['id', 'tens'] + index = DummyDocIndex[SimpleDoc]() + assert list(index._column_infos.keys()) == ['id', 'tens'] # 'tens' is a valid field - assert store._validate_search_field(search_field='tens') + assert index._validate_search_field(search_field='tens') # should not fail when an empty string or None is passed - assert store._validate_search_field(search_field='') - store._validate_search_field(search_field=None) + assert index._validate_search_field(search_field='') + index._validate_search_field(search_field=None) # 'ten' is not a valid field with pytest.raises(ValueError): - store._validate_search_field('ten') + index._validate_search_field('ten') diff --git a/tests/index/base_classes/test_configs.py b/tests/index/base_classes/test_configs.py index 8cae5524ec9..cba31ad296f 100644 --- a/tests/index/base_classes/test_configs.py +++ b/tests/index/base_classes/test_configs.py @@ -63,10 +63,10 @@ def python_type_to_db_type(self, x): def test_defaults(): - store = DummyDocIndex[SimpleDoc]() - assert store._db_config.other == 5 - assert store._db_config.work_dir == '.' - assert store._runtime_config.default_column_config[str] == { + index = DummyDocIndex[SimpleDoc]() + assert index._db_config.other == 5 + assert index._db_config.work_dir == '.' + assert index._runtime_config.default_column_config[str] == { 'dim': 128, 'space': 'l2', } @@ -74,39 +74,39 @@ def test_defaults(): def test_set_by_class(): # change all settings - store = DummyDocIndex[SimpleDoc](DBConfig(work_dir='hi', other=10)) - assert store._db_config.other == 10 - assert store._db_config.work_dir == 'hi' - store.configure(RuntimeConfig(default_column_config={}, default_ef=10)) - assert store._runtime_config.default_column_config == {} + index = DummyDocIndex[SimpleDoc](DBConfig(work_dir='hi', other=10)) + assert index._db_config.other == 10 + assert index._db_config.work_dir == 'hi' + index.configure(RuntimeConfig(default_column_config={}, default_ef=10)) + assert index._runtime_config.default_column_config == {} # change only some settings - store = DummyDocIndex[SimpleDoc](DBConfig(work_dir='hi')) - assert store._db_config.other == 5 - assert store._db_config.work_dir == 'hi' - store.configure(RuntimeConfig(default_column_config={})) - assert store._runtime_config.default_column_config == {} + index = DummyDocIndex[SimpleDoc](DBConfig(work_dir='hi')) + assert index._db_config.other == 5 + assert index._db_config.work_dir == 'hi' + index.configure(RuntimeConfig(default_column_config={})) + assert index._runtime_config.default_column_config == {} def test_set_by_kwargs(): # change all settings - store = DummyDocIndex[SimpleDoc](work_dir='hi', other=10) - assert store._db_config.other == 10 - assert store._db_config.work_dir == 'hi' - store.configure(default_column_config={}, default_ef=10) - assert store._runtime_config.default_column_config == {} + index = DummyDocIndex[SimpleDoc](work_dir='hi', other=10) + assert index._db_config.other == 10 + assert index._db_config.work_dir == 'hi' + index.configure(default_column_config={}, default_ef=10) + assert index._runtime_config.default_column_config == {} # change only some settings - store = DummyDocIndex[SimpleDoc](work_dir='hi') - assert store._db_config.other == 5 - assert store._db_config.work_dir == 'hi' - store.configure(default_column_config={}) - assert store._runtime_config.default_column_config == {} + index = DummyDocIndex[SimpleDoc](work_dir='hi') + assert index._db_config.other == 5 + assert index._db_config.work_dir == 'hi' + index.configure(default_column_config={}) + assert index._runtime_config.default_column_config == {} def test_default_column_config(): - store = DummyDocIndex[SimpleDoc]() - assert store._runtime_config.default_column_config == { + index = DummyDocIndex[SimpleDoc]() + assert index._runtime_config.default_column_config == { str: { 'dim': 128, 'space': 'l2', diff --git a/tests/index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py index a0d4aa4dec9..f1fa93d7748 100644 --- a/tests/index/elastic/v7/test_column_config.py +++ b/tests/index/elastic/v7/test_column_config.py @@ -13,20 +13,20 @@ class MyDoc(BaseDoc): text: str color: str = Field(col_type='keyword') - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() index_docs = [ MyDoc(id='0', text='hello world', color='red'), MyDoc(id='1', text='never gonna give you up', color='blue'), MyDoc(id='2', text='we are the world', color='green'), ] - store.index(index_docs) + index.index(index_docs) query = 'world' - docs, _ = store.text_search(query, search_field='text') + docs, _ = index.text_search(query, search_field='text') assert [doc.id for doc in docs] == ['0', '2'] filter_query = {'terms': {'color': ['red', 'blue']}} - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert [doc.id for doc in docs] == ['0', '1'] @@ -44,19 +44,19 @@ class MyDoc(BaseDoc): } ) - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() doc = [ MyDoc(manager={'age': 25, 'name': {'first': 'Rachel', 'last': 'Green'}}), MyDoc(manager={'age': 30, 'name': {'first': 'Monica', 'last': 'Geller'}}), MyDoc(manager={'age': 35, 'name': {'first': 'Phoebe', 'last': 'Buffay'}}), ] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert store[id_].manager == doc[0].manager + assert index[id_].id == id_ + assert index[id_].manager == doc[0].manager filter_query = {'range': {'manager.age': {'gte': 30}}} - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert [doc.id for doc in docs] == [doc[1].id, doc[2].id] @@ -64,13 +64,13 @@ def test_field_geo_point(): class MyDoc(BaseDoc): location: dict = Field(col_type='geo_point') - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() doc = [ MyDoc(location={'lat': 40.12, 'lon': -72.34}), MyDoc(location={'lat': 41.12, 'lon': -73.34}), MyDoc(location={'lat': 42.12, 'lon': -74.34}), ] - store.index(doc) + index.index(doc) query = { 'query': { @@ -83,7 +83,7 @@ class MyDoc(BaseDoc): }, } - docs, _ = store.execute_query(query) + docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] @@ -92,7 +92,7 @@ class MyDoc(BaseDoc): expected_attendees: dict = Field(col_type='integer_range') time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() doc = [ MyDoc( expected_attendees={'gte': 10, 'lt': 20}, @@ -107,7 +107,7 @@ class MyDoc(BaseDoc): time_frame={'gte': '2023-03-01', 'lt': '2023-04-01'}, ), ] - store.index(doc) + index.index(doc) query = { 'query': { @@ -127,5 +127,5 @@ class MyDoc(BaseDoc): } }, } - docs, _ = store.execute_query(query) + docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] diff --git a/tests/index/elastic/v7/test_find.py b/tests/index/elastic/v7/test_find.py index 6665c8b2b60..d54b3b0480d 100644 --- a/tests/index/elastic/v7/test_find.py +++ b/tests/index/elastic/v7/test_find.py @@ -16,13 +16,13 @@ def test_find_simple_schema(): class SimpleSchema(BaseDoc): tens: NdArray[10] - store = ElasticV7DocIndex[SimpleSchema]() + index = ElasticV7DocIndex[SimpleSchema]() index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -36,18 +36,18 @@ class FlatSchema(BaseDoc): tens_one: NdArray = Field(dims=10) tens_two: NdArray = Field(dims=50) - store = ElasticV7DocIndex[FlatSchema]() + index = ElasticV7DocIndex[FlatSchema]() index_docs = [ FlatDoc(tens_one=np.random.rand(10), tens_two=np.random.rand(50)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] # find on tens_one - docs, scores = store.find(query, search_field='tens_one', limit=5) + docs, scores = index.find(query, search_field='tens_one', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -55,7 +55,7 @@ class FlatSchema(BaseDoc): assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) # find on tens_two - docs, scores = store.find(query, search_field='tens_two', limit=5) + docs, scores = index.find(query, search_field='tens_two', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -75,7 +75,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(dims=10) - store = ElasticV7DocIndex[DeepNestedDoc]() + index = ElasticV7DocIndex[DeepNestedDoc]() index_docs = [ DeepNestedDoc( @@ -84,26 +84,26 @@ class DeepNestedDoc(BaseDoc): ) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] # find on root level - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id assert np.allclose(docs[0].tens, index_docs[-1].tens) # find on first nesting level - docs, scores = store.find(query, search_field='d__tens', limit=5) + docs, scores = index.find(query, search_field='d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id assert np.allclose(docs[0].d.tens, index_docs[-1].d.tens) # find on second nesting level - docs, scores = store.find(query, search_field='d__d__tens', limit=5) + docs, scores = index.find(query, search_field='d__d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -114,19 +114,19 @@ def test_find_torch(): class TorchDoc(BaseDoc): tens: TorchTensor[10] - store = ElasticV7DocIndex[TorchDoc]() + index = ElasticV7DocIndex[TorchDoc]() # A dense_vector field stores dense vectors of float values. index_docs = [ TorchDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TorchTensor) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -143,18 +143,18 @@ def test_find_tensorflow(): class TfDoc(BaseDoc): tens: TensorFlowTensor[10] - store = ElasticV7DocIndex[TfDoc]() + index = ElasticV7DocIndex[TfDoc]() index_docs = [ TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TensorFlowTensor) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -168,13 +168,13 @@ class TfDoc(BaseDoc): def test_find_batched(): - store = ElasticV7DocIndex[SimpleDoc]() + index = ElasticV7DocIndex[SimpleDoc]() index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] - store.index(index_docs) + index.index(index_docs) queries = index_docs[-2:] - docs_batched, scores_batched = store.find_batched( + docs_batched, scores_batched = index.find_batched( queries, search_field='tens', limit=5 ) @@ -191,13 +191,13 @@ class MyDoc(BaseDoc): B: int C: float - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() index_docs = [MyDoc(id=f'{i}', A=(i % 2 == 0), B=i, C=i + 0.5) for i in range(10)] - store.index(index_docs) + index.index(index_docs) filter_query = {'term': {'A': True}} - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert len(docs) > 0 for doc in docs: assert doc.A @@ -210,7 +210,7 @@ class MyDoc(BaseDoc): ] } } - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert [doc.id for doc in docs] == ['3', '4'] @@ -218,16 +218,16 @@ def test_text_search(): class MyDoc(BaseDoc): text: str - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() index_docs = [ MyDoc(text='hello world'), MyDoc(text='never gonna give you up'), MyDoc(text='we are the world'), ] - store.index(index_docs) + index.index(index_docs) query = 'world' - docs, scores = store.text_search(query, search_field='text') + docs, scores = index.text_search(query, search_field='text') assert len(docs) == 2 assert len(scores) == 2 @@ -235,7 +235,7 @@ class MyDoc(BaseDoc): assert docs[1].text.index(query) >= 0 queries = ['world', 'never'] - docs, scores = store.text_search_batched(queries, search_field='text') + docs, scores = index.text_search_batched(queries, search_field='text') for query, da, score in zip(queries, docs, scores): assert len(da) > 0 assert len(score) > 0 @@ -249,46 +249,46 @@ class MyDoc(BaseDoc): num: int text: str - store = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc]() index_docs = [ MyDoc( id=f'{i}', tens=np.random.rand(10), num=int(i / 2), text=f'text {int(i/2)}' ) for i in range(10) ] - store.index(index_docs) + index.index(index_docs) # build_query - q = store.build_query() - assert isinstance(q, store.QueryBuilder) + q = index.build_query() + assert isinstance(q, index.QueryBuilder) # filter - q = store.build_query().filter({'term': {'num': 0}}).build() - docs, _ = store.execute_query(q) + q = index.build_query().filter({'term': {'num': 0}}).build() + docs, _ = index.execute_query(q) assert [doc['id'] for doc in docs] == ['0', '1'] # find - q = store.build_query().find(index_docs[-1], search_field='tens', limit=3).build() - docs, scores = store.execute_query(q) + q = index.build_query().find(index_docs[-1], search_field='tens', limit=3).build() + docs, scores = index.execute_query(q) assert len(docs) == 3 assert len(scores) == 3 assert docs[0]['id'] == index_docs[-1].id assert np.allclose(docs[0]['tens'], index_docs[-1].tens) # text search - q = store.build_query().text_search('0', search_field='text').build() - docs, _ = store.execute_query(q) + q = index.build_query().text_search('0', search_field='text').build() + docs, _ = index.execute_query(q) assert [doc['id'] for doc in docs] == ['0', '1'] # combination q = ( - store.build_query() + index.build_query() .filter({'range': {'num': {'lte': 3}}}) .find(index_docs[-1], search_field='tens') .text_search('0', search_field='text') .build() ) - docs, _ = store.execute_query(q) + docs, _ = index.execute_query(q) assert sorted([doc['id'] for doc in docs]) == ['0', '1'] # direct @@ -296,7 +296,7 @@ class MyDoc(BaseDoc): MyDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i/2)}') for i in range(10) ] - store.index(index_docs) + index.index(index_docs) query = { 'query': { @@ -317,5 +317,5 @@ class MyDoc(BaseDoc): } } - docs, _ = store.execute_query(query) + docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == ['7', '6', '5', '4'] diff --git a/tests/index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py index d5ead493c03..e6e6baf3e60 100644 --- a/tests/index/elastic/v7/test_index_get_del.py +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -25,114 +25,114 @@ @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 - store = ElasticV7DocIndex[SimpleDoc]() + index = ElasticV7DocIndex[SimpleDoc]() if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) - store.index(ten_simple_docs) - assert store.num_docs() == 10 + index.index(ten_simple_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 - store = ElasticV7DocIndex[FlatDoc]() + index = ElasticV7DocIndex[FlatDoc]() if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) - store.index(ten_flat_docs) - assert store.num_docs() == 10 + index.index(ten_flat_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 - store = ElasticV7DocIndex[NestedDoc]() + index = ElasticV7DocIndex[NestedDoc]() if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) - store.index(ten_nested_docs) - assert store.num_docs() == 10 + index.index(ten_nested_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 - store = ElasticV7DocIndex[DeepNestedDoc]() + index = ElasticV7DocIndex[DeepNestedDoc]() if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) - store.index(ten_deep_nested_docs) - assert store.num_docs() == 10 + index.index(ten_deep_nested_docs) + assert index.num_docs() == 10 def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 # simple - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_simple_docs: id_ = d.id - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) # flat - store = ElasticV7DocIndex[FlatDoc]() - store.index(ten_flat_docs) + index = ElasticV7DocIndex[FlatDoc]() + index.index(ten_flat_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_flat_docs: id_ = d.id - assert store[id_].id == id_ - assert np.all(store[id_].tens_one == d.tens_one) - assert np.all(store[id_].tens_two == d.tens_two) + assert index[id_].id == id_ + assert np.all(index[id_].tens_one == d.tens_one) + assert np.all(index[id_].tens_two == d.tens_two) # nested - store = ElasticV7DocIndex[NestedDoc]() - store.index(ten_nested_docs) + index = ElasticV7DocIndex[NestedDoc]() + index.index(ten_nested_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_nested_docs: id_ = d.id - assert store[id_].id == id_ - assert store[id_].d.id == d.d.id - assert np.all(store[id_].d.tens == d.d.tens) + assert index[id_].id == id_ + assert index[id_].d.id == d.d.id + assert np.all(index[id_].d.tens == d.d.tens) def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 docs_to_get_idx = [0, 2, 4, 6, 8] # simple - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_simple_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert np.all(d_out.tens == d_in.tens) # flat - store = ElasticV7DocIndex[FlatDoc]() - store.index(ten_flat_docs) + index = ElasticV7DocIndex[FlatDoc]() + index.index(ten_flat_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_flat_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert np.all(d_out.tens_one == d_in.tens_one) assert np.all(d_out.tens_two == d_in.tens_two) # nested - store = ElasticV7DocIndex[NestedDoc]() - store.index(ten_nested_docs) + index = ElasticV7DocIndex[NestedDoc]() + index.index(ten_nested_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_nested_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert d_out.d.id == d_in.d.id @@ -140,94 +140,94 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: def test_get_key_error(ten_simple_docs): # noqa: F811 - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) with pytest.raises(KeyError): - store['not_a_real_id'] + index['not_a_real_id'] def test_persisting(ten_simple_docs): # noqa: F811 - store = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') + index.index(ten_simple_docs) - store2 = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') - assert store2.num_docs() == 10 + index2 = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') + assert index2.num_docs() == 10 def test_del_single(ten_simple_docs): # noqa: F811 - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) # delete once - assert store.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + assert index.num_docs() == 10 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 for i, d in enumerate(ten_simple_docs): id_ = d.id if i == 0: # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) # delete again - del store[ten_simple_docs[3].id] - assert store.num_docs() == 8 + del index[ten_simple_docs[3].id] + assert index.num_docs() == 8 for i, d in enumerate(ten_simple_docs): id_ = d.id if i in (0, 3): # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) def test_del_multiple(ten_simple_docs): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] ids_to_del = [d.id for d in docs_to_del] - del store[ids_to_del] + del index[ids_to_del] for i, doc in enumerate(ten_simple_docs): if i in docs_to_del_idx: with pytest.raises(KeyError): - store[doc.id] + index[doc.id] else: - assert store[doc.id].id == doc.id - assert np.all(store[doc.id].tens == doc.tens) + assert index[doc.id].id == doc.id + assert np.all(index[doc.id].tens == doc.tens) def test_del_key_error(ten_simple_docs): # noqa: F811 - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) with pytest.warns(UserWarning): - del store['not_a_real_id'] + del index['not_a_real_id'] def test_num_docs(ten_simple_docs): # noqa: F811 - store = ElasticV7DocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticV7DocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 - del store[ten_simple_docs[3].id, ten_simple_docs[5].id] - assert store.num_docs() == 7 + del index[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert index.num_docs() == 7 more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] - store.index(more_docs) - assert store.num_docs() == 12 + index.index(more_docs) + assert index.num_docs() == 12 - del store[more_docs[2].id, ten_simple_docs[7].id] - assert store.num_docs() == 10 + del index[more_docs[2].id, ten_simple_docs[7].id] + assert index.num_docs() == 10 def test_index_union_doc(): @@ -237,13 +237,13 @@ class MyDoc(BaseDoc): class MySchema(BaseDoc): tensor: NdArray - store = ElasticV7DocIndex[MySchema]() + index = ElasticV7DocIndex[MySchema]() doc = [MyDoc(tensor=np.random.randn(128))] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert np.all(store[id_].tensor == doc[0].tensor) + assert index[id_].id == id_ + assert np.all(index[id_].tensor == doc[0].tensor) def test_index_multi_modal_doc(): @@ -251,20 +251,20 @@ class MyMultiModalDoc(BaseDoc): image: MyImageDoc text: TextDoc - store = ElasticV7DocIndex[MyMultiModalDoc]() + index = ElasticV7DocIndex[MyMultiModalDoc]() doc = [ MyMultiModalDoc( image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') ) ] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert np.all(store[id_].image.embedding == doc[0].image.embedding) - assert store[id_].text.text == doc[0].text.text + assert index[id_].id == id_ + assert np.all(index[id_].image.embedding == doc[0].image.embedding) + assert index[id_].text.text == doc[0].text.text query = doc[0] - docs, _ = store.find(query, limit=10, search_field='image__embedding') + docs, _ = index.find(query, limit=10, search_field='image__embedding') assert len(docs) > 0 diff --git a/tests/index/elastic/v8/test_column_config.py b/tests/index/elastic/v8/test_column_config.py index 2b3bbcee0f8..0edd105697d 100644 --- a/tests/index/elastic/v8/test_column_config.py +++ b/tests/index/elastic/v8/test_column_config.py @@ -13,20 +13,20 @@ class MyDoc(BaseDoc): text: str color: str = Field(col_type='keyword') - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() index_docs = [ MyDoc(id='0', text='hello world', color='red'), MyDoc(id='1', text='never gonna give you up', color='blue'), MyDoc(id='2', text='we are the world', color='green'), ] - store.index(index_docs) + index.index(index_docs) query = 'world' - docs, _ = store.text_search(query, search_field='text') + docs, _ = index.text_search(query, search_field='text') assert [doc.id for doc in docs] == ['0', '2'] filter_query = {'terms': {'color': ['red', 'blue']}} - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert [doc.id for doc in docs] == ['0', '1'] @@ -44,19 +44,19 @@ class MyDoc(BaseDoc): } ) - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() doc = [ MyDoc(manager={'age': 25, 'name': {'first': 'Rachel', 'last': 'Green'}}), MyDoc(manager={'age': 30, 'name': {'first': 'Monica', 'last': 'Geller'}}), MyDoc(manager={'age': 35, 'name': {'first': 'Phoebe', 'last': 'Buffay'}}), ] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert store[id_].manager == doc[0].manager + assert index[id_].id == id_ + assert index[id_].manager == doc[0].manager filter_query = {'range': {'manager.age': {'gte': 30}}} - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert [doc.id for doc in docs] == [doc[1].id, doc[2].id] @@ -64,13 +64,13 @@ def test_field_geo_point(): class MyDoc(BaseDoc): location: dict = Field(col_type='geo_point') - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() doc = [ MyDoc(location={'lat': 40.12, 'lon': -72.34}), MyDoc(location={'lat': 41.12, 'lon': -73.34}), MyDoc(location={'lat': 42.12, 'lon': -74.34}), ] - store.index(doc) + index.index(doc) query = { 'query': { @@ -83,7 +83,7 @@ class MyDoc(BaseDoc): }, } - docs, _ = store.execute_query(query) + docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] @@ -92,7 +92,7 @@ class MyDoc(BaseDoc): expected_attendees: dict = Field(col_type='integer_range') time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() doc = [ MyDoc( expected_attendees={'gte': 10, 'lt': 20}, @@ -107,7 +107,7 @@ class MyDoc(BaseDoc): time_frame={'gte': '2023-03-01', 'lt': '2023-04-01'}, ), ] - store.index(doc) + index.index(doc) query = { 'query': { @@ -127,5 +127,5 @@ class MyDoc(BaseDoc): } }, } - docs, _ = store.execute_query(query) + docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] diff --git a/tests/index/elastic/v8/test_find.py b/tests/index/elastic/v8/test_find.py index 5ee0956bb87..bb87755254c 100644 --- a/tests/index/elastic/v8/test_find.py +++ b/tests/index/elastic/v8/test_find.py @@ -17,7 +17,7 @@ def test_find_simple_schema(similarity): class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(similarity=similarity) - store = ElasticDocIndex[SimpleSchema]() + index = ElasticDocIndex[SimpleSchema]() index_docs = [] for _ in range(10): @@ -25,10 +25,10 @@ class SimpleSchema(BaseDoc): if similarity == 'dot_product': vec = vec / np.linalg.norm(vec) index_docs.append(SimpleDoc(tens=vec)) - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -42,7 +42,7 @@ class FlatSchema(BaseDoc): tens_one: NdArray = Field(dims=10, similarity=similarity) tens_two: NdArray = Field(dims=50, similarity=similarity) - store = ElasticDocIndex[FlatSchema]() + index = ElasticDocIndex[FlatSchema]() index_docs = [] for _ in range(10): @@ -53,12 +53,12 @@ class FlatSchema(BaseDoc): vec_two = vec_two / np.linalg.norm(vec_two) index_docs.append(FlatDoc(tens_one=vec_one, tens_two=vec_two)) - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] # find on tens_one - docs, scores = store.find(query, search_field='tens_one', limit=5) + docs, scores = index.find(query, search_field='tens_one', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -66,7 +66,7 @@ class FlatSchema(BaseDoc): assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) # find on tens_two - docs, scores = store.find(query, search_field='tens_two', limit=5) + docs, scores = index.find(query, search_field='tens_two', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -87,7 +87,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(similarity=similarity, dims=10) - store = ElasticDocIndex[DeepNestedDoc]() + index = ElasticDocIndex[DeepNestedDoc]() index_docs = [] for _ in range(10): @@ -105,26 +105,26 @@ class DeepNestedDoc(BaseDoc): ) ) - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] # find on root level - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id assert np.allclose(docs[0].tens, index_docs[-1].tens) # find on first nesting level - docs, scores = store.find(query, search_field='d__tens', limit=5) + docs, scores = index.find(query, search_field='d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id assert np.allclose(docs[0].d.tens, index_docs[-1].d.tens) # find on second nesting level - docs, scores = store.find(query, search_field='d__d__tens', limit=5) + docs, scores = index.find(query, search_field='d__d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -135,19 +135,19 @@ def test_find_torch(): class TorchDoc(BaseDoc): tens: TorchTensor[10] - store = ElasticDocIndex[TorchDoc]() + index = ElasticDocIndex[TorchDoc]() # A dense_vector field stores dense vectors of float values. index_docs = [ TorchDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TorchTensor) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -164,18 +164,18 @@ def test_find_tensorflow(): class TfDoc(BaseDoc): tens: TensorFlowTensor[10] - store = ElasticDocIndex[TfDoc]() + index = ElasticDocIndex[TfDoc]() index_docs = [ TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TensorFlowTensor) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -189,13 +189,13 @@ class TfDoc(BaseDoc): def test_find_batched(): - store = ElasticDocIndex[SimpleDoc]() + index = ElasticDocIndex[SimpleDoc]() index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] - store.index(index_docs) + index.index(index_docs) queries = index_docs[-2:] - docs_batched, scores_batched = store.find_batched( + docs_batched, scores_batched = index.find_batched( queries, search_field='tens', limit=5 ) @@ -212,13 +212,13 @@ class MyDoc(BaseDoc): B: int C: float - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() index_docs = [MyDoc(id=f'{i}', A=(i % 2 == 0), B=i, C=i + 0.5) for i in range(10)] - store.index(index_docs) + index.index(index_docs) filter_query = {'term': {'A': True}} - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert len(docs) > 0 for doc in docs: assert doc.A @@ -231,7 +231,7 @@ class MyDoc(BaseDoc): ] } } - docs = store.filter(filter_query) + docs = index.filter(filter_query) assert [doc.id for doc in docs] == ['3', '4'] @@ -239,16 +239,16 @@ def test_text_search(): class MyDoc(BaseDoc): text: str - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() index_docs = [ MyDoc(text='hello world'), MyDoc(text='never gonna give you up'), MyDoc(text='we are the world'), ] - store.index(index_docs) + index.index(index_docs) query = 'world' - docs, scores = store.text_search(query, search_field='text') + docs, scores = index.text_search(query, search_field='text') assert len(docs) == 2 assert len(scores) == 2 @@ -256,7 +256,7 @@ class MyDoc(BaseDoc): assert docs[1].text.index(query) >= 0 queries = ['world', 'never'] - docs, scores = store.text_search_batched(queries, search_field='text') + docs, scores = index.text_search_batched(queries, search_field='text') for query, da, score in zip(queries, docs, scores): assert len(da) > 0 assert len(score) > 0 @@ -270,41 +270,41 @@ class MyDoc(BaseDoc): num: int text: str - store = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc]() index_docs = [ MyDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i/2)}') for i in range(10) ] - store.index(index_docs) + index.index(index_docs) # build_query - q = store.build_query() - assert isinstance(q, store.QueryBuilder) + q = index.build_query() + assert isinstance(q, index.QueryBuilder) # filter - q = store.build_query().filter({'term': {'num': 0}}).build() - docs, _ = store.execute_query(q) + q = index.build_query().filter({'term': {'num': 0}}).build() + docs, _ = index.execute_query(q) assert [doc['id'] for doc in docs] == ['0', '1'] # find - q = store.build_query().find(index_docs[-1], search_field='tens', limit=3).build() - docs, _ = store.execute_query(q) + q = index.build_query().find(index_docs[-1], search_field='tens', limit=3).build() + docs, _ = index.execute_query(q) assert [doc['id'] for doc in docs] == ['9', '8', '7'] # text_search - q = store.build_query().text_search('0', search_field='text').build() - docs, _ = store.execute_query(q) + q = index.build_query().text_search('0', search_field='text').build() + docs, _ = index.execute_query(q) assert [doc['id'] for doc in docs] == ['0', '1'] # combination q = ( - store.build_query() + index.build_query() .filter({'range': {'num': {'lte': 3}}}) .find(index_docs[-1], search_field='tens') .text_search('0', search_field='text') .build() ) - docs, _ = store.execute_query(q) + docs, _ = index.execute_query(q) assert [doc['id'] for doc in docs] == ['1', '0'] # direct @@ -325,5 +325,5 @@ class MyDoc(BaseDoc): }, } - docs, _ = store.execute_query(query) + docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == ['7', '6', '5', '4'] diff --git a/tests/index/elastic/v8/test_index_get_del.py b/tests/index/elastic/v8/test_index_get_del.py index 03560caae7d..8efd66429b0 100644 --- a/tests/index/elastic/v8/test_index_get_del.py +++ b/tests/index/elastic/v8/test_index_get_del.py @@ -25,114 +25,114 @@ @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 - store = ElasticDocIndex[SimpleDoc]() + index = ElasticDocIndex[SimpleDoc]() if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) - store.index(ten_simple_docs) - assert store.num_docs() == 10 + index.index(ten_simple_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 - store = ElasticDocIndex[FlatDoc]() + index = ElasticDocIndex[FlatDoc]() if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) - store.index(ten_flat_docs) - assert store.num_docs() == 10 + index.index(ten_flat_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 - store = ElasticDocIndex[NestedDoc]() + index = ElasticDocIndex[NestedDoc]() if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) - store.index(ten_nested_docs) - assert store.num_docs() == 10 + index.index(ten_nested_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 - store = ElasticDocIndex[DeepNestedDoc]() + index = ElasticDocIndex[DeepNestedDoc]() if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) - store.index(ten_deep_nested_docs) - assert store.num_docs() == 10 + index.index(ten_deep_nested_docs) + assert index.num_docs() == 10 def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 # simple - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_simple_docs: id_ = d.id - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) # flat - store = ElasticDocIndex[FlatDoc]() - store.index(ten_flat_docs) + index = ElasticDocIndex[FlatDoc]() + index.index(ten_flat_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_flat_docs: id_ = d.id - assert store[id_].id == id_ - assert np.all(store[id_].tens_one == d.tens_one) - assert np.all(store[id_].tens_two == d.tens_two) + assert index[id_].id == id_ + assert np.all(index[id_].tens_one == d.tens_one) + assert np.all(index[id_].tens_two == d.tens_two) # nested - store = ElasticDocIndex[NestedDoc]() - store.index(ten_nested_docs) + index = ElasticDocIndex[NestedDoc]() + index.index(ten_nested_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_nested_docs: id_ = d.id - assert store[id_].id == id_ - assert store[id_].d.id == d.d.id - assert np.all(store[id_].d.tens == d.d.tens) + assert index[id_].id == id_ + assert index[id_].d.id == d.d.id + assert np.all(index[id_].d.tens == d.d.tens) def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 docs_to_get_idx = [0, 2, 4, 6, 8] # simple - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_simple_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert np.all(d_out.tens == d_in.tens) # flat - store = ElasticDocIndex[FlatDoc]() - store.index(ten_flat_docs) + index = ElasticDocIndex[FlatDoc]() + index.index(ten_flat_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_flat_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert np.all(d_out.tens_one == d_in.tens_one) assert np.all(d_out.tens_two == d_in.tens_two) # nested - store = ElasticDocIndex[NestedDoc]() - store.index(ten_nested_docs) + index = ElasticDocIndex[NestedDoc]() + index.index(ten_nested_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_nested_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert d_out.d.id == d_in.d.id @@ -140,94 +140,94 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: def test_get_key_error(ten_simple_docs): # noqa: F811 - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) with pytest.raises(KeyError): - store['not_a_real_id'] + index['not_a_real_id'] def test_persisting(ten_simple_docs): # noqa: F811 - store = ElasticDocIndex[SimpleDoc](index_name='test_persisting') - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc](index_name='test_persisting') + index.index(ten_simple_docs) - store2 = ElasticDocIndex[SimpleDoc](index_name='test_persisting') - assert store2.num_docs() == 10 + index2 = ElasticDocIndex[SimpleDoc](index_name='test_persisting') + assert index2.num_docs() == 10 def test_del_single(ten_simple_docs): # noqa: F811 - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) # delete once - assert store.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + assert index.num_docs() == 10 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 for i, d in enumerate(ten_simple_docs): id_ = d.id if i == 0: # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) # delete again - del store[ten_simple_docs[3].id] - assert store.num_docs() == 8 + del index[ten_simple_docs[3].id] + assert index.num_docs() == 8 for i, d in enumerate(ten_simple_docs): id_ = d.id if i in (0, 3): # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) def test_del_multiple(ten_simple_docs): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] ids_to_del = [d.id for d in docs_to_del] - del store[ids_to_del] + del index[ids_to_del] for i, doc in enumerate(ten_simple_docs): if i in docs_to_del_idx: with pytest.raises(KeyError): - store[doc.id] + index[doc.id] else: - assert store[doc.id].id == doc.id - assert np.all(store[doc.id].tens == doc.tens) + assert index[doc.id].id == doc.id + assert np.all(index[doc.id].tens == doc.tens) def test_del_key_error(ten_simple_docs): # noqa: F811 - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) with pytest.warns(UserWarning): - del store['not_a_real_id'] + del index['not_a_real_id'] def test_num_docs(ten_simple_docs): # noqa: F811 - store = ElasticDocIndex[SimpleDoc]() - store.index(ten_simple_docs) + index = ElasticDocIndex[SimpleDoc]() + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 - del store[ten_simple_docs[3].id, ten_simple_docs[5].id] - assert store.num_docs() == 7 + del index[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert index.num_docs() == 7 more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] - store.index(more_docs) - assert store.num_docs() == 12 + index.index(more_docs) + assert index.num_docs() == 12 - del store[more_docs[2].id, ten_simple_docs[7].id] - assert store.num_docs() == 10 + del index[more_docs[2].id, ten_simple_docs[7].id] + assert index.num_docs() == 10 def test_index_union_doc(): # noqa: F811 @@ -237,13 +237,13 @@ class MyDoc(BaseDoc): class MySchema(BaseDoc): tensor: NdArray[128] - store = ElasticDocIndex[MySchema]() + index = ElasticDocIndex[MySchema]() doc = [MyDoc(tensor=np.random.randn(128))] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert np.all(store[id_].tensor == doc[0].tensor) + assert index[id_].id == id_ + assert np.all(index[id_].tensor == doc[0].tensor) def test_index_multi_modal_doc(): @@ -251,22 +251,22 @@ class MyMultiModalDoc(BaseDoc): image: MyImageDoc text: TextDoc - store = ElasticDocIndex[MyMultiModalDoc]() + index = ElasticDocIndex[MyMultiModalDoc]() doc = [ MyMultiModalDoc( image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') ) ] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert np.all(store[id_].image.embedding == doc[0].image.embedding) - assert store[id_].text.text == doc[0].text.text + assert index[id_].id == id_ + assert np.all(index[id_].image.embedding == doc[0].image.embedding) + assert index[id_].text.text == doc[0].text.text query = doc[0] - docs, _ = store.find(query, limit=10, search_field='image__embedding') + docs, _ = index.find(query, limit=10, search_field='image__embedding') assert len(docs) > 0 diff --git a/tests/index/hnswlib/test_find.py b/tests/index/hnswlib/test_find.py index 0aca0383a94..bfaf5e7c1e6 100644 --- a/tests/index/hnswlib/test_find.py +++ b/tests/index/hnswlib/test_find.py @@ -36,15 +36,15 @@ def test_find_simple_schema(tmp_path, space): class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) - store = HnswDocumentIndex[SimpleSchema](work_dir=str(tmp_path)) + index = HnswDocumentIndex[SimpleSchema](work_dir=str(tmp_path)) index_docs = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(SimpleDoc(tens=np.ones(10))) - store.index(index_docs) + index.index(index_docs) query = SimpleDoc(tens=np.ones(10)) - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -56,18 +56,18 @@ class SimpleSchema(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_torch(tmp_path, space): - store = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path)) index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(TorchDoc(tens=np.ones(10))) - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TorchTensor) query = TorchDoc(tens=np.ones(10)) - result_docs, scores = store.find(query, search_field='tens', limit=5) + result_docs, scores = index.find(query, search_field='tens', limit=5) assert len(result_docs) == 5 assert len(scores) == 5 @@ -86,18 +86,18 @@ def test_find_tensorflow(tmp_path): class TfDoc(BaseDoc): tens: TensorFlowTensor[10] - store = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path)) index_docs = [TfDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(TfDoc(tens=np.ones(10))) - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TensorFlowTensor) query = TfDoc(tens=np.ones(10)) - result_docs, scores = store.find(query, search_field='tens', limit=5) + result_docs, scores = index.find(query, search_field='tens', limit=5) assert len(result_docs) == 5 assert len(scores) == 5 @@ -117,19 +117,19 @@ class FlatSchema(BaseDoc): tens_one: NdArray = Field(dim=10, space=space) tens_two: NdArray = Field(dim=50, space=space) - store = HnswDocumentIndex[FlatSchema](work_dir=str(tmp_path)) + index = HnswDocumentIndex[FlatSchema](work_dir=str(tmp_path)) index_docs = [ FlatDoc(tens_one=np.zeros(10), tens_two=np.zeros(50)) for _ in range(10) ] index_docs.append(FlatDoc(tens_one=np.zeros(10), tens_two=np.ones(50))) index_docs.append(FlatDoc(tens_one=np.ones(10), tens_two=np.zeros(50))) - store.index(index_docs) + index.index(index_docs) query = FlatDoc(tens_one=np.ones(10), tens_two=np.ones(50)) # find on tens_one - docs, scores = store.find(query, search_field='tens_one', limit=5) + docs, scores = index.find(query, search_field='tens_one', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id @@ -137,7 +137,7 @@ class FlatSchema(BaseDoc): assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) # find on tens_two - docs, scores = store.find(query, search_field='tens_two', limit=5) + docs, scores = index.find(query, search_field='tens_two', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-2].id @@ -158,7 +158,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(space=space, dim=10) - store = HnswDocumentIndex[DeepNestedDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[DeepNestedDoc](work_dir=str(tmp_path)) index_docs = [ DeepNestedDoc( @@ -185,28 +185,28 @@ class DeepNestedDoc(BaseDoc): tens=np.ones(10), ) ) - store.index(index_docs) + index.index(index_docs) query = DeepNestedDoc( d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.ones(10)), tens=np.ones(10) ) # find on root level - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-1].id assert np.allclose(docs[0].tens, index_docs[-1].tens) # find on first nesting level - docs, scores = store.find(query, search_field='d__tens', limit=5) + docs, scores = index.find(query, search_field='d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-2].id assert np.allclose(docs[0].d.tens, index_docs[-2].d.tens) # find on second nesting level - docs, scores = store.find(query, search_field='d__d__tens', limit=5) + docs, scores = index.find(query, search_field='d__d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 assert docs[0].id == index_docs[-3].id diff --git a/tests/index/hnswlib/test_index_get_del.py b/tests/index/hnswlib/test_index_get_del.py index d9437878698..77eca0efd86 100644 --- a/tests/index/hnswlib/test_index_get_del.py +++ b/tests/index/hnswlib/test_index_get_del.py @@ -55,13 +55,13 @@ def ten_nested_docs(): @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_simple_schema(ten_simple_docs, tmp_path, use_docarray): - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) - store.index(ten_simple_docs) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + index.index(ten_simple_docs) + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 @@ -69,31 +69,31 @@ def test_schema_with_user_defined_mapping(tmp_path): class MyDoc(BaseDoc): tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray) - store = HnswDocumentIndex[MyDoc](work_dir=str(tmp_path)) - assert store._column_infos['tens'].db_type == np.ndarray + index = HnswDocumentIndex[MyDoc](work_dir=str(tmp_path)) + assert index._column_infos['tens'].db_type == np.ndarray @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_flat_schema(ten_flat_docs, tmp_path, use_docarray): - store = HnswDocumentIndex[FlatDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[FlatDoc](work_dir=str(tmp_path)) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) - store.index(ten_flat_docs) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + index.index(ten_flat_docs) + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_nested_schema(ten_nested_docs, tmp_path, use_docarray): - store = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) - store.index(ten_nested_docs) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + index.index(ten_nested_docs) + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 @@ -102,11 +102,11 @@ def test_index_torch(tmp_path): assert isinstance(docs[0].tens, torch.Tensor) assert isinstance(docs[0].tens, TorchTensor) - store = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path)) - store.index(docs) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + index.index(docs) + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 @@ -121,11 +121,11 @@ class TfDoc(BaseDoc): # assert isinstance(docs[0].tens, torch.Tensor) assert isinstance(docs[0].tens, TensorFlowTensor) - store = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path)) + index = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path)) - store.index(docs) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + index.index(docs) + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 @@ -134,26 +134,26 @@ def test_index_builtin_docs(tmp_path): class TextSchema(TextDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - store = HnswDocumentIndex[TextSchema](work_dir=str(tmp_path)) + index = HnswDocumentIndex[TextSchema](work_dir=str(tmp_path)) - store.index( + index.index( DocList[TextDoc]( [TextDoc(embedding=np.random.randn(10), text=f'{i}') for i in range(10)] ) ) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 # ImageDoc class ImageSchema(ImageDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - store = HnswDocumentIndex[ImageSchema]( + index = HnswDocumentIndex[ImageSchema]( work_dir=str(os.path.join(tmp_path, 'image')) ) - store.index( + index.index( DocList[ImageDoc]( [ ImageDoc( @@ -163,8 +163,8 @@ class ImageSchema(ImageDoc): ] ) ) - assert store.num_docs() == 10 - for index in store._hnsw_indices.values(): + assert index.num_docs() == 10 + for index in index._hnsw_indices.values(): assert index.get_current_count() == 10 @@ -174,36 +174,36 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): nested_path = tmp_path / 'nested' # simple - store = HnswDocumentIndex[SimpleDoc](work_dir=str(simple_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(simple_path)) + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_simple_docs: id_ = d.id - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) # flat - store = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path)) - store.index(ten_flat_docs) + index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path)) + index.index(ten_flat_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_flat_docs: id_ = d.id - assert store[id_].id == id_ - assert np.all(store[id_].tens_one == d.tens_one) - assert np.all(store[id_].tens_two == d.tens_two) + assert index[id_].id == id_ + assert np.all(index[id_].tens_one == d.tens_one) + assert np.all(index[id_].tens_two == d.tens_two) # nested - store = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path)) - store.index(ten_nested_docs) + index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path)) + index.index(ten_nested_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 for d in ten_nested_docs: id_ = d.id - assert store[id_].id == id_ - assert store[id_].d.id == d.d.id - assert np.all(store[id_].d.tens == d.d.tens) + assert index[id_].id == id_ + assert index[id_].d.id == d.d.id + assert np.all(index[id_].d.tens == d.d.tens) def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): @@ -213,38 +213,38 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path) docs_to_get_idx = [0, 2, 4, 6, 8] # simple - store = HnswDocumentIndex[SimpleDoc](work_dir=str(simple_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(simple_path)) + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_simple_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert np.all(d_out.tens == d_in.tens) # flat - store = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path)) - store.index(ten_flat_docs) + index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path)) + index.index(ten_flat_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_flat_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert np.all(d_out.tens_one == d_in.tens_one) assert np.all(d_out.tens_two == d_in.tens_two) # nested - store = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path)) - store.index(ten_nested_docs) + index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path)) + index.index(ten_nested_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_get = [ten_nested_docs[i] for i in docs_to_get_idx] ids_to_get = [d.id for d in docs_to_get] - retrieved_docs = store[ids_to_get] + retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert d_out.d.id == d_in.d.id @@ -252,83 +252,83 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path) def test_get_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + index.index(ten_simple_docs) with pytest.raises(KeyError): - store['not_a_real_id'] + index['not_a_real_id'] def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + index.index(ten_simple_docs) # delete once - assert store.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + assert index.num_docs() == 10 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 for i, d in enumerate(ten_simple_docs): id_ = d.id if i == 0: # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) # delete again - del store[ten_simple_docs[3].id] - assert store.num_docs() == 8 + del index[ten_simple_docs[3].id] + assert index.num_docs() == 8 for i, d in enumerate(ten_simple_docs): id_ = d.id if i in (0, 3): # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ - assert np.all(store[id_].tens == d.tens) + assert index[id_].id == id_ + assert np.all(index[id_].tens == d.tens) def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): docs_to_del_idx = [0, 2, 4, 6, 8] - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] ids_to_del = [d.id for d in docs_to_del] - del store[ids_to_del] + del index[ids_to_del] for i, doc in enumerate(ten_simple_docs): if i in docs_to_del_idx: with pytest.raises(KeyError): - store[doc.id] + index[doc.id] else: - assert store[doc.id].id == doc.id - assert np.all(store[doc.id].tens == doc.tens) + assert index[doc.id].id == doc.id + assert np.all(index[doc.id].tens == doc.tens) def test_del_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + index.index(ten_simple_docs) with pytest.raises(KeyError): - del store['not_a_real_id'] + del index['not_a_real_id'] def test_num_docs(ten_simple_docs, tmp_path): - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - store.index(ten_simple_docs) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 - del store[ten_simple_docs[3].id, ten_simple_docs[5].id] - assert store.num_docs() == 7 + del index[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert index.num_docs() == 7 more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] - store.index(more_docs) - assert store.num_docs() == 12 + index.index(more_docs) + assert index.num_docs() == 12 - del store[more_docs[2].id, ten_simple_docs[7].id] - assert store.num_docs() == 10 + del index[more_docs[2].id, ten_simple_docs[7].id] + assert index.num_docs() == 10 diff --git a/tests/index/hnswlib/test_persist_data.py b/tests/index/hnswlib/test_persist_data.py index 7724d5408c5..fab761582c1 100644 --- a/tests/index/hnswlib/test_persist_data.py +++ b/tests/index/hnswlib/test_persist_data.py @@ -22,23 +22,27 @@ def test_persist_and_restore(tmp_path): query = SimpleDoc(tens=np.random.random((10,))) # create index - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - store.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) - assert store.num_docs() == 10 - find_results_before = store.find(query, search_field='tens', limit=5) + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + + # load existing index file + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + assert index.num_docs() == 0 + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) + assert index.num_docs() == 10 + find_results_before = index.find(query, search_field='tens', limit=5) # delete and restore - del store - store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - assert store.num_docs() == 10 - find_results_after = store.find(query, search_field='tens', limit=5) + del index + index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + assert index.num_docs() == 10 + find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): assert doc_before.id == doc_after.id assert (doc_before.tens == doc_after.tens).all() # add new data - store.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) - assert store.num_docs() == 15 + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) + assert index.num_docs() == 15 def test_persist_and_restore_nested(tmp_path): @@ -47,8 +51,8 @@ def test_persist_and_restore_nested(tmp_path): ) # create index - store = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) - store.index( + index = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) + index.index( [ NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) @@ -56,20 +60,20 @@ def test_persist_and_restore_nested(tmp_path): for _ in range(10) ] ) - assert store.num_docs() == 10 - find_results_before = store.find(query, search_field='d__tens', limit=5) + assert index.num_docs() == 10 + find_results_before = index.find(query, search_field='d__tens', limit=5) # delete and restore - del store - store = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) - assert store.num_docs() == 10 - find_results_after = store.find(query, search_field='d__tens', limit=5) + del index + index = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) + assert index.num_docs() == 10 + find_results_after = index.find(query, search_field='d__tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): assert doc_before.id == doc_after.id assert (doc_before.tens == doc_after.tens).all() # delete and restore - store.index( + index.index( [ NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) @@ -77,9 +81,4 @@ def test_persist_and_restore_nested(tmp_path): for _ in range(5) ] ) - assert store.num_docs() == 15 - - -def test_persist_index_file(tmp_path): - _ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) - _ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + assert index.num_docs() == 15 diff --git a/tests/index/qdrant/test_filter.py b/tests/index/qdrant/test_filter.py index a4278ef62d9..cff4ca367ac 100644 --- a/tests/index/qdrant/test_filter.py +++ b/tests/index/qdrant/test_filter.py @@ -1,16 +1,11 @@ -import pytest -import qdrant_client import numpy as np - from pydantic import Field +from qdrant_client.http import models as rest from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray - -from qdrant_client.http import models as rest - -from .fixtures import qdrant_config, qdrant # ignore: type[import] +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 class SimpleDoc(BaseDoc): @@ -18,12 +13,12 @@ class SimpleDoc(BaseDoc): number: int -def test_filter_range(qdrant_config, qdrant): +def test_filter_range(qdrant_config): # noqa: F811 class SimpleSchema(BaseDoc): embedding: NdArray[10] = Field(space='cosine') # type: ignore[valid-type] number: int - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) index_docs = [ SimpleDoc( @@ -32,11 +27,11 @@ class SimpleSchema(BaseDoc): ) for i in range(10) ] - store.index(index_docs) + index.index(index_docs) filter_query = rest.Filter( must=[rest.FieldCondition(key='number', range=rest.Range(gte=5, lte=7))] ) - docs = store.filter(filter_query, limit=5) + docs = index.filter(filter_query, limit=5) assert len(docs) == 3 diff --git a/tests/index/qdrant/test_find.py b/tests/index/qdrant/test_find.py index c508b0de9fe..610695e5c81 100644 --- a/tests/index/qdrant/test_find.py +++ b/tests/index/qdrant/test_find.py @@ -5,8 +5,7 @@ from docarray import BaseDoc, DocList from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray, TorchTensor - -from .fixtures import qdrant_config, qdrant +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -33,38 +32,38 @@ class TorchDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_simple_schema(qdrant_config, space, qdrant): +def test_find_simple_schema(qdrant_config, space): # noqa: F811 class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) index_docs = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(SimpleDoc(tens=np.ones(10))) - store.index(index_docs) + index.index(index_docs) query = SimpleDoc(tens=np.ones(10)) - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_torch(qdrant_config, space, qdrant): - store = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) +def test_find_torch(qdrant_config, space): # noqa: F811 + index = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(TorchDoc(tens=np.ones(10))) - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TorchTensor) query = TorchDoc(tens=np.ones(10)) - result_docs, scores = store.find(query, search_field='tens', limit=5) + result_docs, scores = index.find(query, search_field='tens', limit=5) assert len(result_docs) == 5 assert len(scores) == 5 @@ -74,24 +73,24 @@ def test_find_torch(qdrant_config, space, qdrant): @pytest.mark.tensorflow @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_tensorflow(qdrant_config, space, qdrant): +def test_find_tensorflow(qdrant_config, space): # noqa: F811 from docarray.typing import TensorFlowTensor class TfDoc(BaseDoc): tens: TensorFlowTensor[10] # type: ignore[valid-type] - store = QdrantDocumentIndex[TfDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[TfDoc](db_config=qdrant_config) index_docs = [ TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TensorFlowTensor) query = index_docs[-1] - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -102,35 +101,35 @@ class TfDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_flat_schema(qdrant_config, space, qdrant): +def test_find_flat_schema(qdrant_config, space): # noqa: F811 class FlatSchema(BaseDoc): tens_one: NdArray = Field(dim=10, space=space) tens_two: NdArray = Field(dim=50, space=space) - store = QdrantDocumentIndex[FlatSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[FlatSchema](db_config=qdrant_config) index_docs = [ FlatDoc(tens_one=np.zeros(10), tens_two=np.zeros(50)) for _ in range(10) ] index_docs.append(FlatDoc(tens_one=np.zeros(10), tens_two=np.ones(50))) index_docs.append(FlatDoc(tens_one=np.ones(10), tens_two=np.zeros(50))) - store.index(index_docs) + index.index(index_docs) query = FlatDoc(tens_one=np.ones(10), tens_two=np.ones(50)) # find on tens_one - docs, scores = store.find(query, search_field='tens_one', limit=5) + docs, scores = index.find(query, search_field='tens_one', limit=5) assert len(docs) == 5 assert len(scores) == 5 # find on tens_two - docs, scores = store.find(query, search_field='tens_two', limit=5) + docs, scores = index.find(query, search_field='tens_two', limit=5) assert len(docs) == 5 assert len(scores) == 5 @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_nested_schema(qdrant_config, space, qdrant): +def test_find_nested_schema(qdrant_config, space): # noqa: F811 class SimpleDoc(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] @@ -142,7 +141,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(space=space, dim=10) - store = QdrantDocumentIndex[DeepNestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[DeepNestedDoc](db_config=qdrant_config) index_docs = [ DeepNestedDoc( @@ -169,37 +168,37 @@ class DeepNestedDoc(BaseDoc): tens=np.ones(10), ) ) - store.index(index_docs) + index.index(index_docs) query = DeepNestedDoc( d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.ones(10)), tens=np.ones(10) ) # find on root level - docs, scores = store.find(query, search_field='tens', limit=5) + docs, scores = index.find(query, search_field='tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 # find on first nesting level - docs, scores = store.find(query, search_field='d__tens', limit=5) + docs, scores = index.find(query, search_field='d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 # find on second nesting level - docs, scores = store.find(query, search_field='d__d__tens', limit=5) + docs, scores = index.find(query, search_field='d__d__tens', limit=5) assert len(docs) == 5 assert len(scores) == 5 @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_batched(qdrant_config, space, qdrant): +def test_find_batched(qdrant_config, space): # noqa: F811 class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) index_docs = [SimpleDoc(tens=vector) for vector in np.identity(10)] - store.index(index_docs) + index.index(index_docs) queries = DocList[SimpleDoc]( [ @@ -212,7 +211,7 @@ class SimpleSchema(BaseDoc): ] ) - docs, scores = store.find_batched(queries, search_field='tens', limit=1) + docs, scores = index.find_batched(queries, search_field='tens', limit=1) assert len(docs) == 2 assert len(docs[0]) == 1 diff --git a/tests/index/qdrant/test_index_get_del.py b/tests/index/qdrant/test_index_get_del.py index 3ddd060817f..a1db816e58c 100644 --- a/tests/index/qdrant/test_index_get_del.py +++ b/tests/index/qdrant/test_index_get_del.py @@ -10,8 +10,7 @@ from docarray.documents import ImageDoc, TextDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray, NdArrayEmbedding, TorchTensor - -from .fixtures import qdrant_config, qdrant # ignore: type[import] +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -56,68 +55,72 @@ def ten_nested_docs(): @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_simple_schema(ten_simple_docs, qdrant_config, use_docarray, qdrant): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_index_simple_schema( + ten_simple_docs, qdrant_config, use_docarray # noqa: F811 +): + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) - store.index(ten_simple_docs) - assert store.num_docs() == 10 + index.index(ten_simple_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, qdrant_config, use_docarray, qdrant): - store = QdrantDocumentIndex[FlatDoc](db_config=qdrant_config) +def test_index_flat_schema(ten_flat_docs, qdrant_config, use_docarray): # noqa: F811 + index = QdrantDocumentIndex[FlatDoc](db_config=qdrant_config) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) - store.index(ten_flat_docs) - assert store.num_docs() == 10 + index.index(ten_flat_docs) + assert index.num_docs() == 10 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_nested_schema(ten_nested_docs, qdrant_config, use_docarray, qdrant): - store = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) +def test_index_nested_schema( + ten_nested_docs, qdrant_config, use_docarray # noqa: F811 +): + index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) - store.index(ten_nested_docs) - assert store.num_docs() == 10 + index.index(ten_nested_docs) + assert index.num_docs() == 10 -def test_index_torch(qdrant_config, qdrant): +def test_index_torch(qdrant_config): # noqa: F811 docs = [TorchDoc(tens=np.random.randn(10)) for _ in range(10)] assert isinstance(docs[0].tens, torch.Tensor) assert isinstance(docs[0].tens, TorchTensor) - store = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) - store.index(docs) - assert store.num_docs() == 10 + index.index(docs) + assert index.num_docs() == 10 @pytest.mark.skip('Qdrant does not support storing image tensors yet') -def test_index_builtin_docs(qdrant_config, qdrant): +def test_index_builtin_docs(qdrant_config): # noqa: F811 # TextDoc class TextSchema(TextDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - store = QdrantDocumentIndex[TextSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[TextSchema](db_config=qdrant_config) - store.index( + index.index( DocList[TextDoc]( [TextDoc(embedding=np.random.randn(10), text=f'{i}') for i in range(10)] ) ) - assert store.num_docs() == 10 + assert index.num_docs() == 10 # ImageDoc class ImageSchema(ImageDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - store = QdrantDocumentIndex[ImageSchema](collection_name='images') # type: ignore[assignment] + index = QdrantDocumentIndex[ImageSchema](collection_name='images') # type: ignore[assignment] - store.index( + index.index( DocList[ImageDoc]( [ ImageDoc( @@ -127,114 +130,106 @@ class ImageSchema(ImageDoc): ] ) ) - assert store.num_docs() == 10 + assert index.num_docs() == 10 -def test_get_key_error( - ten_simple_docs, ten_flat_docs, ten_nested_docs, qdrant_config, qdrant -): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(ten_simple_docs) +def test_get_key_error(ten_simple_docs, qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(ten_simple_docs) with pytest.raises(KeyError): - store['not_a_real_id'] + index['not_a_real_id'] -def test_del_single( - ten_simple_docs, ten_flat_docs, ten_nested_docs, qdrant_config, qdrant -): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(ten_simple_docs) +def test_del_single(ten_simple_docs, qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(ten_simple_docs) # delete once - assert store.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + assert index.num_docs() == 10 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 for i, d in enumerate(ten_simple_docs): id_ = d.id if i == 0: # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ + assert index[id_].id == id_ # delete again - del store[ten_simple_docs[3].id] - assert store.num_docs() == 8 + del index[ten_simple_docs[3].id] + assert index.num_docs() == 8 for i, d in enumerate(ten_simple_docs): id_ = d.id if i in (0, 3): # deleted with pytest.raises(KeyError): - store[id_] + index[id_] else: - assert store[id_].id == id_ + assert index[id_].id == id_ -def test_del_multiple( - ten_simple_docs, ten_flat_docs, ten_nested_docs, qdrant_config, qdrant -): +def test_del_multiple(ten_simple_docs, qdrant_config): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(ten_simple_docs) + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] ids_to_del = [d.id for d in docs_to_del] - del store[ids_to_del] + del index[ids_to_del] for i, doc in enumerate(ten_simple_docs): if i in docs_to_del_idx: with pytest.raises(KeyError): - store[doc.id] + index[doc.id] else: - assert store[doc.id].id == doc.id + assert index[doc.id].id == doc.id -def test_del_key_error( - ten_simple_docs, ten_flat_docs, ten_nested_docs, qdrant_config, qdrant -): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(ten_simple_docs) +def test_del_key_error(ten_simple_docs, qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(ten_simple_docs) with pytest.raises(KeyError): - del store['not_a_real_id'] + del index['not_a_real_id'] -def test_num_docs(ten_simple_docs, qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(ten_simple_docs) +def test_num_docs(ten_simple_docs, qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(ten_simple_docs) - assert store.num_docs() == 10 + assert index.num_docs() == 10 - del store[ten_simple_docs[0].id] - assert store.num_docs() == 9 + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 - del store[ten_simple_docs[3].id, ten_simple_docs[5].id] - assert store.num_docs() == 7 + del index[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert index.num_docs() == 7 more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] - store.index(more_docs) - assert store.num_docs() == 12 + index.index(more_docs) + assert index.num_docs() == 12 - del store[more_docs[2].id, ten_simple_docs[7].id] # type: ignore[arg-type] - assert store.num_docs() == 10 + del index[more_docs[2].id, ten_simple_docs[7].id] # type: ignore[arg-type] + assert index.num_docs() == 10 -def test_multimodal_doc(qdrant_config, qdrant): +def test_multimodal_doc(qdrant_config): # noqa: F811 class MyMultiModalDoc(BaseDoc): image: ImageDoc text: TextDoc - store = QdrantDocumentIndex[MyMultiModalDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[MyMultiModalDoc](db_config=qdrant_config) doc = [ MyMultiModalDoc( image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') ) ] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ # type: ignore[index] - assert cosine(store[id_].image.embedding, doc[0].image.embedding) == pytest.approx( + assert index[id_].id == id_ # type: ignore[index] + assert cosine(index[id_].image.embedding, doc[0].image.embedding) == pytest.approx( 0.0 ) - assert store[id_].text.text == doc[0].text.text + assert index[id_].text.text == doc[0].text.text diff --git a/tests/index/qdrant/test_persist_data.py b/tests/index/qdrant/test_persist_data.py index 0255b78c120..9fd54715a30 100644 --- a/tests/index/qdrant/test_persist_data.py +++ b/tests/index/qdrant/test_persist_data.py @@ -5,8 +5,7 @@ from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray - -from .fixtures import qdrant_config, qdrant +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -20,37 +19,37 @@ class NestedDoc(BaseDoc): tens: NdArray[50] # type: ignore[valid-type] -def test_persist_and_restore(qdrant_config, qdrant): +def test_persist_and_restore(qdrant_config): # noqa: F811 query = SimpleDoc(tens=np.random.random((10,))) # create index - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) - assert store.num_docs() == 10 - find_results_before = store.find(query, search_field='tens', limit=5) + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) + assert index.num_docs() == 10 + find_results_before = index.find(query, search_field='tens', limit=5) # delete and restore - del store - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - assert store.num_docs() == 10 - find_results_after = store.find(query, search_field='tens', limit=5) + del index + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + assert index.num_docs() == 10 + find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): assert doc_before.id == doc_after.id assert doc_before.tens == pytest.approx(doc_after.tens) # add new data - store.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) - assert store.num_docs() == 15 + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) + assert index.num_docs() == 15 -def test_persist_and_restore_nested(qdrant_config, qdrant): +def test_persist_and_restore_nested(qdrant_config): # noqa: F811 query = NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) ) # create index - store = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) - store.index( + index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index.index( [ NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) @@ -58,20 +57,20 @@ def test_persist_and_restore_nested(qdrant_config, qdrant): for _ in range(10) ] ) - assert store.num_docs() == 10 - find_results_before = store.find(query, search_field='d__tens', limit=5) + assert index.num_docs() == 10 + find_results_before = index.find(query, search_field='d__tens', limit=5) # delete and restore - del store - store = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) - assert store.num_docs() == 10 - find_results_after = store.find(query, search_field='d__tens', limit=5) + del index + index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + assert index.num_docs() == 10 + find_results_after = index.find(query, search_field='d__tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): assert doc_before.id == doc_after.id assert doc_before.tens == pytest.approx(doc_after.tens) # delete and restore - store.index( + index.index( [ NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) @@ -79,4 +78,4 @@ def test_persist_and_restore_nested(qdrant_config, qdrant): for _ in range(5) ] ) - assert store.num_docs() == 15 + assert index.num_docs() == 15 diff --git a/tests/index/qdrant/test_query_builder.py b/tests/index/qdrant/test_query_builder.py index ff3ae109879..fba5b5e9d09 100644 --- a/tests/index/qdrant/test_query_builder.py +++ b/tests/index/qdrant/test_query_builder.py @@ -1,15 +1,12 @@ -import pytest import numpy as np - +import pytest from pydantic import Field +from qdrant_client.http import models as rest from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray - -from qdrant_client.http import models as rest - -from .fixtures import qdrant_config, qdrant +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 class SimpleDoc(BaseDoc): @@ -24,10 +21,10 @@ class SimpleSchema(BaseDoc): text: str -def test_find_uses_provided_vector(qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) +def test_find_uses_provided_vector(qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) - query = store.build_query().find(np.ones(10), 'embedding').build(7) # type: ignore[attr-defined] + query = index.build_query().find(np.ones(10), 'embedding').build(7) # type: ignore[attr-defined] assert query.vector_field == 'embedding' assert np.allclose(query.vector_query, np.ones(10)) @@ -35,11 +32,11 @@ def test_find_uses_provided_vector(qdrant_config, qdrant): assert query.limit == 7 -def test_multiple_find_returns_averaged_vector(qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) +def test_multiple_find_returns_averaged_vector(qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) query = ( - store.build_query() # type: ignore[attr-defined] + index.build_query() # type: ignore[attr-defined] .find(np.ones(10), 'embedding') .find(np.zeros(10), 'embedding') .build(5) @@ -51,22 +48,22 @@ def test_multiple_find_returns_averaged_vector(qdrant_config, qdrant): assert query.limit == 5 -def test_multiple_find_different_field_raises_error(qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) +def test_multiple_find_different_field_raises_error(qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) with pytest.raises(ValueError): ( - store.build_query() # type: ignore[attr-defined] + index.build_query() # type: ignore[attr-defined] .find(np.ones(10), 'embedding_1') .find(np.zeros(10), 'embedding_2') ) -def test_filter_passes_qdrant_filter(qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) +def test_filter_passes_qdrant_filter(qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) qdrant_filter = rest.Filter(should=[rest.HasIdCondition(has_id=[1, 2, 3])]) - query = store.build_query().filter(qdrant_filter).build(11) # type: ignore[attr-defined] + query = index.build_query().filter(qdrant_filter).build(11) # type: ignore[attr-defined] assert query.vector_field is None assert query.vector_query is None @@ -74,10 +71,10 @@ def test_filter_passes_qdrant_filter(qdrant_config, qdrant): assert query.limit == 11 -def test_text_search_creates_qdrant_filter(qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) +def test_text_search_creates_qdrant_filter(qdrant_config): # noqa: F811 + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) - query = store.build_query().text_search('lorem ipsum', 'text').build(3) # type: ignore[attr-defined] + query = index.build_query().text_search('lorem ipsum', 'text').build(3) # type: ignore[attr-defined] assert query.vector_field is None assert query.vector_query is None @@ -89,8 +86,10 @@ def test_text_search_creates_qdrant_filter(qdrant_config, qdrant): assert query.limit == 3 -def test_query_builder_execute_query_find_text_search_filter(qdrant_config, qdrant): - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) +def test_query_builder_execute_query_find_text_search_filter( + qdrant_config, # noqa: F811 +): + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) index_docs = [ SimpleDoc( @@ -100,7 +99,7 @@ def test_query_builder_execute_query_find_text_search_filter(qdrant_config, qdra ) for i in range(10, 30, 2) ] - store.index(index_docs) + index.index(index_docs) find_query = np.ones(10) text_search_query = 'ipsum 1' @@ -116,13 +115,13 @@ def test_query_builder_execute_query_find_text_search_filter(qdrant_config, qdra ] ) query = ( - store.build_query() # type: ignore[attr-defined] + index.build_query() # type: ignore[attr-defined] .find(find_query, search_field='embedding') .text_search(text_search_query, search_field='text') .filter(filter_query) .build(limit=5) ) - docs = store.execute_query(query) + docs = index.execute_query(query) assert len(docs) == 3 assert all(x in docs.number for x in [12, 14, 16]) diff --git a/tests/index/qdrant/test_raw_query.py b/tests/index/qdrant/test_raw_query.py index 7bc8ffd7e0c..27e05573763 100644 --- a/tests/index/qdrant/test_raw_query.py +++ b/tests/index/qdrant/test_raw_query.py @@ -1,15 +1,13 @@ -import numpy as np - from typing import Optional, Sequence +import numpy as np import pytest from pydantic import Field from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray - -from .fixtures import qdrant_config, qdrant +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 class SimpleDoc(BaseDoc): @@ -24,9 +22,9 @@ def index_docs() -> Sequence[SimpleDoc]: @pytest.mark.parametrize('limit', [1, 5, 10]) -def test_dict_limit(qdrant_config, qdrant, index_docs, limit): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(index_docs) +def test_dict_limit(qdrant_config, index_docs, limit): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(index_docs) # Search test query = { @@ -35,7 +33,7 @@ def test_dict_limit(qdrant_config, qdrant, index_docs, limit): 'with_vectors': True, } - points = store.execute_query(query=query) + points = index.execute_query(query=query) assert points is not None assert len(points) == limit @@ -45,14 +43,14 @@ def test_dict_limit(qdrant_config, qdrant, index_docs, limit): 'with_vectors': True, } - points = store.execute_query(query=query) + points = index.execute_query(query=query) assert points is not None assert len(points) == limit -def test_dict_full_text_filter(qdrant_config, qdrant, index_docs): - store = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) - store.index(index_docs) +def test_dict_full_text_filter(qdrant_config, index_docs): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index.index(index_docs) # Search test query = { @@ -63,7 +61,7 @@ def test_dict_full_text_filter(qdrant_config, qdrant, index_docs): 'with_vectors': True, } - points = store.execute_query(query=query) + points = index.execute_query(query=query) assert points is not None assert len(points) == 1 assert points[0].id == index_docs[2].id @@ -76,7 +74,7 @@ def test_dict_full_text_filter(qdrant_config, qdrant, index_docs): 'with_vectors': True, } - points = store.execute_query(query=query) + points = index.execute_query(query=query) assert points is not None assert len(points) == 1 assert points[0].id == index_docs[2].id diff --git a/tests/index/qdrant/test_text_search.py b/tests/index/qdrant/test_text_search.py index 9a863e2dbbf..31815163433 100644 --- a/tests/index/qdrant/test_text_search.py +++ b/tests/index/qdrant/test_text_search.py @@ -1,12 +1,10 @@ import numpy as np - from pydantic import Field from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray - -from .fixtures import qdrant_config, qdrant +from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 class SimpleDoc(BaseDoc): @@ -14,12 +12,12 @@ class SimpleDoc(BaseDoc): text: str -def test_text_search(qdrant_config, qdrant): +def test_text_search(qdrant_config): # noqa: F811 class SimpleSchema(BaseDoc): embedding: NdArray[10] = Field(space='cosine') # type: ignore[valid-type] text: str - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) index_docs = [ SimpleDoc( @@ -28,10 +26,10 @@ class SimpleSchema(BaseDoc): ) for i in range(10) ] - store.index(index_docs) + index.index(index_docs) query = 'ipsum 2' - docs, scores = store.text_search(query, search_field='text', limit=5) + docs, scores = index.text_search(query, search_field='text', limit=5) assert len(docs) == 1 assert len(scores) == 1 @@ -39,12 +37,12 @@ class SimpleSchema(BaseDoc): assert scores[0] > 0.0 -def test_text_search_batched(qdrant_config, qdrant): +def test_text_search_batched(qdrant_config): # noqa: F811 class SimpleSchema(BaseDoc): embedding: NdArray[10] = Field(space='cosine') # type: ignore[valid-type] text: str - store = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) index_docs = [ SimpleDoc( @@ -53,10 +51,10 @@ class SimpleSchema(BaseDoc): ) for i in range(10) ] - store.index(index_docs) + index.index(index_docs) queries = ['ipsum 2', 'ipsum 4', 'Lorem'] - docs, scores = store.text_search_batched(queries, search_field='text', limit=5) + docs, scores = index.text_search_batched(queries, search_field='text', limit=5) assert len(docs) == 3 assert len(docs[0]) == 1 diff --git a/tests/index/weaviate/test_column_config_weaviate.py b/tests/index/weaviate/test_column_config_weaviate.py index 037bdba6205..4789a6d707f 100644 --- a/tests/index/weaviate/test_column_config_weaviate.py +++ b/tests/index/weaviate/test_column_config_weaviate.py @@ -15,8 +15,8 @@ def test_column_config(weaviate_client): - def get_text_field_data_type(store, index_name): - props = store._client.schema.get(index_name)["properties"] + def get_text_field_data_type(index, index_name): + props = index._client.schema.get(index_name)["properties"] text_field = [p for p in props if p["name"] == "text"][0] return text_field["dataType"][0] @@ -28,9 +28,9 @@ class StringDoc(BaseDoc): text: str = Field(col_type="string") dbconfig = WeaviateDocumentIndex.DBConfig(index_name="TextDoc") - store = WeaviateDocumentIndex[TextDoc](db_config=dbconfig) - assert get_text_field_data_type(store, "TextDoc") == "text" + index = WeaviateDocumentIndex[TextDoc](db_config=dbconfig) + assert get_text_field_data_type(index, "TextDoc") == "text" dbconfig = WeaviateDocumentIndex.DBConfig(index_name="StringDoc") - store = WeaviateDocumentIndex[StringDoc](db_config=dbconfig) - assert get_text_field_data_type(store, "StringDoc") == "string" + index = WeaviateDocumentIndex[StringDoc](db_config=dbconfig) + assert get_text_field_data_type(index, "StringDoc") == "string" diff --git a/tests/index/weaviate/test_find_weaviate.py b/tests/index/weaviate/test_find_weaviate.py index 21a20be3921..7908d2c0ce6 100644 --- a/tests/index/weaviate/test_find_weaviate.py +++ b/tests/index/weaviate/test_find_weaviate.py @@ -21,15 +21,15 @@ def test_find_torch(weaviate_client): class TorchDoc(BaseDoc): tens: TorchTensor[10] = Field(dims=10, is_embedding=True) - store = WeaviateDocumentIndex[TorchDoc]() + index = WeaviateDocumentIndex[TorchDoc]() index_docs = [ TorchDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] - docs, scores = store.find(query, limit=5) + docs, scores = index.find(query, limit=5) assert len(docs) == 5 assert len(scores) == 5 @@ -47,15 +47,15 @@ def test_find_tensorflow(): class TfDoc(BaseDoc): tens: TensorFlowTensor[10] = Field(dims=10, is_embedding=True) - store = WeaviateDocumentIndex[TfDoc]() + index = WeaviateDocumentIndex[TfDoc]() index_docs = [ TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) ] - store.index(index_docs) + index.index(index_docs) query = index_docs[-1] - docs, scores = store.find(query, limit=5) + docs, scores = index.find(query, limit=5) assert len(docs) == 5 assert len(scores) == 5 diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index 431501c2730..ba4d6a27f5e 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -58,16 +58,16 @@ def documents(): @pytest.fixture -def test_store(weaviate_client, documents): - store = WeaviateDocumentIndex[Document]() - store.index(documents) - yield store +def test_index(weaviate_client, documents): + index = WeaviateDocumentIndex[Document]() + index.index(documents) + yield index def test_index_simple_schema(weaviate_client, ten_simple_docs): - store = WeaviateDocumentIndex[SimpleDoc]() - store.index(ten_simple_docs) - assert store.num_docs() == 10 + index = WeaviateDocumentIndex[SimpleDoc]() + index.index(ten_simple_docs) + assert index.num_docs() == 10 for doc in ten_simple_docs: doc_id = doc.id @@ -111,24 +111,24 @@ class Document(BaseDoc): vectors = [[10, 10], [10.5, 10.5], [-100, -100]] docs = [Document(embedding=vector) for vector in vectors] - store = WeaviateDocumentIndex[Document]() - store.index(docs) + index = WeaviateDocumentIndex[Document]() + index.index(docs) query = [10.1, 10.1] - results = store.find( + results = index.find( query, search_field='', limit=3, score_name="distance", score_threshold=1e-2 ) assert len(results) == 2 - results = store.find(query, search_field='', limit=3, score_threshold=0.99) + results = index.find(query, search_field='', limit=3, score_threshold=0.99) assert len(results) == 2 with pytest.raises( ValueError, match=r"Argument search_field is not supported for WeaviateDocumentIndex", ): - store.find(query, search_field="foo", limit=10) + index.find(query, search_field="foo", limit=10) def test_find_batched(weaviate_client, caplog): @@ -138,19 +138,19 @@ class Document(BaseDoc): vectors = [[10, 10], [10.5, 10.5], [-100, -100]] docs = [Document(embedding=vector) for vector in vectors] - store = WeaviateDocumentIndex[Document]() - store.index(docs) + index = WeaviateDocumentIndex[Document]() + index.index(docs) queries = np.array([[10.1, 10.1], [-100, -100]]) - results = store.find_batched( + results = index.find_batched( queries, search_field='', limit=3, score_name="distance", score_threshold=1e-2 ) assert len(results) == 2 assert len(results.documents[0]) == 2 assert len(results.documents[1]) == 1 - results = store.find_batched( + results = index.find_batched( queries, search_field='', limit=3, score_name="certainty" ) assert len(results) == 2 @@ -161,7 +161,7 @@ class Document(BaseDoc): ValueError, match=r"Argument search_field is not supported for WeaviateDocumentIndex", ): - store.find_batched(queries, search_field="foo", limit=10) + index.find_batched(queries, search_field="foo", limit=10) @pytest.mark.parametrize( @@ -172,8 +172,8 @@ class Document(BaseDoc): ({"path": ["id"], "operator": "Equal", "valueString": "1"}, 1), ], ) -def test_filter(test_store, filter_query, expected_num_docs): - docs = test_store.filter(filter_query, limit=3) +def test_filter(test_index, filter_query, expected_num_docs): + docs = test_index.filter(filter_query, limit=3) actual_num_docs = len(docs) assert actual_num_docs == expected_num_docs @@ -198,50 +198,50 @@ def test_filter(test_store, filter_query, expected_num_docs): ), ], ) -def test_filter_batched(test_store, filter_queries, expected_num_docs): +def test_filter_batched(test_index, filter_queries, expected_num_docs): filter_queries = [ {"path": ["text"], "operator": "Equal", "valueText": "lorem ipsum"}, {"path": ["text"], "operator": "Equal", "valueText": "foo"}, ] - results = test_store.filter_batched(filter_queries, limit=3) + results = test_index.filter_batched(filter_queries, limit=3) actual_num_docs = [len(docs) for docs in results] assert actual_num_docs == expected_num_docs -def test_text_search(test_store): - results = test_store.text_search(query="lorem", search_field="text", limit=3) +def test_text_search(test_index): + results = test_index.text_search(query="lorem", search_field="text", limit=3) assert len(results.documents) == 1 -def test_text_search_batched(test_store): +def test_text_search_batched(test_index): text_queries = ["lorem", "foo"] - results = test_store.text_search_batched( + results = test_index.text_search_batched( queries=text_queries, search_field="text", limit=3 ) assert len(results.documents[0]) == 1 assert len(results.documents[1]) == 0 -def test_del_items(test_store): - del test_store[["1", "2"]] - assert test_store.num_docs() == 1 +def test_del_items(test_index): + del test_index[["1", "2"]] + assert test_index.num_docs() == 1 -def test_get_items(test_store): - docs = test_store[["1", "2"]] +def test_get_items(test_index): + docs = test_index[["1", "2"]] assert len(docs) == 2 assert set(doc.id for doc in docs) == {'1', '2'} def test_index_nested_documents(weaviate_client): - store = WeaviateDocumentIndex[NestedDocument]() + index = WeaviateDocumentIndex[NestedDocument]() document = NestedDocument( text="lorem ipsum", child=Document(embedding=[10, 10], text="dolor sit amet") ) - store.index([document]) - assert store.num_docs() == 1 + index.index([document]) + assert index.num_docs() == 1 @pytest.mark.parametrize( @@ -256,13 +256,13 @@ def test_index_nested_documents(weaviate_client): def test_text_search_nested_documents( weaviate_client, search_field, query, expected_num_docs ): - store = WeaviateDocumentIndex[NestedDocument]() + index = WeaviateDocumentIndex[NestedDocument]() document = NestedDocument( text="lorem ipsum", child=Document(embedding=[10, 10], text="dolor sit amet") ) - store.index([document]) + index.index([document]) - results = store.text_search(query=query, search_field=search_field, limit=3) + results = index.text_search(query=query, search_field=search_field, limit=3) assert len(results.documents) == expected_num_docs @@ -275,37 +275,37 @@ def test_reuse_existing_schema(weaviate_client, caplog): assert "Will reuse existing schema" in caplog.text -def test_query_builder(test_store): +def test_query_builder(test_index): query_embedding = [10.25, 10.25] query_text = "ipsum" where_filter = {"path": ["id"], "operator": "Equal", "valueString": "1"} q = ( - test_store.build_query() + test_index.build_query() .find(query=query_embedding) .filter(where_filter) .build() ) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert len(docs) == 1 q = ( - test_store.build_query() + test_index.build_query() .text_search(query=query_text, search_field="text") .build() ) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert len(docs) == 1 -def test_batched_query_builder(test_store): +def test_batched_query_builder(test_index): query_embeddings = [[10.25, 10.25], [-100, -100]] query_texts = ["ipsum", "foo"] where_filters = [{"path": ["id"], "operator": "Equal", "valueString": "1"}] q = ( - test_store.build_query() + test_index.build_query() .find_batched( queries=query_embeddings, score_name="certainty", score_threshold=0.99 ) @@ -313,22 +313,22 @@ def test_batched_query_builder(test_store): .build() ) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert len(docs[0]) == 1 assert len(docs[1]) == 0 q = ( - test_store.build_query() + test_index.build_query() .text_search_batched(queries=query_texts, search_field="text") .build() ) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert len(docs[0]) == 1 assert len(docs[1]) == 0 -def test_raw_graphql(test_store): +def test_raw_graphql(test_index): graphql_query = """ { Aggregate { @@ -341,35 +341,35 @@ def test_raw_graphql(test_store): } """ - results = test_store.execute_query(graphql_query) + results = test_index.execute_query(graphql_query) num_docs = results["data"]["Aggregate"]["Document"][0]["meta"]["count"] assert num_docs == 3 -def test_hybrid_query(test_store): +def test_hybrid_query(test_index): query_embedding = [10.25, 10.25] query_text = "ipsum" where_filter = {"path": ["id"], "operator": "Equal", "valueString": "1"} q = ( - test_store.build_query() + test_index.build_query() .find(query=query_embedding) .text_search(query=query_text, search_field="text") .filter(where_filter) .build() ) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert len(docs) == 1 -def test_hybrid_query_batched(test_store): +def test_hybrid_query_batched(test_index): query_embeddings = [[10.25, 10.25], [-100, -100]] query_texts = ["dolor", "elit"] q = ( - test_store.build_query() + test_index.build_query() .find_batched( queries=query_embeddings, score_name="certainty", score_threshold=0.99 ) @@ -377,7 +377,7 @@ def test_hybrid_query_batched(test_store): .build() ) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert docs[0][0].id == '1' assert docs[1][0].id == '2' @@ -387,28 +387,28 @@ class MyMultiModalDoc(BaseDoc): image: ImageDoc text: TextDoc - store = WeaviateDocumentIndex[MyMultiModalDoc]() + index = WeaviateDocumentIndex[MyMultiModalDoc]() doc = [ MyMultiModalDoc( image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') ) ] - store.index(doc) + index.index(doc) id_ = doc[0].id - assert store[id_].id == id_ - assert np.all(store[id_].image.embedding == doc[0].image.embedding) - assert store[id_].text.text == doc[0].text.text + assert index[id_].id == id_ + assert np.all(index[id_].image.embedding == doc[0].image.embedding) + assert index[id_].text.text == doc[0].text.text def test_index_document_with_bytes(weaviate_client): doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo") - store = WeaviateDocumentIndex[ImageDoc]() - store.index([doc]) + index = WeaviateDocumentIndex[ImageDoc]() + index.index([doc]) - results = store.filter( + results = index.filter( filter_query={"path": ["id"], "operator": "Equal", "valueString": "1"} ) @@ -423,22 +423,22 @@ class Document(BaseDoc): doc = Document(not_embedding=[2, 5], text="dolor sit amet", id="1") - store = WeaviateDocumentIndex[Document]() + index = WeaviateDocumentIndex[Document]() - store.index([doc]) + index.index([doc]) - results = store.filter( + results = index.filter( filter_query={"path": ["id"], "operator": "Equal", "valueString": "1"} ) assert doc == results[0] -def test_limit_query_builder(test_store): +def test_limit_query_builder(test_index): query_vector = [10.25, 10.25] - q = test_store.build_query().find(query=query_vector).limit(2) + q = test_index.build_query().find(query=query_vector).limit(2) - docs = test_store.execute_query(q) + docs = test_index.execute_query(q) assert len(docs) == 2 @@ -448,6 +448,6 @@ class Document(BaseDoc): embedded_options = EmbeddedOptions() db_config = WeaviateDocumentIndex.DBConfig(embedded_options=embedded_options) - store = WeaviateDocumentIndex[Document](db_config=db_config) + index = WeaviateDocumentIndex[Document](db_config=db_config) - assert store._client._connection.embedded_db + assert index._client._connection.embedded_db