Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tqdm progress bars optional (less verbose prod logs) #796

Merged
merged 6 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/_src/api/api/document_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_

```python
| __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", **kwargs, ,)
| __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, **kwargs, ,)
```

**Arguments**:
Expand Down Expand Up @@ -816,6 +816,8 @@ added already exists.
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
- `embedding_field`: Name of field containing an embedding vector.
- `progress_bar`: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.

<a name="faiss.FAISSDocumentStore.write_documents"></a>
#### write\_documents
Expand Down
4 changes: 3 additions & 1 deletion docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf
#### \_\_init\_\_

```python
| __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128)
| __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True)
```

**Arguments**:
Expand Down Expand Up @@ -59,6 +59,8 @@ multiprocessing. Set to None to let Inferencer determine optimum number. If you
want to debug the Language Model, you might need to disable multiprocessing!
- `max_seq_len`: Max sequence length of one input text for the model
- `doc_stride`: Length of striding window for splitting long texts (used if ``len(text) > max_seq_len``)
- `progress_bar`: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.

<a name="farm.FARMReader.train"></a>
#### train
Expand Down
4 changes: 3 additions & 1 deletion docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### \_\_init\_\_

```python
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product")
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", progress_bar: bool = True)
```

Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
Expand Down Expand Up @@ -264,6 +264,8 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}.
- `progress_bar`: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.

<a name="dense.DensePassageRetriever.retrieve"></a>
#### retrieve
Expand Down
7 changes: 6 additions & 1 deletion haystack/document_store/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
index: str = "document",
similarity: str = "dot_product",
embedding_field: str = "embedding",
progress_bar: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
:param embedding_field: Name of field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
self.vector_dim = vector_dim

Expand All @@ -91,6 +94,8 @@ def __init__(
else:
raise ValueError("The FAISS document store can currently only support dot_product similarity. "
"Please set similarity=\"dot_product\"")
self.progress_bar = progress_bar

super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
Expand Down Expand Up @@ -189,7 +194,7 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non

result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count) as progress_bar:
with tqdm(total=document_count, disable=self.progress_bar) as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
Expand Down
8 changes: 6 additions & 2 deletions haystack/document_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
update_existing_documents: bool = False,
return_embedding: bool = False,
embedding_field: str = "embedding",
progress_bar: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -90,6 +91,8 @@ def __init__(
added already exists.
:param return_embedding: To return document embedding.
:param embedding_field: Name of field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
self.vector_dim = vector_dim
Expand All @@ -108,6 +111,7 @@ def __init__(
self._create_collection_and_index_if_not_exist(self.index)
self.return_embedding = return_embedding
self.embedding_field = embedding_field
self.progress_bar = progress_bar

super().__init__(
url=sql_url,
Expand Down Expand Up @@ -173,7 +177,7 @@ def write_documents(
add_vectors = False if document_objects[0].embedding is None else True

batched_documents = get_batches_from_generator(document_objects, batch_size)
with tqdm(total=len(document_objects)) as progress_bar:
with tqdm(total=len(document_objects), disable=self.progress_bar) as progress_bar:
for document_batch in batched_documents:
vector_ids = []
if add_vectors:
Expand Down Expand Up @@ -234,7 +238,7 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non

result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count) as progress_bar:
with tqdm(total=document_count, disable=self.progress_bar) as progress_bar:
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)

Expand Down
11 changes: 8 additions & 3 deletions haystack/reader/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
num_processes: Optional[int] = None,
max_seq_len: int = 256,
doc_stride: int = 128,
progress_bar: bool = True
):

"""
Expand Down Expand Up @@ -86,14 +87,16 @@ def __init__(
want to debug the Language Model, you might need to disable multiprocessing!
:param max_seq_len: Max sequence length of one input text for the model
:param doc_stride: Length of striding window for splitting long texts (used if ``len(text) > max_seq_len``)

:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""

self.return_no_answers = return_no_answer
self.top_k_per_candidate = top_k_per_candidate
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
task_type="question_answering", max_seq_len=max_seq_len,
doc_stride=doc_stride, num_processes=num_processes, revision=model_version)
task_type="question_answering", max_seq_len=max_seq_len,
doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
disable_tqdm=progress_bar)
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
Expand All @@ -103,6 +106,7 @@ def __init__(
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
self.max_seq_len = max_seq_len
self.use_gpu = use_gpu
self.progress_bar = progress_bar

def train(
self,
Expand Down Expand Up @@ -226,6 +230,7 @@ def train(
evaluate_every=evaluate_every,
device=device,
use_amp=use_amp,
disable_tqdm=self.progress_bar
)


Expand Down
15 changes: 13 additions & 2 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(self,
batch_size: int = 16,
embed_title: bool = True,
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product"
similarity_function: str = "dot_product",
progress_bar: bool = True
):
"""
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
Expand Down Expand Up @@ -83,12 +84,15 @@ def __init__(self,
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""

self.document_store = document_store
self.batch_size = batch_size
self.max_seq_len_passage = max_seq_len_passage
self.max_seq_len_query = max_seq_len_query
self.progress_bar = progress_bar

if document_store is None:
logger.warning("DensePassageRetriever initialized without a document store. "
Expand Down Expand Up @@ -193,7 +197,14 @@ def _get_predictions(self, dicts):
)
all_embeddings = {"query": [], "passages": []}
self.model.eval()
for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=False)):

# When running evaluations etc., we don't want a progress bar for every single query
if len(dataset) == 1:
disable_tqdm=True
else:
disable_tqdm = self.progress_bar

for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
batch = {key: batch[key].to(self.device) for key in batch}

# get logits
Expand Down