Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Columns names instead of ORM to get all documents #620

Merged
merged 14 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 129 additions & 46 deletions haystack/document_store/sql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
from typing import Any, Dict, Union, List, Optional
from uuid import uuid4
Expand Down Expand Up @@ -42,7 +43,12 @@ class MetaORM(ORMBase):

name = Column(String(100), index=True)
value = Column(String(1000), index=True)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
document_id = Column(
String(100),
ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"),
nullable=False,
index=True
)

documents = relationship(DocumentORM, backref="Meta")

Expand All @@ -69,6 +75,7 @@ def __init__(
index: str = "document",
label_index: str = "label",
update_existing_documents: bool = False,
batch_size: int = 999,
):
"""
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
Expand All @@ -80,7 +87,11 @@ def __init__(
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists. Using this parameter coud cause performance degradation for document insertion.
added already exists. Using this parameter could cause performance degradation
for document insertion.
:param batch_size: Maximum number of host parameters in a single SQL statement.
lalitpagaria marked this conversation as resolved.
Show resolved Hide resolved
To help in excessive memory allocations.
More info refer: https://www.sqlite.org/limits.html
"""
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
Expand All @@ -91,6 +102,7 @@ def __init__(
self.update_existing_documents = update_existing_documents
if getattr(self, "similarity", None) is None:
self.similarity = None
self.batch_size = batch_size

def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
Expand All @@ -101,21 +113,33 @@ def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[D
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all()
documents = [self._convert_sql_row_to_document(row) for row in results]

documents = []
for i in range(0, len(ids), self.batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.id.in_(ids[i: i + self.batch_size]),
DocumentORM.index == index
)
for row in query.all():
documents.append(self._convert_sql_row_to_document(row))

return documents

def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
"""Fetch documents by specifying a list of text vector id strings"""
index = index or self.index
results = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids),
DocumentORM.index == index
).all()
sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id))
documents = [self._convert_sql_row_to_document(row) for row in sorted_results]
return documents

documents = []
for i in range(0, len(vector_ids), self.batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids[i: i + self.batch_size]),
DocumentORM.index == index
)
for row in query.all():
documents.append(self._convert_sql_row_to_document(row))

sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"])) # type: ignore
return sorted_documents

