Skip to content

Commit

Permalink
Clean API docs and increase coverage (#621)
Browse files Browse the repository at this point in the history
* Fix docstrings

* Fix docstrings

* docstrings for retrievers and docstores

* Clean and add more docstrings
  • Loading branch information
brandenchan committed Nov 27, 2020
1 parent fa55de2 commit 9fbd845
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 8 deletions.
46 changes: 43 additions & 3 deletions haystack/document_store/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ def __init__(
If set to False, an error is raised if the document ID of the document being
added already exists.
:param refresh_type: Type of ES refresh used to control when changes made by a request (e.g. bulk) are made visible to search.
Values:
- 'wait_for' => continue only after changes are visible (slow, but safe)
- 'false' => continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion)
If set to 'wait_for', continue only after changes are visible (slow, but safe).
If set to 'false', continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion).
More info at https://www.elastic.co/guide/en/elasticsearch/reference/6.8/docs-refresh.html
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
Expand Down Expand Up @@ -220,6 +219,7 @@ def _create_document_field_map(self) -> Dict:
}

def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
index = index or self.index
documents = self.get_documents_by_id([id], index=index)
if documents:
Expand All @@ -228,6 +228,7 @@ def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
return None

def get_documents_by_id(self, ids: List[str], index=None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
query = {"query": {"ids": {"values": ids}}}
result = self.client.search(index=index, body=query)["hits"]["hits"]
Expand Down Expand Up @@ -298,6 +299,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)

def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None):
"""Write annotation labels into document store."""
index = index or self.label_index
if index and not self.client.indices.exists(index=index):
self._create_label_index(index)
Expand All @@ -317,10 +319,16 @@ def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[s
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)

def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Update the metadata dictionary of a document by specifying its string id
"""
body = {"doc": meta}
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type)

def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
index = index or self.index

body: dict = {"query": {"bool": {}}}
Expand All @@ -343,6 +351,9 @@ def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, ind
return count

def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store
"""
return self.get_document_count(index=index)

