/
_context_relevancy.py
98 lines (75 loc) · 3.39 KB
/
_context_relevancy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import annotations
import logging
import typing as t
from dataclasses import dataclass, field
from typing import List
import pysbd
from ragas.llms.prompt import Prompt
from ragas.metrics.base import EvaluationMode, MetricWithLLM
if t.TYPE_CHECKING:
from langchain_core.callbacks.base import Callbacks
logger = logging.getLogger(__name__)
CONTEXT_RELEVANCE = Prompt(
name="context_relevancy",
instruction="""Please extract relevant sentences from the provided context that is absolutely required answer the following question. If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the phrase "Insufficient Information". While extracting candidate sentences you're not allowed to make any changes to sentences from given context.""",
input_keys=["question", "context"],
output_key="candidate sentences",
output_type="json",
)
seg = pysbd.Segmenter(language="en", clean=False)
def sent_tokenize(text: str) -> List[str]:
"""
tokenizer text into sentences
"""
sentences = seg.segment(text)
assert isinstance(sentences, list)
return sentences
@dataclass
class ContextRelevancy(MetricWithLLM):
"""
Extracts sentences from the context that are relevant to the question with
self-consistency checks. The number of relevant sentences and is used as the score.
Attributes
----------
name : str
"""
name: str = "context_relevancy" # type: ignore
evaluation_mode: EvaluationMode = EvaluationMode.qc # type: ignore
context_relevancy_prompt: Prompt = field(default_factory=lambda: CONTEXT_RELEVANCE)
show_deprecation_warning: bool = False
def _compute_score(self, response: str, row: t.Dict) -> float:
context = "\n".join(row["contexts"])
context_sents = sent_tokenize(context)
indices = (
sent_tokenize(response.strip())
if response.lower() != "insufficient information."
else []
)
# print(len(indices))
if len(context_sents) == 0:
return 0
else:
return min(len(indices) / len(context_sents), 1)
async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
assert self.llm is not None, "LLM is not initialized"
if self.show_deprecation_warning:
logger.warning(
"The 'context_relevancy' metric is going to be deprecated soon! Please use the 'context_precision' metric instead. It is a drop-in replacement just a simple search and replace should work." # noqa
)
question, contexts = row["question"], row["contexts"]
result = await self.llm.generate(
self.context_relevancy_prompt.format(
question=question, context="\n".join(contexts)
),
callbacks=callbacks,
)
return self._compute_score(result.generations[0][0].text, row)
def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "set LLM before use"
logger.info(f"Adapting Context Relevancy to {language}")
self.context_relevancy_prompt = self.context_relevancy_prompt.adapt(
language, self.llm, cache_dir
)
def save(self, cache_dir: str | None = None) -> None:
self.context_relevancy_prompt.save(cache_dir)
context_relevancy = ContextRelevancy()