Skip to content

Commit

Permalink
fix: failsafe for non-valid json and failed LLM calls (#7723)
Browse files Browse the repository at this point in the history
* wip

* initial import

* adding tests

* adding params

* adding safeguards for nan in evaluators

* adding docstrings

* fixing tests

* removing unused imports

* adding tests to context and faithfullness evaluators

* fixing docstrings

* nit

* removing unused imports

* adding release notes

* attending PR comments

* fixing tests

* fixing tests

* adding types

* removing unused imports

* Update haystack/components/evaluators/context_relevance.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Update haystack/components/evaluators/faithfulness.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* attending PR comments

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
  • Loading branch information
davidsbatista and shadeMe committed May 23, 2024
1 parent e3dccf4 commit 38747ff
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 18 deletions.
10 changes: 9 additions & 1 deletion haystack/components/evaluators/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
progress_bar: bool = True,
api: str = "openai",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
raise_on_failure: bool = True,
):
"""
Creates an instance of ContextRelevanceEvaluator.
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
Supported APIs: "openai".
:param api_key:
The API key.
:param raise_on_failure:
Whether to raise an exception if the API call fails.
"""
self.instructions = (
"Your task is to judge how relevant the provided context is for answering a question. "
Expand All @@ -117,6 +121,7 @@ def __init__(
examples=self.examples,
api=self.api,
api_key=self.api_key,
raise_on_failure=raise_on_failure,
progress_bar=progress_bar,
)

Expand All @@ -138,7 +143,10 @@ def run(self, questions: List[str], contexts: List[List[str]]) -> Dict[str, Any]
result = super().run(questions=questions, contexts=contexts)

# calculate average statement relevance score per query
for res in result["results"]:
for idx, res in enumerate(result["results"]):
if res is None:
result["results"][idx] = {"statements": [], "statement_scores": [], "score": float("nan")}
continue
if not res["statements"]:
res["score"] = 0
else:
Expand Down
9 changes: 8 additions & 1 deletion haystack/components/evaluators/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
progress_bar: bool = True,
api: str = "openai",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
raise_on_failure: bool = True,
):
"""
Creates an instance of FaithfulnessEvaluator.
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
Supported APIs: "openai".
:param api_key:
The API key.
:param raise_on_failure:
Whether to raise an exception if the API call fails.
"""
self.instructions = (
Expand All @@ -134,6 +137,7 @@ def __init__(
examples=self.examples,
api=self.api,
api_key=self.api_key,
raise_on_failure=raise_on_failure,
progress_bar=progress_bar,
)

Expand All @@ -157,7 +161,10 @@ def run(self, questions: List[str], contexts: List[List[str]], predicted_answers
result = super().run(questions=questions, contexts=contexts, predicted_answers=predicted_answers)

# calculate average statement faithfulness score per query
for res in result["results"]:
for idx, res in enumerate(result["results"]):
if res is None:
result["results"][idx] = {"statements": [], "statement_scores": [], "score": float("nan")}
continue
if not res["statements"]:
res["score"] = 0
else:
Expand Down
71 changes: 57 additions & 14 deletions haystack/components/evaluators/llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
from warnings import warn

from tqdm import tqdm

Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
examples: List[Dict[str, Any]],
progress_bar: bool = True,
*,
raise_on_failure: bool = True,
api: str = "openai",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
):
Expand All @@ -73,6 +75,8 @@ def __init__(
`outputs` parameters.
Each example is a dictionary with keys "inputs" and "outputs"
They contain the input and output as dictionaries respectively.
:param raise_on_failure:
If True, the component will raise an exception on an unsuccessful API call.
:param progress_bar:
Whether to show a progress bar during the evaluation.
:param api:
Expand All @@ -83,6 +87,7 @@ def __init__(
"""
self.validate_init_parameters(inputs, outputs, examples)
self.raise_on_failure = raise_on_failure
self.instructions = instructions
self.inputs = inputs
self.outputs = outputs
Expand Down Expand Up @@ -168,7 +173,11 @@ def run(self, **inputs) -> Dict[str, Any]:
:returns:
A dictionary with a single `results` entry that contains a list of results.
Each result is a dictionary containing the keys as defined in the `outputs` parameter of the LLMEvaluator
and the evaluation results as the values.
and the evaluation results as the values. If an exception occurs for a particular input value, the result
will be `None` for that entry.
:raises ValueError:
Only in the case that `raise_on_failure` is set to True and the received inputs are not lists or have
different lengths, or if the output is not a valid JSON or doesn't contain the expected keys.
"""
self.validate_input_parameters(dict(self.inputs), inputs)

Expand All @@ -177,14 +186,31 @@ def run(self, **inputs) -> Dict[str, Any]:
input_names, values = inputs.keys(), list(zip(*inputs.values()))
list_of_input_names_to_values = [dict(zip(input_names, v)) for v in values]

results = []
results: List[Optional[Dict[str, Any]]] = []
errors = 0
for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar):
prompt = self.builder.run(**input_names_to_values)
result = self.generator.run(prompt=prompt["prompt"])

self.validate_outputs(expected=self.outputs, received=result["replies"][0])
parsed_result = json.loads(result["replies"][0])
results.append(parsed_result)
try:
result = self.generator.run(prompt=prompt["prompt"])
except Exception as e:
msg = f"Error while generating response for prompt: {prompt}. Error: {e}"
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
results.append(None)
errors += 1
continue

if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0]):
parsed_result = json.loads(result["replies"][0])
results.append(parsed_result)
else:
results.append(None)
errors += 1

if errors > 0:
msg = f"LLM evaluator failed for {errors} out of {len(list_of_input_names_to_values)} inputs."
warn(msg)

