From 399639a27c832b5522ae6a903f563da52e255a9a Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Wed, 11 Aug 2021 20:07:20 +0200 Subject: [PATCH 01/11] init --- haystack/eval.py | 453 +++++++++++++++++++----------- test/test_eval.py | 6 +- tutorials/Tutorial5_Evaluation.py | 2 +- 3 files changed, 290 insertions(+), 171 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index 42dc65684ec..f5b1aaebfaa 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -1,5 +1,9 @@ from typing import List, Tuple, Dict, Any, Optional import logging +from transformers import AutoConfig +from sentence_transformers import SentenceTransformer, CrossEncoder +from sklearn.metrics.pairwise import cosine_similarity +import numpy as np from haystack import MultiLabel, Label @@ -148,18 +152,29 @@ class EvalAnswers: open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md). """ - def __init__(self, skip_incorrect_retrieval: bool=True, open_domain: bool=True, debug: bool=False): + def __init__(self, + skip_incorrect_retrieval: bool=True, + open_domain: bool=True, + sas_model=None, + debug: bool=False, + ): """ :param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents :param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer + :param sas_model: Semantic Answer Similarity model string. When set, will use the model to calculate similarity between predictions and labels. + possible models can be sentence transformers or cross encoders trained on Semantic Textual Similarity (STS) data + "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - A good default for multiple languages + "cross-encoder/stsb-roberta-large" - large powerful but slow model for English only :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log """ self.outgoing_edges = 1 - self.init_counts() self.log: List = [] self.debug = debug self.skip_incorrect_retrieval = skip_incorrect_retrieval self.open_domain = open_domain + self.sas_model = sas_model + self.init_counts() + def init_counts(self): self.query_count = 0 @@ -176,6 +191,11 @@ def init_counts(self): self.top_k_em = 0.0 self.top_1_f1 = 0.0 self.top_k_f1 = 0.0 + if self.sas_model is not None: + self.top_1_sas_sum = 0 + self.top_k_sas_sum = 0 + self.top_1_sas = 0.0 + self.top_k_sas = 0.0 def run(self, labels, answers, **kwargs): """Run this node on one sample and its labels""" @@ -201,12 +221,26 @@ def run(self, labels, answers, **kwargs): self.has_answer_count += 1 predictions = [p for p in predictions if p["answer"]] top_1_em, top_1_f1, top_k_em, top_k_f1 = self.evaluate_extraction(multi_labels, predictions) + + # Compute Semantic Answer Similarity if present + if self.sas_model is not None: + gold_labels = [multi_labels.multiple_answers] + predictions_list = [[p["answer"] for p in predictions]] + top_1_sas, top_k_sas = semantic_answer_similarity( + predictions=predictions_list, + gold_labels=gold_labels, + sts_model_path_or_string=self.sas_model) + self.top_1_sas_sum += top_1_sas + self.top_k_sas_sum += top_k_sas + if self.debug: self.log.append({"predictions": predictions, "gold_labels": multi_labels, "top_k_f1": top_k_f1, "top_k_em": top_k_em }) + if self.sas_model: + self.log[-1].update({"top_k_sas":top_k_sas}) self.top_1_em_count += top_1_em self.top_1_f1_sum += top_1_f1 @@ -233,6 +267,9 @@ def update_has_answer_metrics(self): self.top_k_em = self.top_k_em_count / self.has_answer_count self.top_1_f1 = self.top_1_f1_sum / self.has_answer_count self.top_k_f1 = self.top_k_f1_sum / self.has_answer_count + if self.sas_model is not None: + self.top_1_sas = self.top_1_sas_sum / self.has_answer_count + self.top_k_sas = self.top_k_sas_sum / self.has_answer_count def update_no_answer_metrics(self): self.top_1_no_answer = self.top_1_no_answer_count / self.no_answer_count @@ -248,6 +285,9 @@ def print(self, mode): print(f"top k EM: {self.top_k_em:.4f}") print(f"top 1 F1: {self.top_1_f1:.4f}") print(f"top k F1: {self.top_k_f1:.4f}") + if self.sas_model is not None: + print(f"top 1 SAS: {self.top_1_sas:.4f}") + print(f"top k SAS: {self.top_k_sas:.4f}") if self.no_answer_count: print() print(f"no_answer queries: {self.no_answer_count}") @@ -266,6 +306,11 @@ def print(self, mode): print(f"top k EM: {pipeline_top_k_em:.4f}") print(f"top 1 F1: {pipeline_top_1_f1:.4f}") print(f"top k F1: {pipeline_top_k_f1:.4f}") + if self.sas_model is not None: + pipeline_top_1_sas = (self.top_1_sas_sum + self.top_1_no_answer_count) / self.query_count + pipeline_top_k_sas = (self.top_k_sas_sum + self.no_answer_count) / self.query_count + print(f"top 1 SAS: {pipeline_top_1_sas:.4f}") + print(f"top k SAS: {pipeline_top_k_sas:.4f}") if self.no_answer_count: print( "(top k results are likely inflated since the Reader always returns a no_answer prediction in its top k)" @@ -294,177 +339,246 @@ def calculate_f1_str_multi(gold_labels, prediction): results.append(result) return max(results) +# TODO delete? +# def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int): +# number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"] +# +# metrics = { +# "reader_top1_accuracy" : metric_counts["correct_readings_top1"] / correct_retrievals, +# "reader_top1_accuracy_has_answer" : metric_counts["correct_readings_top1_has_answer"] / number_of_has_answer, +# "reader_topk_accuracy" : metric_counts["correct_readings_topk"] / correct_retrievals, +# "reader_topk_accuracy_has_answer" : metric_counts["correct_readings_topk_has_answer"] / number_of_has_answer, +# "reader_top1_em" : metric_counts["exact_matches_top1"] / correct_retrievals, +# "reader_top1_em_has_answer" : metric_counts["exact_matches_top1_has_answer"] / number_of_has_answer, +# "reader_topk_em" : metric_counts["exact_matches_topk"] / correct_retrievals, +# "reader_topk_em_has_answer" : metric_counts["exact_matches_topk_has_answer"] / number_of_has_answer, +# "reader_top1_f1" : metric_counts["summed_f1_top1"] / correct_retrievals, +# "reader_top1_f1_has_answer" : metric_counts["summed_f1_top1_has_answer"] / number_of_has_answer, +# "reader_topk_f1" : metric_counts["summed_f1_topk"] / correct_retrievals, +# "reader_topk_f1_has_answer" : metric_counts["summed_f1_topk_has_answer"] / number_of_has_answer, +# } +# +# if metric_counts["number_of_no_answer"]: +# metrics["reader_top1_no_answer_accuracy"] = metric_counts["correct_no_answers_top1"] / metric_counts[ +# "number_of_no_answer"] +# metrics["reader_topk_no_answer_accuracy"] = metric_counts["correct_no_answers_topk"] / metric_counts[ +# "number_of_no_answer"] +# else: +# metrics["reader_top1_no_answer_accuracy"] = None # type: ignore +# metrics["reader_topk_no_answer_accuracy"] = None # type: ignore +# +# return metrics + +# TODO delete? +# def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]): +# questions_with_correct_doc = [] +# summed_avg_precision_retriever = 0.0 +# summed_reciprocal_rank_retriever = 0.0 +# +# for question in questions_with_docs: +# number_relevant_docs = len(set(question["question"].multiple_document_ids)) +# found_relevant_doc = False +# relevant_docs_found = 0 +# current_avg_precision = 0.0 +# for doc_idx, doc in enumerate(question["docs"]): +# # check if correct doc among retrieved docs +# if doc.id in question["question"].multiple_document_ids: +# if not found_relevant_doc: +# summed_reciprocal_rank_retriever += 1 / (doc_idx + 1) +# relevant_docs_found += 1 +# found_relevant_doc = True +# current_avg_precision += relevant_docs_found / (doc_idx + 1) +# if relevant_docs_found == number_relevant_docs: +# break +# if found_relevant_doc: +# all_relevant_docs = len(set(question["question"].multiple_document_ids)) +# summed_avg_precision_retriever += current_avg_precision / all_relevant_docs +# +# if found_relevant_doc: +# questions_with_correct_doc.append({ +# "question": question["question"], +# "docs": question["docs"] +# }) +# +# return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever + +# TODO delete? +# def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): +# # Calculates evaluation metrics for one question and adds results to counter. +# # check if question is answerable +# if not question.no_answer: +# found_answer = False +# found_em = False +# best_f1 = 0 +# for answer_idx, answer in enumerate(predicted_answers["answers"]): +# if answer["document_id"] in question.multiple_document_ids: +# gold_spans = [{"offset_start": question.multiple_offset_start_in_docs[i], +# "offset_end": question.multiple_offset_start_in_docs[i] + len(question.multiple_answers[i]), +# "doc_id": question.multiple_document_ids[i]} for i in range(len(question.multiple_answers))] # type: ignore +# predicted_span = {"offset_start": answer["offset_start_in_doc"], +# "offset_end": answer["offset_end_in_doc"], +# "doc_id": answer["document_id"]} +# best_f1_in_gold_spans = 0 +# for gold_span in gold_spans: +# if gold_span["doc_id"] == predicted_span["doc_id"]: +# # check if overlap between gold answer and predicted answer +# if not found_answer: +# metric_counts, found_answer = _count_overlap(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore +# +# # check for exact match +# if not found_em: +# metric_counts, found_em = _count_exact_match(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore +# +# # calculate f1 +# current_f1 = _calculate_f1(gold_span, predicted_span) # type: ignore +# if current_f1 > best_f1_in_gold_spans: +# best_f1_in_gold_spans = current_f1 +# # top-1 f1 +# if answer_idx == 0: +# metric_counts["summed_f1_top1"] += best_f1_in_gold_spans +# metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans +# if best_f1_in_gold_spans > best_f1: +# best_f1 = best_f1_in_gold_spans +# +# if found_em: +# break +# # top-k answers: use best f1-score +# metric_counts["summed_f1_topk"] += best_f1 +# metric_counts["summed_f1_topk_has_answer"] += best_f1 +# +# # question not answerable +# else: +# metric_counts["number_of_no_answer"] += 1 +# metric_counts = _count_no_answer(predicted_answers["answers"], metric_counts) +# +# return metric_counts + +# TODO delete? +# def eval_counts_reader_batch(pred: Dict[str, Any], metric_counts: Dict[str, float]): +# # Calculates evaluation metrics for one question and adds results to counter. +# +# # check if question is answerable +# if not pred["label"].no_answer: +# found_answer = False +# found_em = False +# best_f1 = 0 +# for answer_idx, answer in enumerate(pred["answers"]): +# # check if correct document: +# if answer["document_id"] in pred["label"].multiple_document_ids: +# gold_spans = [{"offset_start": pred["label"].multiple_offset_start_in_docs[i], +# "offset_end": pred["label"].multiple_offset_start_in_docs[i] + len(pred["label"].multiple_answers[i]), +# "doc_id": pred["label"].multiple_document_ids[i]} +# for i in range(len(pred["label"].multiple_answers))] # type: ignore +# predicted_span = {"offset_start": answer["offset_start_in_doc"], +# "offset_end": answer["offset_end_in_doc"], +# "doc_id": answer["document_id"]} +# +# best_f1_in_gold_spans = 0 +# for gold_span in gold_spans: +# if gold_span["doc_id"] == predicted_span["doc_id"]: +# # check if overlap between gold answer and predicted answer +# if not found_answer: +# metric_counts, found_answer = _count_overlap( +# gold_span, predicted_span, metric_counts, answer_idx +# ) +# # check for exact match +# if not found_em: +# metric_counts, found_em = _count_exact_match( +# gold_span, predicted_span, metric_counts, answer_idx +# ) +# # calculate f1 +# current_f1 = _calculate_f1(gold_span, predicted_span) +# if current_f1 > best_f1_in_gold_spans: +# best_f1_in_gold_spans = current_f1 +# # top-1 f1 +# if answer_idx == 0: +# metric_counts["summed_f1_top1"] += best_f1_in_gold_spans +# metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans +# if best_f1_in_gold_spans > best_f1: +# best_f1 = best_f1_in_gold_spans +# +# if found_em: +# break +# +# # top-k answers: use best f1-score +# metric_counts["summed_f1_topk"] += best_f1 +# metric_counts["summed_f1_topk_has_answer"] += best_f1 +# +# # question not answerable +# else: +# metric_counts["number_of_no_answer"] += 1 +# metric_counts = _count_no_answer(pred["answers"], metric_counts) +# +# return metric_counts + + +def semantic_answer_similarity(predictions: List[List[str]], + gold_labels: List[List[str]], + sts_model_path_or_string: str="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"): + """ + Computes BERT based similarity of prediction to gold labels. + Returns per QA pair a) the similarity of the most likely prediction to all available gold labels + b) the highest similarity of all predictions to gold labels -def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int): - number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"] - - metrics = { - "reader_top1_accuracy" : metric_counts["correct_readings_top1"] / correct_retrievals, - "reader_top1_accuracy_has_answer" : metric_counts["correct_readings_top1_has_answer"] / number_of_has_answer, - "reader_topk_accuracy" : metric_counts["correct_readings_topk"] / correct_retrievals, - "reader_topk_accuracy_has_answer" : metric_counts["correct_readings_topk_has_answer"] / number_of_has_answer, - "reader_top1_em" : metric_counts["exact_matches_top1"] / correct_retrievals, - "reader_top1_em_has_answer" : metric_counts["exact_matches_top1_has_answer"] / number_of_has_answer, - "reader_topk_em" : metric_counts["exact_matches_topk"] / correct_retrievals, - "reader_topk_em_has_answer" : metric_counts["exact_matches_topk_has_answer"] / number_of_has_answer, - "reader_top1_f1" : metric_counts["summed_f1_top1"] / correct_retrievals, - "reader_top1_f1_has_answer" : metric_counts["summed_f1_top1_has_answer"] / number_of_has_answer, - "reader_topk_f1" : metric_counts["summed_f1_topk"] / correct_retrievals, - "reader_topk_f1_has_answer" : metric_counts["summed_f1_topk_has_answer"] / number_of_has_answer, - } - - if metric_counts["number_of_no_answer"]: - metrics["reader_top1_no_answer_accuracy"] = metric_counts["correct_no_answers_top1"] / metric_counts[ - "number_of_no_answer"] - metrics["reader_topk_no_answer_accuracy"] = metric_counts["correct_no_answers_topk"] / metric_counts[ - "number_of_no_answer"] - else: - metrics["reader_top1_no_answer_accuracy"] = None # type: ignore - metrics["reader_topk_no_answer_accuracy"] = None # type: ignore - - return metrics - - -def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]): - questions_with_correct_doc = [] - summed_avg_precision_retriever = 0.0 - summed_reciprocal_rank_retriever = 0.0 - - for question in questions_with_docs: - number_relevant_docs = len(set(question["question"].multiple_document_ids)) - found_relevant_doc = False - relevant_docs_found = 0 - current_avg_precision = 0.0 - for doc_idx, doc in enumerate(question["docs"]): - # check if correct doc among retrieved docs - if doc.id in question["question"].multiple_document_ids: - if not found_relevant_doc: - summed_reciprocal_rank_retriever += 1 / (doc_idx + 1) - relevant_docs_found += 1 - found_relevant_doc = True - current_avg_precision += relevant_docs_found / (doc_idx + 1) - if relevant_docs_found == number_relevant_docs: - break - if found_relevant_doc: - all_relevant_docs = len(set(question["question"].multiple_document_ids)) - summed_avg_precision_retriever += current_avg_precision / all_relevant_docs - - if found_relevant_doc: - questions_with_correct_doc.append({ - "question": question["question"], - "docs": question["docs"] - }) - - return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever - - -def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): - # Calculates evaluation metrics for one question and adds results to counter. - # check if question is answerable - if not question.no_answer: - found_answer = False - found_em = False - best_f1 = 0 - for answer_idx, answer in enumerate(predicted_answers["answers"]): - if answer["document_id"] in question.multiple_document_ids: - gold_spans = [{"offset_start": question.multiple_offset_start_in_docs[i], - "offset_end": question.multiple_offset_start_in_docs[i] + len(question.multiple_answers[i]), - "doc_id": question.multiple_document_ids[i]} for i in range(len(question.multiple_answers))] # type: ignore - predicted_span = {"offset_start": answer["offset_start_in_doc"], - "offset_end": answer["offset_end_in_doc"], - "doc_id": answer["document_id"]} - best_f1_in_gold_spans = 0 - for gold_span in gold_spans: - if gold_span["doc_id"] == predicted_span["doc_id"]: - # check if overlap between gold answer and predicted answer - if not found_answer: - metric_counts, found_answer = _count_overlap(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore - - # check for exact match - if not found_em: - metric_counts, found_em = _count_exact_match(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore - - # calculate f1 - current_f1 = _calculate_f1(gold_span, predicted_span) # type: ignore - if current_f1 > best_f1_in_gold_spans: - best_f1_in_gold_spans = current_f1 - # top-1 f1 - if answer_idx == 0: - metric_counts["summed_f1_top1"] += best_f1_in_gold_spans - metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans - if best_f1_in_gold_spans > best_f1: - best_f1 = best_f1_in_gold_spans - - if found_em: - break - # top-k answers: use best f1-score - metric_counts["summed_f1_topk"] += best_f1 - metric_counts["summed_f1_topk_has_answer"] += best_f1 - - # question not answerable - else: - metric_counts["number_of_no_answer"] += 1 - metric_counts = _count_no_answer(predicted_answers["answers"], metric_counts) - - return metric_counts + :param predictions: Predictions as list of multiple preds per question + :param gold_labels: Labels as list of multiple possible answers per question + :param sts_model_path_or_string: SentenceTransformers semantic textual similarity model, should be path or string + pointing to downloadable models. + Returns the average of the semantically evaluated best, and Top N predictions as well as the List of + :return best_pred_similarity, all_preds_highest_similarity + """ + assert len(predictions) == len(gold_labels) + + config = AutoConfig.from_pretrained(sts_model_path_or_string) + cross_encoder_used = False + if config.architectures is not None: + cross_encoder_used = any([arch.endswith('ForSequenceClassification') for arch in config.architectures]) + + # Compute similarities + top_1_sim = [] + top_k_sim = [] + + + # Based on Modelstring we can load either Bi Encoders or Cross Encoders. + # Similarity computation changes for both approaches + if cross_encoder_used: + model = CrossEncoder(sts_model_path_or_string) + for preds,labels in zip (predictions,gold_labels): + # TODO put all texts and labels into grid and extract scores afterwards + grid = [] + for p in preds: + for l in labels: + grid.append((p,l)) + scores = model.predict(grid) + top_1_sim.append(np.max(scores[:len(labels)])) + top_k_sim.append(np.max(scores)) -def eval_counts_reader_batch(pred: Dict[str, Any], metric_counts: Dict[str, float]): - # Calculates evaluation metrics for one question and adds results to counter. - - # check if question is answerable - if not pred["label"].no_answer: - found_answer = False - found_em = False - best_f1 = 0 - for answer_idx, answer in enumerate(pred["answers"]): - # check if correct document: - if answer["document_id"] in pred["label"].multiple_document_ids: - gold_spans = [{"offset_start": pred["label"].multiple_offset_start_in_docs[i], - "offset_end": pred["label"].multiple_offset_start_in_docs[i] + len(pred["label"].multiple_answers[i]), - "doc_id": pred["label"].multiple_document_ids[i]} - for i in range(len(pred["label"].multiple_answers))] # type: ignore - predicted_span = {"offset_start": answer["offset_start_in_doc"], - "offset_end": answer["offset_end_in_doc"], - "doc_id": answer["document_id"]} - - best_f1_in_gold_spans = 0 - for gold_span in gold_spans: - if gold_span["doc_id"] == predicted_span["doc_id"]: - # check if overlap between gold answer and predicted answer - if not found_answer: - metric_counts, found_answer = _count_overlap( - gold_span, predicted_span, metric_counts, answer_idx - ) - # check for exact match - if not found_em: - metric_counts, found_em = _count_exact_match( - gold_span, predicted_span, metric_counts, answer_idx - ) - # calculate f1 - current_f1 = _calculate_f1(gold_span, predicted_span) - if current_f1 > best_f1_in_gold_spans: - best_f1_in_gold_spans = current_f1 - # top-1 f1 - if answer_idx == 0: - metric_counts["summed_f1_top1"] += best_f1_in_gold_spans - metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans - if best_f1_in_gold_spans > best_f1: - best_f1 = best_f1_in_gold_spans - - if found_em: - break - - # top-k answers: use best f1-score - metric_counts["summed_f1_topk"] += best_f1 - metric_counts["summed_f1_topk_has_answer"] += best_f1 - - # question not answerable else: - metric_counts["number_of_no_answer"] += 1 - metric_counts = _count_no_answer(pred["answers"], metric_counts) - - return metric_counts + # For Biencoders we can flatten predictions and labels into one list + model = SentenceTransformer(sts_model_path_or_string) + lengths = [] + all_texts = [] + for p, l in zip(predictions, gold_labels): + # TODO potentially exclude (near) exact matches from computations + all_texts.extend(p) + all_texts.extend(l) + lengths.append((len(p), len(l))) + # then compute embeddings + embeddings = model.encode(all_texts) + + # then select which embeddings will be used for similarity computations + current_position = 0 + for i, (len_p, len_l) in enumerate(lengths): + pred_embeddings = embeddings[current_position:current_position + len_p, :] + current_position += len_p + label_embeddings = embeddings[current_position:current_position + len_l, :] + current_position += len_l + sims = cosine_similarity(pred_embeddings, label_embeddings) + top_1_sim.append(np.max(sims[0, :])) + top_k_sim.append(np.max(sims)) + + return np.mean(top_1_sim), np.mean(top_k_sim) def _count_overlap( @@ -554,3 +668,4 @@ def _count_no_answer(answers: List[dict], metric_counts: Dict[str, float]): break return metric_counts + diff --git a/test/test_eval.py b/test/test_eval.py index 1842c0ae7ec..33c710f7cfc 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -106,7 +106,8 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): labels = document_store.get_all_labels_aggregated(index="haystack_test_feedback") eval_retriever = EvalDocuments() - eval_reader = EvalAnswers() + eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-MiniLM-L3-v2",debug=True) + eval_reader_cross = EvalAnswers(sas_model="cross-encoder/stsb-TinyBERT-L-4",debug=True) assert document_store.get_document_count(index="haystack_test_eval_document") == 2 p = Pipeline() @@ -114,6 +115,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) + p.add_node(component=eval_reader_cross, name="EvalAnswers_cross", inputs=["QAReader"]) for l in labels: res = p.run( query=l.question, @@ -125,6 +127,8 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): assert eval_retriever.recall == 1.0 assert round(eval_reader.top_k_f1, 4) == 0.8333 assert eval_reader.top_k_em == 0.5 + assert round(eval_reader.top_k_sas, 3) == 0.800 + assert round(eval_reader_cross.top_k_sas, 3) == 0.671 @pytest.mark.elasticsearch def test_eval_data_split_word(document_store): diff --git a/tutorials/Tutorial5_Evaluation.py b/tutorials/Tutorial5_Evaluation.py index 1c8f2dcc673..3af9f93fe3b 100644 --- a/tutorials/Tutorial5_Evaluation.py +++ b/tutorials/Tutorial5_Evaluation.py @@ -99,7 +99,7 @@ def tutorial5_evaluation(): # Here we initialize the nodes that perform evaluation eval_retriever = EvalDocuments() - eval_reader = EvalAnswers() + eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") ## Evaluate Retriever on its own in closed domain fashion From 33d4cdd1e0e6160617c0e64f6d988ef507990273 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 09:44:30 +0200 Subject: [PATCH 02/11] Add type annotation --- haystack/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index f5b1aaebfaa..dc16c2b0e7b 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -557,8 +557,8 @@ def semantic_answer_similarity(predictions: List[List[str]], else: # For Biencoders we can flatten predictions and labels into one list model = SentenceTransformer(sts_model_path_or_string) - lengths = [] - all_texts = [] + lengths:List[Tuple[int,int]] = [] + all_texts:List[str] = [] for p, l in zip(predictions, gold_labels): # TODO potentially exclude (near) exact matches from computations all_texts.extend(p) From 8d696b076b2f8178ec6dfc8c2804b1a683bc1355 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 09:52:16 +0200 Subject: [PATCH 03/11] Add test case, fix mypy --- haystack/eval.py | 2 +- test/test_eval.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/haystack/eval.py b/haystack/eval.py index dc16c2b0e7b..13bb57be8ec 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -559,7 +559,7 @@ def semantic_answer_similarity(predictions: List[List[str]], model = SentenceTransformer(sts_model_path_or_string) lengths:List[Tuple[int,int]] = [] all_texts:List[str] = [] - for p, l in zip(predictions, gold_labels): + for p, l in zip(predictions, gold_labels): # type: ignore # TODO potentially exclude (near) exact matches from computations all_texts.extend(p) all_texts.extend(l) diff --git a/test/test_eval.py b/test/test_eval.py index 33c710f7cfc..287fd33a9ec 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -108,6 +108,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): eval_retriever = EvalDocuments() eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-MiniLM-L3-v2",debug=True) eval_reader_cross = EvalAnswers(sas_model="cross-encoder/stsb-TinyBERT-L-4",debug=True) + eval_reader_vanila = EvalAnswers() assert document_store.get_document_count(index="haystack_test_eval_document") == 2 p = Pipeline() @@ -116,6 +117,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) p.add_node(component=eval_reader_cross, name="EvalAnswers_cross", inputs=["QAReader"]) + p.add_node(component=eval_reader_vanila, name="EvalAnswers_vanilla", inputs=["QAReader"]) for l in labels: res = p.run( query=l.question, @@ -129,6 +131,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): assert eval_reader.top_k_em == 0.5 assert round(eval_reader.top_k_sas, 3) == 0.800 assert round(eval_reader_cross.top_k_sas, 3) == 0.671 + assert eval_reader.top_k_em == eval_reader_vanila.top_k_em @pytest.mark.elasticsearch def test_eval_data_split_word(document_store): From 5b9be74a0352f50c68f00d7bc54d140d593a0ced Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Thu, 12 Aug 2021 11:21:28 +0200 Subject: [PATCH 04/11] adjust docstrings. rename model path --- haystack/eval.py | 48 +++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index 13bb57be8ec..7d3f40dbf07 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -153,18 +153,21 @@ class EvalAnswers: """ def __init__(self, - skip_incorrect_retrieval: bool=True, - open_domain: bool=True, - sas_model=None, - debug: bool=False, + skip_incorrect_retrieval: bool = True, + open_domain: bool = True, + sas_model: str = None, + debug: bool = False, ): """ :param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents :param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer - :param sas_model: Semantic Answer Similarity model string. When set, will use the model to calculate similarity between predictions and labels. - possible models can be sentence transformers or cross encoders trained on Semantic Textual Similarity (STS) data - "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - A good default for multiple languages - "cross-encoder/stsb-roberta-large" - large powerful but slow model for English only + :param sas_model: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric. + The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps. + Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture. + Models: + - You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data + - Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + - Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large" :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log """ self.outgoing_edges = 1 @@ -175,7 +178,6 @@ def __init__(self, self.sas_model = sas_model self.init_counts() - def init_counts(self): self.query_count = 0 self.correct_retrieval_count = 0 @@ -229,7 +231,7 @@ def run(self, labels, answers, **kwargs): top_1_sas, top_k_sas = semantic_answer_similarity( predictions=predictions_list, gold_labels=gold_labels, - sts_model_path_or_string=self.sas_model) + sas_model_name_or_path=self.sas_model) self.top_1_sas_sum += top_1_sas self.top_k_sas_sum += top_k_sas @@ -514,15 +516,15 @@ def calculate_f1_str_multi(gold_labels, prediction): def semantic_answer_similarity(predictions: List[List[str]], gold_labels: List[List[str]], - sts_model_path_or_string: str="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"): + sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") -> (np.float32, np.float32): """ - Computes BERT based similarity of prediction to gold labels. - Returns per QA pair a) the similarity of the most likely prediction to all available gold labels + Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1. + Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels b) the highest similarity of all predictions to gold labels - :param predictions: Predictions as list of multiple preds per question + :param predictions: Predicted answers as list of multiple preds per question :param gold_labels: Labels as list of multiple possible answers per question - :param sts_model_path_or_string: SentenceTransformers semantic textual similarity model, should be path or string + :param sas_model_name_or_path: SentenceTransformers semantic textual similarity model, should be path or string pointing to downloadable models. Returns the average of the semantically evaluated best, and Top N predictions as well as the List of @@ -530,7 +532,7 @@ def semantic_answer_similarity(predictions: List[List[str]], """ assert len(predictions) == len(gold_labels) - config = AutoConfig.from_pretrained(sts_model_path_or_string) + config = AutoConfig.from_pretrained(sas_model_name_or_path) cross_encoder_used = False if config.architectures is not None: cross_encoder_used = any([arch.endswith('ForSequenceClassification') for arch in config.architectures]) @@ -540,11 +542,11 @@ def semantic_answer_similarity(predictions: List[List[str]], top_k_sim = [] - # Based on Modelstring we can load either Bi Encoders or Cross Encoders. + # Based on Modelstring we can load either Bi-Encoders or Cross Encoders. # Similarity computation changes for both approaches if cross_encoder_used: - model = CrossEncoder(sts_model_path_or_string) - for preds,labels in zip (predictions,gold_labels): + model = CrossEncoder(sas_model_name_or_path) + for preds, labels in zip (predictions,gold_labels): # TODO put all texts and labels into grid and extract scores afterwards grid = [] for p in preds: @@ -555,10 +557,10 @@ def semantic_answer_similarity(predictions: List[List[str]], top_k_sim.append(np.max(scores)) else: - # For Biencoders we can flatten predictions and labels into one list - model = SentenceTransformer(sts_model_path_or_string) - lengths:List[Tuple[int,int]] = [] - all_texts:List[str] = [] + # For Bi-encoders we can flatten predictions and labels into one list + model = SentenceTransformer(sas_model_name_or_path) + lengths: List[Tuple[int,int]] = [] + all_texts: List[str] = [] for p, l in zip(predictions, gold_labels): # type: ignore # TODO potentially exclude (near) exact matches from computations all_texts.extend(p) From 86f83cf776d502a043d877b40df7941083968537 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Thu, 12 Aug 2021 11:31:59 +0200 Subject: [PATCH 05/11] satisfy mypy --- haystack/eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/haystack/eval.py b/haystack/eval.py index 7d3f40dbf07..6323c4536ca 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -516,7 +516,8 @@ def calculate_f1_str_multi(gold_labels, prediction): def semantic_answer_similarity(predictions: List[List[str]], gold_labels: List[List[str]], - sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") -> (np.float32, np.float32): + sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) -> Tuple[np.float32, np.float32]: """ Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1. Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels From 73f92423fb6748cb7c00181559138eecfbeb0de8 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 11:41:25 +0200 Subject: [PATCH 06/11] Change return type of sas function + docstring --- haystack/eval.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index 13bb57be8ec..40eb6369287 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -222,16 +222,17 @@ def run(self, labels, answers, **kwargs): predictions = [p for p in predictions if p["answer"]] top_1_em, top_1_f1, top_k_em, top_k_f1 = self.evaluate_extraction(multi_labels, predictions) - # Compute Semantic Answer Similarity if present + # Compute Semantic Answer Similarity if model is supplied if self.sas_model is not None: + # sas works on batches, so we pack the labels into a list of lists, and unpack the return values as well gold_labels = [multi_labels.multiple_answers] predictions_list = [[p["answer"] for p in predictions]] top_1_sas, top_k_sas = semantic_answer_similarity( predictions=predictions_list, gold_labels=gold_labels, sts_model_path_or_string=self.sas_model) - self.top_1_sas_sum += top_1_sas - self.top_k_sas_sum += top_k_sas + self.top_1_sas_sum += top_1_sas[0] + self.top_k_sas_sum += top_k_sas[0] if self.debug: self.log.append({"predictions": predictions, @@ -517,7 +518,7 @@ def semantic_answer_similarity(predictions: List[List[str]], sts_model_path_or_string: str="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"): """ Computes BERT based similarity of prediction to gold labels. - Returns per QA pair a) the similarity of the most likely prediction to all available gold labels + Returns per QA pair a) the similarity of the most likely prediction to all corresponding gold labels b) the highest similarity of all predictions to gold labels :param predictions: Predictions as list of multiple preds per question @@ -525,8 +526,8 @@ def semantic_answer_similarity(predictions: List[List[str]], :param sts_model_path_or_string: SentenceTransformers semantic textual similarity model, should be path or string pointing to downloadable models. - Returns the average of the semantically evaluated best, and Top N predictions as well as the List of - :return best_pred_similarity, all_preds_highest_similarity + + :return top_1_sas, top_k_sas """ assert len(predictions) == len(gold_labels) @@ -536,8 +537,8 @@ def semantic_answer_similarity(predictions: List[List[str]], cross_encoder_used = any([arch.endswith('ForSequenceClassification') for arch in config.architectures]) # Compute similarities - top_1_sim = [] - top_k_sim = [] + top_1_sas = [] + top_k_sas = [] # Based on Modelstring we can load either Bi Encoders or Cross Encoders. @@ -545,14 +546,14 @@ def semantic_answer_similarity(predictions: List[List[str]], if cross_encoder_used: model = CrossEncoder(sts_model_path_or_string) for preds,labels in zip (predictions,gold_labels): - # TODO put all texts and labels into grid and extract scores afterwards + # TODO add efficient batch mode: put all texts and labels into grid, predict, and extract scores afterwards grid = [] for p in preds: for l in labels: grid.append((p,l)) scores = model.predict(grid) - top_1_sim.append(np.max(scores[:len(labels)])) - top_k_sim.append(np.max(scores)) + top_1_sas.append(np.max(scores[:len(labels)])) + top_k_sas.append(np.max(scores)) else: # For Biencoders we can flatten predictions and labels into one list @@ -575,10 +576,10 @@ def semantic_answer_similarity(predictions: List[List[str]], label_embeddings = embeddings[current_position:current_position + len_l, :] current_position += len_l sims = cosine_similarity(pred_embeddings, label_embeddings) - top_1_sim.append(np.max(sims[0, :])) - top_k_sim.append(np.max(sims)) + top_1_sas.append(np.max(sims[0, :])) + top_k_sas.append(np.max(sims)) - return np.mean(top_1_sim), np.mean(top_k_sim) + return top_1_sas, top_k_sas def _count_overlap( From c6a38c879621270fdb487d57bbc39cd4529efd2f Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 11:46:22 +0200 Subject: [PATCH 07/11] Adjust return type --- haystack/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/eval.py b/haystack/eval.py index e19451df284..23258c8bb43 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -518,7 +518,7 @@ def calculate_f1_str_multi(gold_labels, prediction): def semantic_answer_similarity(predictions: List[List[str]], gold_labels: List[List[str]], sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) -> Tuple[np.float32, np.float32]: + ) -> Tuple[List[float],List[float]]: """ Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1. Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels From 3a7abbb01c84c8132ef40316f948a9ba8dba317c Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 11:57:47 +0200 Subject: [PATCH 08/11] Adjust docstring --- haystack/eval.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/eval.py b/haystack/eval.py index 23258c8bb43..0955b4bdd19 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -165,7 +165,9 @@ def __init__(self, The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps. Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture. Models: - - You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data + - You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data. + Not all cross encoders can be used because of different return types. + If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class - Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large" :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log From d80d364ce22bb6234c59b961c3fee35d1f7b06a9 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 11:58:50 +0200 Subject: [PATCH 09/11] Add german model to docstring --- haystack/eval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/haystack/eval.py b/haystack/eval.py index 0955b4bdd19..b7a0d6ca207 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -170,6 +170,7 @@ def __init__(self, If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class - Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large" + - Large model for German only: "deepset/gbert-large-sts" :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log """ self.outgoing_edges = 1 From 66d0fae52e3d3be948b3ef49e1082df887932797 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 12:00:08 +0200 Subject: [PATCH 10/11] Delete unsused fcts --- haystack/eval.py | 172 ----------------------------------------------- 1 file changed, 172 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index b7a0d6ca207..3b149705c66 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -345,178 +345,6 @@ def calculate_f1_str_multi(gold_labels, prediction): results.append(result) return max(results) -# TODO delete? -# def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int): -# number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"] -# -# metrics = { -# "reader_top1_accuracy" : metric_counts["correct_readings_top1"] / correct_retrievals, -# "reader_top1_accuracy_has_answer" : metric_counts["correct_readings_top1_has_answer"] / number_of_has_answer, -# "reader_topk_accuracy" : metric_counts["correct_readings_topk"] / correct_retrievals, -# "reader_topk_accuracy_has_answer" : metric_counts["correct_readings_topk_has_answer"] / number_of_has_answer, -# "reader_top1_em" : metric_counts["exact_matches_top1"] / correct_retrievals, -# "reader_top1_em_has_answer" : metric_counts["exact_matches_top1_has_answer"] / number_of_has_answer, -# "reader_topk_em" : metric_counts["exact_matches_topk"] / correct_retrievals, -# "reader_topk_em_has_answer" : metric_counts["exact_matches_topk_has_answer"] / number_of_has_answer, -# "reader_top1_f1" : metric_counts["summed_f1_top1"] / correct_retrievals, -# "reader_top1_f1_has_answer" : metric_counts["summed_f1_top1_has_answer"] / number_of_has_answer, -# "reader_topk_f1" : metric_counts["summed_f1_topk"] / correct_retrievals, -# "reader_topk_f1_has_answer" : metric_counts["summed_f1_topk_has_answer"] / number_of_has_answer, -# } -# -# if metric_counts["number_of_no_answer"]: -# metrics["reader_top1_no_answer_accuracy"] = metric_counts["correct_no_answers_top1"] / metric_counts[ -# "number_of_no_answer"] -# metrics["reader_topk_no_answer_accuracy"] = metric_counts["correct_no_answers_topk"] / metric_counts[ -# "number_of_no_answer"] -# else: -# metrics["reader_top1_no_answer_accuracy"] = None # type: ignore -# metrics["reader_topk_no_answer_accuracy"] = None # type: ignore -# -# return metrics - -# TODO delete? -# def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]): -# questions_with_correct_doc = [] -# summed_avg_precision_retriever = 0.0 -# summed_reciprocal_rank_retriever = 0.0 -# -# for question in questions_with_docs: -# number_relevant_docs = len(set(question["question"].multiple_document_ids)) -# found_relevant_doc = False -# relevant_docs_found = 0 -# current_avg_precision = 0.0 -# for doc_idx, doc in enumerate(question["docs"]): -# # check if correct doc among retrieved docs -# if doc.id in question["question"].multiple_document_ids: -# if not found_relevant_doc: -# summed_reciprocal_rank_retriever += 1 / (doc_idx + 1) -# relevant_docs_found += 1 -# found_relevant_doc = True -# current_avg_precision += relevant_docs_found / (doc_idx + 1) -# if relevant_docs_found == number_relevant_docs: -# break -# if found_relevant_doc: -# all_relevant_docs = len(set(question["question"].multiple_document_ids)) -# summed_avg_precision_retriever += current_avg_precision / all_relevant_docs -# -# if found_relevant_doc: -# questions_with_correct_doc.append({ -# "question": question["question"], -# "docs": question["docs"] -# }) -# -# return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever - -# TODO delete? -# def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): -# # Calculates evaluation metrics for one question and adds results to counter. -# # check if question is answerable -# if not question.no_answer: -# found_answer = False -# found_em = False -# best_f1 = 0 -# for answer_idx, answer in enumerate(predicted_answers["answers"]): -# if answer["document_id"] in question.multiple_document_ids: -# gold_spans = [{"offset_start": question.multiple_offset_start_in_docs[i], -# "offset_end": question.multiple_offset_start_in_docs[i] + len(question.multiple_answers[i]), -# "doc_id": question.multiple_document_ids[i]} for i in range(len(question.multiple_answers))] # type: ignore -# predicted_span = {"offset_start": answer["offset_start_in_doc"], -# "offset_end": answer["offset_end_in_doc"], -# "doc_id": answer["document_id"]} -# best_f1_in_gold_spans = 0 -# for gold_span in gold_spans: -# if gold_span["doc_id"] == predicted_span["doc_id"]: -# # check if overlap between gold answer and predicted answer -# if not found_answer: -# metric_counts, found_answer = _count_overlap(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore -# -# # check for exact match -# if not found_em: -# metric_counts, found_em = _count_exact_match(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore -# -# # calculate f1 -# current_f1 = _calculate_f1(gold_span, predicted_span) # type: ignore -# if current_f1 > best_f1_in_gold_spans: -# best_f1_in_gold_spans = current_f1 -# # top-1 f1 -# if answer_idx == 0: -# metric_counts["summed_f1_top1"] += best_f1_in_gold_spans -# metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans -# if best_f1_in_gold_spans > best_f1: -# best_f1 = best_f1_in_gold_spans -# -# if found_em: -# break -# # top-k answers: use best f1-score -# metric_counts["summed_f1_topk"] += best_f1 -# metric_counts["summed_f1_topk_has_answer"] += best_f1 -# -# # question not answerable -# else: -# metric_counts["number_of_no_answer"] += 1 -# metric_counts = _count_no_answer(predicted_answers["answers"], metric_counts) -# -# return metric_counts - -# TODO delete? -# def eval_counts_reader_batch(pred: Dict[str, Any], metric_counts: Dict[str, float]): -# # Calculates evaluation metrics for one question and adds results to counter. -# -# # check if question is answerable -# if not pred["label"].no_answer: -# found_answer = False -# found_em = False -# best_f1 = 0 -# for answer_idx, answer in enumerate(pred["answers"]): -# # check if correct document: -# if answer["document_id"] in pred["label"].multiple_document_ids: -# gold_spans = [{"offset_start": pred["label"].multiple_offset_start_in_docs[i], -# "offset_end": pred["label"].multiple_offset_start_in_docs[i] + len(pred["label"].multiple_answers[i]), -# "doc_id": pred["label"].multiple_document_ids[i]} -# for i in range(len(pred["label"].multiple_answers))] # type: ignore -# predicted_span = {"offset_start": answer["offset_start_in_doc"], -# "offset_end": answer["offset_end_in_doc"], -# "doc_id": answer["document_id"]} -# -# best_f1_in_gold_spans = 0 -# for gold_span in gold_spans: -# if gold_span["doc_id"] == predicted_span["doc_id"]: -# # check if overlap between gold answer and predicted answer -# if not found_answer: -# metric_counts, found_answer = _count_overlap( -# gold_span, predicted_span, metric_counts, answer_idx -# ) -# # check for exact match -# if not found_em: -# metric_counts, found_em = _count_exact_match( -# gold_span, predicted_span, metric_counts, answer_idx -# ) -# # calculate f1 -# current_f1 = _calculate_f1(gold_span, predicted_span) -# if current_f1 > best_f1_in_gold_spans: -# best_f1_in_gold_spans = current_f1 -# # top-1 f1 -# if answer_idx == 0: -# metric_counts["summed_f1_top1"] += best_f1_in_gold_spans -# metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans -# if best_f1_in_gold_spans > best_f1: -# best_f1 = best_f1_in_gold_spans -# -# if found_em: -# break -# -# # top-k answers: use best f1-score -# metric_counts["summed_f1_topk"] += best_f1 -# metric_counts["summed_f1_topk_has_answer"] += best_f1 -# -# # question not answerable -# else: -# metric_counts["number_of_no_answer"] += 1 -# metric_counts = _count_no_answer(pred["answers"], metric_counts) -# -# return metric_counts - def semantic_answer_similarity(predictions: List[List[str]], gold_labels: List[List[str]], From e1d6c64c449ca1fe11f52473fb8e35c91377f2c5 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 12:03:03 +0200 Subject: [PATCH 11/11] reformat --- haystack/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index 3b149705c66..7fe9e3dedb3 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -322,6 +322,7 @@ def print(self, mode): "(top k results are likely inflated since the Reader always returns a no_answer prediction in its top k)" ) + def get_label(labels, node_id): if type(labels) in [Label, MultiLabel]: ret = labels @@ -330,6 +331,7 @@ def get_label(labels, node_id): ret = labels[node_id] return ret + def calculate_em_str_multi(gold_labels, prediction): for gold_label in gold_labels: result = calculate_em_str(gold_label, prediction) @@ -374,7 +376,6 @@ def semantic_answer_similarity(predictions: List[List[str]], top_1_sas = [] top_k_sas = [] - # Based on Modelstring we can load either Bi-Encoders or Cross Encoders. # Similarity computation changes for both approaches if cross_encoder_used: @@ -388,7 +389,6 @@ def semantic_answer_similarity(predictions: List[List[str]], scores = model.predict(grid) top_1_sas.append(np.max(scores[:len(labels)])) top_k_sas.append(np.max(scores)) - else: # For Bi-encoders we can flatten predictions and labels into one list model = SentenceTransformer(sas_model_name_or_path)