def get_all_documents(
Expand Down Expand Up @@ -372,12 +383,18 @@ def get_all_documents(
return documents

def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
"""
Return all labels in the document store
"""
index = index or self.label_index
result = self.get_all_documents_in_index(index=index, filters=filters)
labels = [Label.from_dict(hit["_source"]) for hit in result]
return labels

def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]:
"""
Return all documents in a specific index in the document store
"""
body = {
"query": {
"bool": {
Expand Down Expand Up @@ -409,6 +426,15 @@ def query(
custom_query: Optional[str] = None,
index: Optional[str] = None,
) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query as defined by the BM25 algorithm.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""

if index is None:
index = self.index
Expand Down Expand Up @@ -483,6 +509,17 @@ def query_by_embedding(self,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: Index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""
if index is None:
index = self.index

Expand Down Expand Up @@ -572,6 +609,9 @@ def _convert_es_hit_to_document(
return document

def describe_documents(self, index=None):
"""
Return a summary of the documents in the document store
"""
if index is None:
index = self.index
docs = self.get_all_documents(index)
Expand Down
4 changes: 4 additions & 0 deletions haystack/document_store/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
Expand Down Expand Up @@ -229,6 +230,9 @@ def train_index(self, documents: Optional[Union[List[dict], List[Document]]], em
self.faiss_index.train(embeddings)

def delete_all_documents(self, index=None):
"""
Delete all documents from the document store.
"""
index = index or self.index
self.faiss_index.reset()
super().delete_all_documents(index=index)
Expand Down
26 changes: 25 additions & 1 deletion haystack/document_store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
self.indexes[index][document.id] = document

def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
"""Write annotation labels into document store."""
index = index or self.label_index
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]

Expand All @@ -55,6 +56,7 @@ def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[s
self.indexes[index][label_id] = label

def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
index = index or self.index
documents = self.get_documents_by_id([id], index=index)
if documents:
Expand All @@ -63,6 +65,7 @@ def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[D
return None

def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
documents = [self.indexes[index][id] for id in ids]
return documents
Expand All @@ -74,6 +77,18 @@ def query_by_embedding(self,
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:

"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: Index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""

from numpy import dot
from numpy.linalg import norm

Expand Down Expand Up @@ -137,13 +152,19 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non
self.indexes[index][doc.id].embedding = emb

def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
documents = self.get_all_documents(index=index, filters=filters)
return len(documents)

def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store
"""
index = index or self.label_index
return len(self.indexes[index].items())

def get_all_documents(
self,
index: Optional[str] = None,
Expand Down Expand Up @@ -187,6 +208,9 @@ def get_all_documents(
return filtered_documents

def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
"""
Return all labels in the document store
"""
index = index or self.label_index

if filters:
Expand Down
16 changes: 16 additions & 0 deletions haystack/document_store/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,21 @@ def __init__(
self.update_existing_documents = update_existing_documents

def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document

def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all()
documents = [self._convert_sql_row_to_document(row) for row in results]

return documents

def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
"""Fetch documents by specifying a list of text vector id strings"""
index = index or self.index
results = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids),
Expand Down Expand Up @@ -138,6 +141,9 @@ def get_all_documents(
return documents

def get_all_labels(self, index=None, filters: Optional[dict] = None):
"""
Return all labels in the document store
"""
index = index or self.label_index
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
labels = [self._convert_sql_row_to_label(row) for row in label_rows]
Expand Down Expand Up @@ -182,6 +188,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
raise ex

def write_labels(self, labels, index=None):
"""Write annotation labels into document store."""

labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
index = index or self.label_index
Expand Down Expand Up @@ -221,6 +228,9 @@ def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str]
self.session.commit()

def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Update the metadata dictionary of a document by specifying its string id
"""
self.session.query(MetaORM).filter_by(document_id=id).delete()
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
for m in meta_orms:
Expand All @@ -244,6 +254,9 @@ def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_i
self.write_labels(labels, index=label_index)

def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
index = index or self.index
query = self.session.query(DocumentORM).filter_by(index=index)

Expand All @@ -256,6 +269,9 @@ def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, ind
return count

def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store
"""
index = index or self.index
return self.session.query(LabelORM).filter_by(index=index).count()

Expand Down
5 changes: 5 additions & 0 deletions haystack/file_converter/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def __init__(self, remove_numeric_tables: Optional[bool] = False, valid_language
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)

def convert(self, file_path: Path, meta: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
"""
Extract text from a .pdf file.
:param file_path: Path to the .pdf file you want to convert
"""

pages = self._read_pdf(file_path, layout=False)

Expand Down
3 changes: 3 additions & 0 deletions haystack/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

class BasePreProcessor:
def process(self, document: dict) -> List[dict]:
"""
Perform document cleaning and splitting. Takes a single document as input and returns a list of documents.
"""
cleaned_document = self.clean(document)
split_documents = self.split(cleaned_document)
return split_documents
Expand Down
4 changes: 4 additions & 0 deletions haystack/preprocessor/cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@


def clean_wiki_text(text: str) -> str:
"""
Clean wikipedia text by removing multiple new lines, removing extremely short lines,
adding paragraph breaks and removing empty paragraphs
"""
# get rid of multiple new lines
while "\n\n" in text:
text = text.replace("\n\n", "\n")
Expand Down
8 changes: 8 additions & 0 deletions haystack/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(
self.split_respect_sentence_boundary = split_respect_sentence_boundary

def clean(self, document: dict) -> dict:
"""
Perform document cleaning on a single document and return a single document. This method will deal with whitespaces, headers, footers
and empty lines. Its exact functionality is defined by the parameters passed into PreProcessor.__init__().
"""
text = document["text"]
if self.clean_header_footer:
text = self._find_and_remove_header_footer(
Expand All @@ -74,6 +78,10 @@ def clean(self, document: dict) -> dict:
return document

def split(self, document: dict) -> List[dict]:
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
with different strides. It can also respect sectence boundaries. Its exact functionality is defined by
the parameters passed into PreProcessor.__init__(). Takes a single document as input and returns a list of documents. """

if not self.split_by:
return [document]

Expand Down
5 changes: 1 addition & 4 deletions haystack/reader/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ class TransformersReader(BaseReader):
Transformer based model for extractive Question Answering using the HuggingFace's transformers framework
(https://github.com/huggingface/transformers).
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
| With the reader, you can:
- directly get predictions via predict()
With this reader, you can directly get predictions via predict()
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions haystack/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str
pass

def timing(self, fn):
"""Wrapper method used to time functions. """
@wraps(fn)
def wrapper(*args, **kwargs):
if "retrieve_time" not in self.__dict__:
Expand Down
Loading

0 comments on commit 9fbd845

Please sign in to comment.