-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AnswerExactMatchEvaluator
#7050
Merged
Merged
Changes from 3 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .answer_exact_match import AnswerExactMatchEvaluator | ||
|
||
__all__ = ["AnswerExactMatchEvaluator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Any, Dict, List | ||
|
||
from haystack import default_from_dict, default_to_dict | ||
from haystack.core.component import component | ||
|
||
|
||
@component | ||
class AnswerExactMatchEvaluator: | ||
""" | ||
Evaluator that checks if the predicted answers matches any of the ground truth answers exactly. | ||
silvanocerza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return default_to_dict(self) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AnswerExactMatchEvaluator": | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(result=float) | ||
def run( | ||
self, questions: List[str], ground_truth_answers: List[List[str]], predicted_answers: List[List[str]] | ||
) -> Dict[str, float]: | ||
""" | ||
Run the AnswerExactMatchEvaluator on the given inputs. | ||
All lists must have the same length. | ||
|
||
:param questions: A list of questions. | ||
:param ground_truth_answers: A list of expected answers for each question. | ||
:param predicted_answers: A list of predicted answers for each question. | ||
:returns: A dictionary with the following outputs: | ||
* `result` - A number from 0.0 to 1.0, it represents the proportion of questions where any predicted | ||
silvanocerza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
answer matched one of the ground truth answers. | ||
""" | ||
if not len(questions) == len(ground_truth_answers) == len(predicted_answers): | ||
raise ValueError("The length of questions, ground_truth_answers, and predicted_answers must be the same.") | ||
|
||
matches = 0 | ||
for truths, extracted in zip(ground_truth_answers, predicted_answers): | ||
if set(truths) & set(extracted): | ||
matches += 1 | ||
|
||
# The proportion of questions where any predicted answer matched one of the ground truth answers | ||
result = matches / len(questions) | ||
|
||
return {"result": result} |
6 changes: 6 additions & 0 deletions
6
releasenotes/notes/exact-match-evaluator-197bb87b65e19d0c.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
features: | ||
- | | ||
Add `AnswerExactMatchEvaluator`, a Component that can be used to calculate the Exact Match metric | ||
given a list of questions, a list of expected answers for each question and the list of predicted | ||
answers for each question. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import pytest | ||
|
||
from haystack.components.evaluators import AnswerExactMatchEvaluator | ||
|
||
|
||
def test_run_with_all_matching(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["Paris"]], | ||
) | ||
|
||
assert result["result"] == 1.0 | ||
|
||
|
||
def test_run_with_no_matching(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Paris"], ["London"]], | ||
) | ||
|
||
assert result["result"] == 0.0 | ||
|
||
|
||
def test_run_with_partial_matching(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["London"]], | ||
) | ||
|
||
assert result["result"] == 0.5 | ||
|
||
|
||
def test_run_with_different_lengths(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
|
||
with pytest.raises(ValueError): | ||
evaluator.run( | ||
questions=["What is the capital of Germany?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["London"]], | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"]], | ||
predicted_answers=[["Berlin"], ["London"]], | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"]], | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that the metrics are (in)directly associated with specific upstream components, I don't think we need an additional
statistical
submodule after all.