Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def evaluate(
row_run_managers.append((row_rm, row_group_cm))

if is_async:
[executor.submit(metric.ascore, row, row_group_cm) for metric in metrics]
[
executor.submit(
metric.ascore, row, row_group_cm, name=f"{metric.name}-{i}"
)
for metric in metrics
]
else:
[executor.submit(metric.score, row, row_group_cm) for metric in metrics]

Expand Down
8 changes: 6 additions & 2 deletions src/ragas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ async def wrapped_callable_async(*args, **kwargs):
else:
return wrapped_callable

def submit(self, callable: t.Callable, *args, **kwargs):
def submit(
self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs
):
if self.is_async:
self.executor = t.cast(asyncio.AbstractEventLoop, self.executor)
callable_with_index = self.wrap_callable_with_index(
Expand All @@ -52,7 +54,9 @@ def submit(self, callable: t.Callable, *args, **kwargs):
# is type correct?
callable_with_index = t.cast(t.Callable, callable_with_index)
self.futures.append(
self.executor.create_task(callable_with_index(*args, **kwargs))
self.executor.create_task(
callable_with_index(*args, **kwargs), name=name
)
)
else:
self.executor = t.cast(ThreadPoolExecutor, self.executor)
Expand Down
37 changes: 27 additions & 10 deletions src/ragas/llms/json_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,38 @@ def safe_load(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None):
start, end = self._find_outermost_json(text)
return json.loads(text[start:end])
except ValueError:
text = self._fix_to_json(text, llm, callbacks)
from ragas.llms.prompt import PromptValue

results = llm.generate_text(
PromptValue(prompt_str=JSON_PROMPT.format(input=text)),
n=1,
callbacks=callbacks,
)
text = results.generations[0][0].text
retry += 1

return {}

def _fix_to_json(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks):
from ragas.llms.prompt import PromptValue
async def asafe_load(
self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None
):
retry = 0
while retry <= self.max_retries:
try:
start, end = self._find_outermost_json(text)
return json.loads(text[start:end])
except ValueError:
from ragas.llms.prompt import PromptValue

results = await llm.agenerate_text(
PromptValue(prompt_str=JSON_PROMPT.format(input=text)),
n=1,
callbacks=callbacks,
)
text = results.generations[0][0].text
retry += 1

# TODO (executor)
results = llm.generate_text(
PromptValue(prompt_str=JSON_PROMPT.format(input=text)),
n=1,
callbacks=callbacks,
)
return results.generations[0][0].text
return {}

def _find_outermost_json(self, text):
stack = []
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/llms/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def to_string(self) -> str:
class Prompt(BaseModel):
"""
Prompt is a class that represents a prompt for the ragas metrics.

Prompt is a class that represents a prompt for the ragas metrics.

Attributes:
Expand Down
15 changes: 9 additions & 6 deletions src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult

CORRECTNESS_PROMPT = Prompt(
name="answer_correctness",
Expand Down Expand Up @@ -110,17 +109,15 @@ def __post_init__(self: t.Self):
llm=self.llm, batch_size=self.batch_size
)

def _compute_statement_presence(self, result: LLMResult) -> float:
def _compute_statement_presence(self, prediction: t.Any) -> float:
assert self.llm is not None, "LLM must be set"

key_map = {
"TP": "statements that are present in both the answer and the ground truth",
"FP": "statements present in the answer but not found in the ground truth",
"FN": "relevant statements found in the ground truth but omitted in the answer", # noqa: E501
}
outputs = result.generations[0]

prediction = json_loader.safe_load(outputs[0].text, self.llm)
prediction = prediction if isinstance(prediction, list) else [prediction]
if prediction:
prediction = [
Expand All @@ -146,7 +143,10 @@ def _score(self, row: t.Dict, callbacks: Callbacks) -> float:
p_value = self.correctness_prompt.format(question=q, ground_truth=g, answer=a)
is_statement_present = self.llm.generate_text(p_value, callbacks=callbacks)

f1_score = self._compute_statement_presence(is_statement_present)
prediction = json_loader.safe_load(
is_statement_present.generations[0][0].text, self.llm
)
f1_score = self._compute_statement_presence(prediction)

if self.weights[1] == 0:
similarity_score = 0
Expand All @@ -169,7 +169,10 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
p_value, callbacks=callbacks
)

f1_score = self._compute_statement_presence(is_statement_present)
prediction = await json_loader.asafe_load(
is_statement_present.generations[0][0].text, self.llm
)
f1_score = self._compute_statement_presence(prediction)

if self.weights[1] == 0:
similarity_score = 0
Expand Down
7 changes: 2 additions & 5 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ class AnswerRelevancy(MetricWithLLM):
def init_model(self):
super().init_model()

if isinstance(self.embeddings, OpenAIEmbeddings):
if self.embeddings.openai_api_key == "no-key":
raise OpenAIKeyNotFound

def calculate_similarity(
self: t.Self, question: str, generated_questions: list[str]
):
Expand Down Expand Up @@ -143,7 +139,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
callbacks=callbacks,
)
response = [
json_loader.safe_load(r.text, self.llm) for r in result.generations[0]
await json_loader.asafe_load(r.text, self.llm)
for r in result.generations[0]
]

return self._calculate_score(response, row)
Expand Down
4 changes: 3 additions & 1 deletion src/ragas/metrics/_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ async def _ascore(
)
responses.append(result.generations[0][0].text)

json_responses = [json_loader.safe_load(item, self.llm) for item in responses]
json_responses = [
await json_loader.asafe_load(item, self.llm) for item in responses
]
score = self._calculate_average_precision(json_responses)
return score

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
result = await self.llm.agenerate_text(
self._create_context_recall_prompt(row), callbacks=callbacks
)
response = json_loader.safe_load(result.generations[0][0].text, self.llm)
response = await json_loader.asafe_load(result.generations[0][0].text, self.llm)

return self._compute_score(response)

Expand Down
1 change: 0 additions & 1 deletion src/ragas/metrics/_context_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dataclasses import dataclass, field
from typing import List

import numpy as np
import pysbd

from ragas.llms.prompt import Prompt
Expand Down
44 changes: 26 additions & 18 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult

from ragas.llms.prompt import PromptValue

Expand All @@ -38,7 +37,9 @@
{
"question": "Cadmium Chloride is slightly soluble in this chemical, it is also called what?",
"answer": "alcohol",
"statements": {"statements": ["Cadmium Chloride is slightly soluble in alcohol."]},
"statements": {
"statements": ["Cadmium Chloride is slightly soluble in alcohol."]
},
},
{
"question": "Were Hitler and Benito Mussolini of the same nationality?",
Expand Down Expand Up @@ -99,7 +100,11 @@
{
"context": """Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time.""",
"statements": """statement_1: Nil""",
"answer": {"statement_1": "Nil", "reason": "The statement is invalid", "verdict": "-1"},
"answer": {
"statement_1": "Nil",
"reason": "The statement is invalid",
"verdict": "-1",
},
},
],
input_keys=["context", "statements"],
Expand Down Expand Up @@ -127,15 +132,12 @@ def _create_answer_prompt(self, row: t.Dict) -> PromptValue:
prompt_value = LONG_FORM_ANSWER_PROMPT.format(question=question, answer=answer)
return prompt_value

def _create_nli_prompt(self, row: t.Dict, answer_result: LLMResult) -> PromptValue:
def _create_nli_prompt(self, row: t.Dict, statements: t.Any) -> PromptValue:
assert self.llm is not None, "llm must be set to compute score"

contexts = row["contexts"]
# check if the statements are support in the contexts
contexts_str: str = "\n".join(contexts)
statements = json_loader.safe_load(
answer_result.generations[0][0].text, self.llm
).get("statements", [])
statements = statements if statements != [] else ["Nil"]
statements_str: str = "\n".join(
[f"statement_{i+1}: {st}" for i, st in enumerate(statements)]
Expand All @@ -145,13 +147,9 @@ def _create_nli_prompt(self, row: t.Dict, answer_result: LLMResult) -> PromptVal
)
return prompt_value

def _compute_score(self, result: LLMResult):
assert self.llm is not None, "llm must be set to compute score"

def _compute_score(self, output: t.Any):
# check the verdicts and compute the score
output = result.generations[0][0]
verdict_score_map = {"1": 1, "0": 0, "null": np.nan}
output = json_loader.safe_load(output.text, self.llm)
output = output if isinstance(output, list) else [output]
faithful_statements = sum(
verdict_score_map.get(
Expand All @@ -173,22 +171,32 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
"""
assert self.llm is not None, "LLM is not set"
p = self._create_answer_prompt(row)
result = await self.llm.agenerate_text(p, callbacks=callbacks)
answer_result = await self.llm.agenerate_text(p, callbacks=callbacks)

p = self._create_nli_prompt(row, result)
statements = await json_loader.asafe_load(
answer_result.generations[0][0].text, self.llm
)
p = self._create_nli_prompt(row, statements.get("statements", []))
result = await self.llm.agenerate_text(p, callbacks=callbacks)

return self._compute_score(result)
json_output = await json_loader.asafe_load(
result.generations[0][0].text, self.llm
)
return self._compute_score(json_output)

def _score(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "LLM is not set"
p = self._create_answer_prompt(row)
result = self.llm.generate_text(p, callbacks=callbacks)
answer_result = self.llm.generate_text(p, callbacks=callbacks)

p = self._create_nli_prompt(row, result)
statements = json_loader.safe_load(
answer_result.generations[0][0].text, self.llm
)
p = self._create_nli_prompt(row, statements.get("statements", []))
result = self.llm.generate_text(p, callbacks=callbacks)

return self._compute_score(result)
json_output = json_loader.safe_load(result.generations[0][0].text, self.llm)
return self._compute_score(json_output)

def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "LLM is not set"
Expand Down
4 changes: 3 additions & 1 deletion src/ragas/metrics/critique.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
)

responses = [r.text for r in result.generations[0]]
safe_loaded_responses = [json_loader.safe_load(r, self.llm) for r in responses]
safe_loaded_responses = [
await json_loader.asafe_load(r, self.llm) for r in responses
]

return self._compute_score(safe_loaded_responses)

Expand Down
42 changes: 23 additions & 19 deletions tests/benchmarks/benchmark_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from ragas.metrics.critique import harmfulness

# data
ds = load_dataset("explodinggradients/fiqa", "ragas_eval")
ds = load_dataset("explodinggradients/amnesty_qa", "english")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add .select(range(0,10)) to fast up the process.

assert isinstance(ds, DatasetDict)
fiqa = ds["baseline"]
eval_dataset = ds["train"]

# metrics
metrics = [
Expand All @@ -33,24 +33,28 @@
answer_similarity,
]

IGNORE_THREADS = False
IGNORE_ASYNCIO = False

if __name__ == "__main__":
# asyncio
start = time.time()
print("ignored")
# _ = evaluate(
# fiqa,
# metrics=[
# faithfulness,
# ],
# is_async=True,
# )
print(f"Time taken [Asyncio]: {time.time() - start:.2f}s")
if not IGNORE_ASYNCIO:
print("Starting [Asyncio]")
start = time.time()
_ = evaluate(
eval_dataset,
metrics=metrics,
is_async=True,
)
print(f"Time taken [Asyncio]: {time.time() - start:.2f}s")

# Threads
start = time.time()
_ = evaluate(
fiqa,
metrics=metrics,
is_async=False,
)
print(f"Time taken [Threads]: {time.time() - start:.2f}s")
if not IGNORE_THREADS:
print("Starting [Threads]")
start = time.time()
_ = evaluate(
eval_dataset,
metrics=metrics,
is_async=False,
)
print(f"Time taken [Threads]: {time.time() - start:.2f}s")