def get_all_documents(
self,
Expand All @@ -134,21 +158,52 @@ def get_all_documents(
"""

index = index or self.index
query = self.session.query(DocumentORM).filter_by(index=index)
# Generally ORM objects kept in memory cause performance issue
# Hence using directly column name improve memory and performance.
# Refer https://stackoverflow.com/questions/23185319/why-is-loading-sqlalchemy-objects-via-the-orm-5-8x-slower-than-rows-via-a-raw-my
documents_query = self.session.query(
DocumentORM.id,
DocumentORM.text,
DocumentORM.vector_id
).filter_by(index=index)

if filters:
query = query.join(MetaORM)
documents_query = documents_query.join(MetaORM)
for key, values in filters.items():
query = query.filter(MetaORM.name == key, MetaORM.value.in_(values))
documents_query = documents_query.filter(
MetaORM.name == key,
MetaORM.value.in_(values),
DocumentORM.id == MetaORM.document_id
)

documents_map = {}
for row in documents_query.all():
documents_map[row.id] = Document(
id=row.id,
text=row.text,
meta=None if row.vector_id is None else {"vector_id": row.vector_id} # type: ignore
)

documents = [self._convert_sql_row_to_document(row) for row in query.all()]
return documents
for doc_ids in self.chunked_iterable(documents_map.keys(), size=self.batch_size):
meta_query = self.session.query(
MetaORM.document_id,
MetaORM.name,
MetaORM.value
).filter(MetaORM.document_id.in_(doc_ids))

for row in meta_query.all():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first query (line 120) executes quickly as it should if just querying by index (< 1 sec). Memory usage seems to be normal / expected.

The second call to get metadata is much slower and errors eventually:

sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) too many SQL variables
[SQL: SELECT meta.document_id AS meta_document_id, meta.name AS meta_name, meta.value AS meta_value
FROM meta
WHERE meta.document_id IN (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,.....?)

It seems odd that 2 separate queries would be required to get the same fields in a document.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second call to get metadata is much slower and errors eventually:

To fix it I have added index on document_id and limiting the number of host variable parameters passed to sql queries. (999 is for sqlite < 3.32 and 32K for >= 3.32)

It seems odd that 2 separate queries would be required to get the same fields in a document.

It is done to prevent duplication of very long text field in memory. Each docs can have multiple metas and for each meta better not to keep duplicate text in the memory. Hence I have split it in two queries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please test with latest change where I have fix the issue you have reported.
It seems you have very good amount of data to benchmark it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinchg by chance if you get time could you please test latest changes.

if documents_map[row.document_id].meta is None:
documents_map[row.document_id].meta = {}
documents_map[row.document_id].meta[row.name] = row.value # type: ignore

return list(documents_map.values())

def get_all_labels(self, index=None, filters: Optional[dict] = None):
"""
Return all labels in the document store
"""
index = index or self.label_index
# TODO: Use batch_size
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
labels = [self._convert_sql_row_to_label(row) for row in label_rows]

Expand All @@ -169,33 +224,41 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
:return: None
"""

# Make sure we comply to Document class format
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
index = index or self.index
for doc in document_objects:
meta_fields = doc.meta or {}
vector_id = meta_fields.get("vector_id")
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
if self.update_existing_documents:
# First old meta data cleaning is required
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
self.session.merge(doc_orm)
else:
self.session.add(doc_orm)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
self.session.rollback()
raise ex
if len(documents) == 0:
return
# Make sure we comply to Document class format
if isinstance(documents[0], dict):
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
else:
document_objects = documents

for i in range(0, len(document_objects), self.batch_size):
for doc in document_objects[i: i + self.batch_size]:
meta_fields = doc.meta or {}
vector_id = meta_fields.get("vector_id")
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
if self.update_existing_documents:
# First old meta data cleaning is required
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
self.session.merge(doc_orm)
else:
self.session.add(doc_orm)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
self.session.rollback()
raise ex

def write_labels(self, labels, index=None):
"""Write annotation labels into document store."""

labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
index = index or self.label_index
# TODO: Use batch_size
for label in labels:
label_orm = LabelORM(
document_id=label.document_id,
Expand All @@ -220,16 +283,22 @@ def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str]
:param index: filter documents by the optional index attribute for documents in database.
"""
index = index or self.index
self.session.query(DocumentORM).filter(
DocumentORM.id.in_(vector_id_map),
DocumentORM.index == index
).update({
DocumentORM.vector_id: case(
vector_id_map,
value=DocumentORM.id,
)
}, synchronize_session=False)
self.session.commit()
for chunk_map in self.chunked_dict(vector_id_map, size=self.batch_size):
self.session.query(DocumentORM).filter(
DocumentORM.id.in_(chunk_map),
DocumentORM.index == index
).update({
DocumentORM.vector_id: case(
chunk_map,
value=DocumentORM.id,
)
}, synchronize_session=False)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
self.session.rollback()
raise ex

def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Expand Down Expand Up @@ -338,3 +407,17 @@ def _get_or_create(self, session, model, **kwargs):
session.add(instance)
session.commit()
return instance

# Refer: https://alexwlchan.net/2018/12/iterating-in-fixed-size-chunks/
def chunked_iterable(self, iterable, size):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, size))
if not chunk:
break
yield chunk

def chunked_dict(self, dictionary, size):
it = iter(dictionary)
for i in range(0, len(dictionary), size):
yield {k: dictionary[k] for k in itertools.islice(it, size)}
21 changes: 19 additions & 2 deletions test/test_faiss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import faiss
import numpy as np
import pytest
Expand Down Expand Up @@ -120,16 +122,31 @@ def test_faiss_update_with_empty_store(document_store, retriever):

@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
def test_faiss_retrieving(index_factory):
document_store = FAISSDocumentStore(sql_url="sqlite:///haystack_test_faiss.db", faiss_index_factory_str=index_factory)
document_store = FAISSDocumentStore(
sql_url="sqlite:///test_faiss_retrieving.db",
faiss_index_factory_str=index_factory
)

document_store.delete_all_documents(index="document")
if "ivf" in index_factory.lower():
document_store.train_index(DOCUMENTS)
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)

retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="deepset/sentence_bert",
use_gpu=False
)
result = retriever.retrieve(query="How to test this?")

assert len(result) == len(DOCUMENTS)
assert type(result[0]) == Document

# Cleanup
document_store.faiss_index.reset()
if os.path.exists("test_faiss_retrieving.db"):
os.remove("test_faiss_retrieving.db")


@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
Expand Down