diff --git a/test/conftest.py b/test/conftest.py index cae7b6a141..4a5deef8d8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -234,24 +234,25 @@ def document_store(request, test_docs_xs): document_store.faiss_index.reset() -def get_document_store(document_store_type): +def get_document_store(document_store_type, embedding_field="embedding"): if document_store_type == "sql": if os.path.exists("haystack_test.db"): os.remove("haystack_test.db") document_store = SQLDocumentStore(url="sqlite:///haystack_test.db") elif document_store_type == "memory": - document_store = InMemoryDocumentStore(return_embedding=True) + document_store = InMemoryDocumentStore(return_embedding=True, embedding_field=embedding_field) elif document_store_type == "elasticsearch": # make sure we start from a fresh index client = Elasticsearch() client.indices.delete(index='haystack_test*', ignore=[404]) - document_store = ElasticsearchDocumentStore(index="haystack_test", return_embedding=True) + document_store = ElasticsearchDocumentStore( + index="haystack_test", return_embedding=True, embedding_field=embedding_field + ) elif document_store_type == "faiss": if os.path.exists("haystack_test_faiss.db"): os.remove("haystack_test_faiss.db") document_store = FAISSDocumentStore( - sql_url="sqlite:///haystack_test_faiss.db", - return_embedding=True + sql_url="sqlite:///haystack_test_faiss.db", return_embedding=True, embedding_field=embedding_field ) return document_store else: diff --git a/test/test_db.py b/test/test_db.py index 8a45b39328..b9db9a8c6c 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -2,6 +2,7 @@ import pytest from elasticsearch import Elasticsearch +from conftest import get_document_store from haystack import Document, Label from haystack.document_store.elasticsearch import ElasticsearchDocumentStore @@ -369,6 +370,20 @@ def test_elasticsearch_update_meta(document_store): assert updated_document.meta["meta_key_2"] == "2" +@pytest.mark.elasticsearch +@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"]) +def test_custom_embedding_field(document_store_type): + document_store = get_document_store( + document_store_type=document_store_type, embedding_field="custom_embedding_field" + ) + doc_to_write = {"text": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)} + document_store.write_documents([doc_to_write]) + documents = document_store.get_all_documents(return_embedding=True) + assert len(documents) == 1 + assert documents[0].text == "test" + np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding) + + @pytest.mark.elasticsearch def test_elasticsearch_custom_fields(elasticsearch_fixture): client = Elasticsearch()