return {"results": results}

Expand Down Expand Up @@ -299,20 +325,37 @@ def validate_input_parameters(expected: Dict[str, Any], received: Dict[str, Any]
)
raise ValueError(msg)

@staticmethod
def validate_outputs(expected: List[str], received: str) -> None:
def is_valid_json_and_has_expected_keys(self, expected: List[str], received: str) -> bool:
"""
Validate the output.
Output must be a valid JSON with the expected keys.
:param expected:
Names of expected outputs
:param received:
Names of received outputs
:raises ValueError:
If not all expected outputs are present in the received outputs
If the output is not a valid JSON with the expected keys:
- with `raise_on_failure` set to True a ValueError is raised.
- with `raise_on_failure` set to False a warning is issued and False is returned.
:returns:
True if the received output is a valid JSON with the expected keys, False otherwise.
"""
parsed_output = json.loads(received)
try:
parsed_output = json.loads(received)
except json.JSONDecodeError:
msg = "Response from LLM evaluator is not a valid JSON."
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
return False

if not all(output in parsed_output for output in expected):
msg = f"Expected response from LLM evaluator to be JSON with keys {expected}, got {received}."
raise ValueError(msg)
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
return False

return True
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
If an LLM-based evaluator (e.g., `Faithfulness` or `ContextRelevance`) is initialised with `raise_on_failure=False`, and if a call to an LLM fails or an LLM outputs an invalid JSON, the score of the sample is set to `NaN` instead of raising an exception.
The user is notified with a warning indicating the number of requests that failed.
41 changes: 41 additions & 0 deletions test/components/evaluators/test_context_relevance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from typing import List

import math

import pytest

from haystack.components.evaluators import ContextRelevanceEvaluator
Expand Down Expand Up @@ -159,6 +161,45 @@ def test_run_missing_parameters(self, monkeypatch):
with pytest.raises(TypeError, match="missing 2 required positional arguments"):
component.run()

def test_run_returns_nan_raise_on_failure_false(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator(raise_on_failure=False)

def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
[
"The popularity of sports can be measured in various ways, including TV viewership, social media "
"presence, number of participants, and economic impact. Football is undoubtedly the world's most "
"popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and "
"Messi, drawing a followership of more than 4 billion people."
],
[
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects."
],
]
results = component.run(questions=questions, contexts=contexts)

assert math.isnan(results["score"])

assert results["individual_scores"][0] == 1.0
assert math.isnan(results["individual_scores"][1])

assert results["results"][0] == {"statements": ["c", "d"], "statement_scores": [1, 1], "score": 1.0}

assert results["results"][1]["statements"] == []
assert results["results"][1]["statement_scores"] == []
assert math.isnan(results["results"][1]["score"])

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
45 changes: 45 additions & 0 deletions test/components/evaluators/test_faithfulness_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
import math
from typing import List

import numpy as np
import pytest

from haystack.components.evaluators import FaithfulnessEvaluator
Expand Down Expand Up @@ -191,6 +193,49 @@ def test_run_missing_parameters(self, monkeypatch):
with pytest.raises(TypeError, match="missing 3 required positional arguments"):
component.run()

def test_run_returns_nan_raise_on_failure_false(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator(raise_on_failure=False)

def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
[
"The popularity of sports can be measured in various ways, including TV viewership, social media "
"presence, number of participants, and economic impact. Football is undoubtedly the world's most "
"popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and "
"Messi, drawing a followership of more than 4 billion people."
],
[
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects."
],
]
predicted_answers = [
"Football is the most popular sport with around 4 billion followers worldwide.",
"Guido van Rossum.",
]
results = component.run(questions=questions, contexts=contexts, predicted_answers=predicted_answers)

assert math.isnan(results["score"])

assert results["individual_scores"][0] == 1.0
assert math.isnan(results["individual_scores"][1])

assert results["results"][0] == {"statements": ["c", "d"], "statement_scores": [1, 1], "score": 1.0}

assert results["results"][1]["statements"] == []
assert results["results"][1]["statement_scores"] == []
assert math.isnan(results["results"][1]["score"])

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
36 changes: 34 additions & 2 deletions test/components/evaluators/test_llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List

import numpy as np
import pytest

from haystack.components.evaluators import LLMEvaluator
Expand Down Expand Up @@ -379,10 +380,41 @@ def test_invalid_outputs(self, monkeypatch):
],
)
with pytest.raises(ValueError):
component.validate_outputs(expected=["score", "another_expected_output"], received='{"score": 1.0}')
component.is_valid_json_and_has_expected_keys(
expected=["score", "another_expected_output"], received='{"score": 1.0}'
)

with pytest.raises(ValueError):
component.is_valid_json_and_has_expected_keys(expected=["score"], received='{"wrong_name": 1.0}')

def test_output_invalid_json_raise_on_failure_false(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = LLMEvaluator(
instructions="test-instruction",
inputs=[("predicted_answers", List[str])],
outputs=["score"],
examples=[
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
],
raise_on_failure=False,
)
assert (
component.is_valid_json_and_has_expected_keys(expected=["score"], received="some_invalid_json_output")
is False
)

def test_output_invalid_json_raise_on_failure_true(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = LLMEvaluator(
instructions="test-instruction",
inputs=[("predicted_answers", List[str])],
outputs=["score"],
examples=[
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
],
)
with pytest.raises(ValueError):
component.validate_outputs(expected=["score"], received='{"wrong_name": 1.0}')
component.is_valid_json_and_has_expected_keys(expected=["score"], received="some_invalid_json_output")

def test_unsupported_api(self):
with pytest.raises(ValueError):
Expand Down

0 comments on commit 38747ff

Please sign in to comment.