From a89ec79eb1d419b62ed895e053aa6c865f34aa8e Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Fri, 26 Feb 2021 10:35:08 +0100 Subject: [PATCH] fix boolean for disabling tqdm progressbar --- haystack/document_store/faiss.py | 2 +- haystack/document_store/memory.py | 2 +- haystack/document_store/milvus.py | 4 ++-- haystack/reader/farm.py | 4 ++-- haystack/retriever/dense.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index dd45c4821d2..9c202e0ba77 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -210,7 +210,7 @@ def update_embeddings( only_documents_without_embedding=not update_existing_embeddings ) batched_documents = get_batches_from_generator(result, batch_size) - with tqdm(total=document_count, disable=self.progress_bar) as progress_bar: + with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: embeddings = retriever.embed_passages(document_batch) # type: ignore assert len(document_batch) == len(embeddings) diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index 8499a99a430..721b637d6f5 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -200,7 +200,7 @@ def update_embeddings( document_count = len(result) logger.info(f"Updating embeddings for {document_count} docs ...") batched_documents = get_batches_from_generator(result, batch_size) - with tqdm(total=document_count, disable=self.progress_bar) as progress_bar: + with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: embeddings = retriever.embed_passages(document_batch) # type: ignore assert len(document_batch) == len(embeddings) diff --git a/haystack/document_store/milvus.py b/haystack/document_store/milvus.py index 6552e7f3f82..90e64050a93 100644 --- a/haystack/document_store/milvus.py +++ b/haystack/document_store/milvus.py @@ -177,7 +177,7 @@ def write_documents( add_vectors = False if document_objects[0].embedding is None else True batched_documents = get_batches_from_generator(document_objects, batch_size) - with tqdm(total=len(document_objects), disable=self.progress_bar) as progress_bar: + with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: vector_ids = [] if add_vectors: @@ -257,7 +257,7 @@ def update_embeddings( only_documents_without_embedding=not update_existing_embeddings ) batched_documents = get_batches_from_generator(result, batch_size) - with tqdm(total=document_count, disable=self.progress_bar) as progress_bar: + with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: self._delete_vector_ids_from_milvus(documents=document_batch, index=index) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index c8a6e665b69..ba02c530fb0 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -96,7 +96,7 @@ def __init__( self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering", max_seq_len=max_seq_len, doc_stride=doc_stride, num_processes=num_processes, revision=model_version, - disable_tqdm=progress_bar) + disable_tqdm=not progress_bar) self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer @@ -230,7 +230,7 @@ def train( evaluate_every=evaluate_every, device=device, use_amp=use_amp, - disable_tqdm=self.progress_bar + disable_tqdm=not self.progress_bar ) diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 71674bd71ca..4acd1edd46e 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -213,7 +213,7 @@ def _get_predictions(self, dicts): if len(dataset) == 1: disable_tqdm=True else: - disable_tqdm = self.progress_bar + disable_tqdm = not self.progress_bar for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)): batch = {key: batch[key].to(self.device) for key in batch}