diff --git a/src/ragas/embeddings/base.py b/src/ragas/embeddings/base.py index 8c722b21d..355697161 100644 --- a/src/ragas/embeddings/base.py +++ b/src/ragas/embeddings/base.py @@ -7,8 +7,8 @@ import numpy as np from langchain.embeddings import AzureOpenAIEmbeddings as BaseAzureOpenAIEmbeddings -from langchain.embeddings import OpenAIEmbeddings as BaseOpenAIEmbeddings from langchain.embeddings import FastEmbedEmbeddings as BaseFastEmbedEmbeddings +from langchain.embeddings import OpenAIEmbeddings as BaseOpenAIEmbeddings from langchain.schema.embeddings import Embeddings from pydantic.dataclasses import dataclass @@ -47,6 +47,7 @@ def validate_api_key(self): else: raise OpenAIKeyNotFound + class FastEmbedEmbeddings(BaseFastEmbedEmbeddings, RagasEmbeddings): """ Find the list of supported models at: @@ -64,6 +65,7 @@ def validate_api_key(self): """ pass + class AzureOpenAIEmbeddings(BaseAzureOpenAIEmbeddings, RagasEmbeddings): azure_endpoint: t.Optional[str] = None deployment: t.Optional[str] = None diff --git a/src/ragas/llms/prompt.py b/src/ragas/llms/prompt.py index d6df3d05e..61d2bb589 100644 --- a/src/ragas/llms/prompt.py +++ b/src/ragas/llms/prompt.py @@ -1,6 +1,8 @@ from __future__ import annotations import json +import logging +import os import typing as t from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate @@ -8,17 +10,22 @@ from langchain_core.prompt_values import PromptValue from langchain_core.pydantic_v1 import root_validator +from ragas.llms import RagasLLM +from ragas.utils import RAGAS_CACHE_HOME, json_loader + class Prompt(PromptValue): """ - RagasPrompt is a class that represents a prompt for the ragas metrics. + Prompt is a class that represents a prompt for the ragas metrics. """ + name: str instruction: str examples: t.List[t.Dict[str, t.Any]] = [] input_keys: t.List[str] output_key: str output_type: str = "json" + language = "en" @root_validator def validate_prompt(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: @@ -44,10 +51,11 @@ def validate_prompt(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: raise ValueError( f"example {no+1} does not have the variable {output_key} in the definition" ) - if values["output_type"] == "json": + if values["output_type"].lower() == "json": try: if output_key in example: - json.loads(example[output_key]) + if isinstance(example[output_key], str): + json.loads(example[output_key]) except ValueError as e: raise ValueError( f"{output_key} in example {no+1} is not in valid json format: {e}" @@ -64,6 +72,7 @@ def to_string(self) -> str: # Format the examples to match the Langchain prompt template for example in self.examples: for key, value in example.items(): + value = json.dumps(value, ensure_ascii=False).encode("utf8").decode() value = ( value.replace("{", "{{").replace("}", "}}") if self.output_type.lower() == "json" @@ -90,6 +99,7 @@ def get_example_str(self, example_no: int) -> str: example = self.examples[example_no] example_str = "" for key, value in example.items(): + value = json.dumps(value, ensure_ascii=False).encode("utf8").decode() value = ( value.replace("{", "{{").replace("}", "}}") if self.output_type.lower() == "json" @@ -100,7 +110,7 @@ def get_example_str(self, example_no: int) -> str: def format(self, **kwargs: t.Any) -> ChatPromptTemplate: """ - Format the RagasPrompt object into a ChatPromptTemplate object to be used in metrics. + Format the Prompt object into a ChatPromptTemplate object to be used in metrics. """ if set(self.input_keys) != set(kwargs.keys()): raise ValueError( @@ -109,3 +119,112 @@ def format(self, **kwargs: t.Any) -> ChatPromptTemplate: prompt = self.to_string() human_prompt = HumanMessagePromptTemplate.from_template(prompt) return ChatPromptTemplate.from_messages([human_prompt.format(**kwargs)]) + + def adapt( + self, language: str, llm: RagasLLM, cache_dir: t.Optional[str] = None + ) -> Prompt: + # TODO: Add callbacks + cache_dir = cache_dir if cache_dir else RAGAS_CACHE_HOME + if os.path.exists(os.path.join(cache_dir, language, f"{self.name}.json")): + return self._load(language, self.name, cache_dir) + + prompts = [] + for example in self.examples: + prompts.extend( + [ + str_translation.format( + translate_to=language, input=example.get(key) + ) + for key in self.input_keys + ] + ) + prompts.append( + json_translatation.format( + translate_to=language, input=example.get(self.output_key) + ) + if self.output_type.lower() == "json" + else str_translation.format( + translate_to=language, input=example.get(self.output_key) + ) + ) + + results = [result[0].text for result in llm.generate(prompts).generations] + per_example_items = len(self.input_keys) + 1 + grouped_results = [ + results[i : i + per_example_items] + for i in range(0, len(results), per_example_items) + ] + assert len(grouped_results) == len( + self.examples + ), "examples and adapted examples must be of equal length" + for i, example in enumerate(grouped_results): + example_dict = {} + example_dict.update( + {k: v for k, v in zip(self.input_keys, example[: len(self.input_keys)])} + ) + example_dict[self.output_key] = ( + json_loader.safe_load(example[-1], llm) + if self.output_type.lower() == "json" + else example[-1] + ) + + self.examples[i] = example_dict + + self.language = language + + # TODO:Validate the prompt after adaptation + + return self + + def save(self, cache_dir: t.Optional[str] = None) -> None: + cache_dir = cache_dir if cache_dir else RAGAS_CACHE_HOME + cache_dir = os.path.join(cache_dir, self.language) + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + cache_path = os.path.join(cache_dir, f"{self.name}.json") + with open(cache_path, "w") as file: + json.dump(self.to_json(), file, indent=4) + + @classmethod + def _load(cls, language: str, name: str, cache_dir: str) -> Prompt: + logging.log(logging.INFO, f"Loading {name} from {cache_dir}") + path = os.path.join(cache_dir, language, f"{name}.json") + return cls(**json.load(open(path))["kwargs"]) + + +str_translation = Prompt( + name="str_translation", + instruction="Language translation", + examples=[ + { + "translate_to": "hindi", + "input": "Who was Albert Einstein and what is he best known for?", + "output": "अल्बर्ट आइंस्टीन कौन थे और वे किसके लिए सबसे ज्यादा प्रसिद्ध हैं?", + }, + ], + input_keys=["translate_to", "input"], + output_key="output", + output_type="str", +) + +json_translatation = Prompt( + name="json_translation", + instruction="Translate values in given json to target language ", + examples=[ + { + "translate_to": "hindi", + "input": """{"statements": [ + "Albert Einstein was born in Germany.", + "Albert Einstein was best known for his theory of relativity." + ]}""", + "output": """{"statements": [ + "अल्बर्ट आइंस्टीन का जन्म जर्मनी में हुआ था।", + "अल्बर्ट आइंस्टीन अपने सापेक्षता के सिद्धांत के लिए सबसे अधिक प्रसिद्ध थे।" + ]}""", + } + ], + input_keys=["translate_to", "input"], + output_key="output", + output_type="JSON", +) diff --git a/src/ragas/metrics/_answer_correctness.py b/src/ragas/metrics/_answer_correctness.py index 8a90dcfb9..89f83005e 100644 --- a/src/ragas/metrics/_answer_correctness.py +++ b/src/ragas/metrics/_answer_correctness.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from dataclasses import dataclass, field @@ -7,15 +8,18 @@ from datasets import Dataset from langchain.callbacks.manager import CallbackManager, trace_as_chain_group -from ragas.utils import json_loader from ragas.llms.prompt import Prompt from ragas.metrics._answer_similarity import AnswerSimilarity from ragas.metrics.base import EvaluationMode, MetricWithLLM +from ragas.utils import json_loader + +logger = logging.getLogger(__name__) if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks CORRECTNESS_PROMPT = Prompt( + name="answer_correctness", instruction="""Extract following from given question and ground truth""", examples=[ { @@ -71,13 +75,16 @@ class AnswerCorrectness(MetricWithLLM): name: str = "answer_correctness" # type: ignore[reportIncompatibleMethodOverride] evaluation_mode: EvaluationMode = EvaluationMode.qga # type: ignore[reportIncompatibleMethodOverride] + correctness_prompt: Prompt = field(default_factory=lambda: CORRECTNESS_PROMPT) batch_size: int = 15 weights: list[float] = field(default_factory=lambda: [0.75, 0.25]) answer_similarity: AnswerSimilarity | None = None def __post_init__(self: t.Self): if len(self.weights) != 2: - raise ValueError("Expects a list of two weights. First for factuality, second for semantic similarity") + raise ValueError( + "Expects a list of two weights. First for factuality, second for semantic similarity" + ) if all([w == 0 for w in self.weights]): raise ValueError("At least one weight must be non-zero") if not all([w >= 0 for w in self.weights]): @@ -88,6 +95,15 @@ def __post_init__(self: t.Self): llm=self.llm, batch_size=self.batch_size ) + def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: + logger.info(f"Adapting AnswerCorrectness metric to {language}") + self.correctness_prompt = self.correctness_prompt.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: t.Optional[str] = None) -> None: + self.correctness_prompt.save(cache_dir) + def _score_batch( self: t.Self, dataset: Dataset, @@ -107,7 +123,9 @@ def _score_batch( ) as batch_group: for q, a, g in zip(question, answer, ground_truths): prompts.append( - CORRECTNESS_PROMPT.format(question=q, ground_truth=g[0], answer=a) + self.correctness_prompt.format( + question=q, ground_truth=g[0], answer=a + ) ) result = self.llm.generate(prompts, callbacks=batch_group) @@ -121,7 +139,9 @@ def _score_batch( f1_score = [] for prediction in outputs: prediction = json_loader.safe_load(prediction[0].text, self.llm) - prediction = prediction if isinstance(prediction, list) else [] + prediction = ( + prediction if isinstance(prediction, list) else [prediction] + ) if prediction: prediction = [ item.get(key_map[k], np.nan) diff --git a/src/ragas/metrics/_answer_relevance.py b/src/ragas/metrics/_answer_relevance.py index 390e94e69..6b736aabd 100644 --- a/src/ragas/metrics/_answer_relevance.py +++ b/src/ragas/metrics/_answer_relevance.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from dataclasses import dataclass, field @@ -10,9 +11,11 @@ from ragas.embeddings.base import embedding_factory from ragas.exceptions import OpenAIKeyNotFound -from ragas.utils import json_loader from ragas.llms.prompt import Prompt from ragas.metrics.base import EvaluationMode, MetricWithLLM +from ragas.utils import json_loader + +logger = logging.getLogger(__name__) if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks @@ -20,6 +23,7 @@ from ragas.embeddings.base import RagasEmbeddings QUESTION_GEN = Prompt( + name="question_generation", instruction="""Generate a question for the given answer and Identify if answer is noncommittal""", examples=[ { @@ -72,6 +76,7 @@ class AnswerRelevancy(MetricWithLLM): name: str = "answer_relevancy" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore + question_generation: Prompt = field(default_factory=lambda: QUESTION_GEN) batch_size: int = 15 strictness: int = 3 embeddings: RagasEmbeddings = field(default_factory=embedding_factory) @@ -83,6 +88,15 @@ def init_model(self): if self.embeddings.openai_api_key == "no-key": raise OpenAIKeyNotFound + def adapt(self, language: str, cache_dir: str | None = None) -> None: + logger.info(f"Adapting AnswerRelevancy metric to {language}") + self.question_generation = self.question_generation.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: str | None = None) -> None: + self.question_generation.save(cache_dir) + def _score_batch( self: t.Self, dataset: Dataset, @@ -101,7 +115,9 @@ def _score_batch( ) as batch_group: prompts = [] for ans, ctx in zip(answers, contexts): - prompts.append(QUESTION_GEN.format(answer=ans, context="\n".join(ctx))) + prompts.append( + self.question_generation.format(answer=ans, context="\n".join(ctx)) + ) results = self.llm.generate( prompts, diff --git a/src/ragas/metrics/_answer_similarity.py b/src/ragas/metrics/_answer_similarity.py index df42554e4..350c44c70 100644 --- a/src/ragas/metrics/_answer_similarity.py +++ b/src/ragas/metrics/_answer_similarity.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from dataclasses import dataclass, field @@ -19,6 +20,8 @@ from ragas.embeddings.base import RagasEmbeddings +logger = logging.getLogger(__name__) + @dataclass class AnswerSimilarity(MetricWithLLM): diff --git a/src/ragas/metrics/_context_precision.py b/src/ragas/metrics/_context_precision.py index 782a75f6d..a8f40b814 100644 --- a/src/ragas/metrics/_context_precision.py +++ b/src/ragas/metrics/_context_precision.py @@ -2,20 +2,21 @@ import logging import typing as t -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np from datasets import Dataset from langchain.callbacks.manager import CallbackManager, trace_as_chain_group -from ragas.utils import json_loader from ragas.llms.prompt import Prompt from ragas.metrics.base import EvaluationMode, MetricWithLLM +from ragas.utils import json_loader if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks CONTEXT_PRECISION = Prompt( + name="context_precision", instruction="""Given question, answer and context verify if the context was useful in arriving at the given answer. Give verdict as "1" if useful and "0" if not. """, examples=[ { @@ -70,8 +71,18 @@ class ContextPrecision(MetricWithLLM): name: str = "context_precision" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore + context_precision_prompt: Prompt = field(default_factory=lambda: CONTEXT_PRECISION) batch_size: int = 15 + def adapt(self, language: str, cache_dir: str | None = None) -> None: + logging.info(f"Adapting Context Precision to {language}") + self.context_precision_prompt = self.context_precision_prompt.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: str | None = None) -> None: + self.context_precision_prompt.save(cache_dir) + def get_dataset_attributes(self, dataset: Dataset): answer = "ground_truths" if answer not in dataset.features.keys(): @@ -97,7 +108,9 @@ def _score_batch( ) as batch_group: for qstn, ctx, answer in zip(questions, contexts, answers): human_prompts = [ - CONTEXT_PRECISION.format(question=qstn, context=c, answer=answer) + self.context_precision_prompt.format( + question=qstn, context=c, answer=answer + ) for c in ctx ] diff --git a/src/ragas/metrics/_context_recall.py b/src/ragas/metrics/_context_recall.py index 68fff571e..faabce302 100644 --- a/src/ragas/metrics/_context_recall.py +++ b/src/ragas/metrics/_context_recall.py @@ -1,20 +1,24 @@ from __future__ import annotations +import logging import typing as t -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np from datasets import Dataset from langchain.callbacks.manager import CallbackManager, trace_as_chain_group -from ragas.utils import json_loader from ragas.llms.prompt import Prompt from ragas.metrics.base import EvaluationMode, MetricWithLLM +from ragas.utils import json_loader if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks +logger = logging.getLogger(__name__) + CONTEXT_RECALL_RA = Prompt( + name="context_recall", instruction="""Given a context, and an answer, analyze each sentence in the answer and classify if the sentence can be attributed to the given context or not. Use only "Yes" (1) or "No" (0) as a binary classification. Output json with reason.""", examples=[ { @@ -79,8 +83,18 @@ class ContextRecall(MetricWithLLM): name: str = "context_recall" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore + context_recall_prompt: Prompt = field(default_factory=lambda: CONTEXT_RECALL_RA) batch_size: int = 15 + def adapt(self, language: str, cache_dir: str | None = None) -> None: + logger.info(f"Adapting Context Recall to {language}") + self.context_recall_prompt = self.context_recall_prompt.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: str | None = None) -> None: + self.context_recall_prompt.save(cache_dir) + def _score_batch( self: t.Self, dataset: Dataset, @@ -102,7 +116,9 @@ def _score_batch( gt = "\n".join(gt) if isinstance(gt, list) else gt ctx = "\n".join(ctx) if isinstance(ctx, list) else ctx prompts.append( - CONTEXT_RECALL_RA.format(question=qstn, context=ctx, answer=gt) + self.context_recall_prompt.format( + question=qstn, context=ctx, answer=gt + ) ) responses: list[list[str]] = [] diff --git a/src/ragas/metrics/_context_relevancy.py b/src/ragas/metrics/_context_relevancy.py index d4b716934..a43fa4454 100644 --- a/src/ragas/metrics/_context_relevancy.py +++ b/src/ragas/metrics/_context_relevancy.py @@ -2,7 +2,7 @@ import logging import typing as t -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List import numpy as np @@ -16,7 +16,10 @@ if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks +logger = logging.getLogger(__name__) + CONTEXT_RELEVANCE = Prompt( + name="context_relevancy", instruction="""Please extract relevant sentences from the provided context that is absolutely required answer the following question. If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the phrase "Insufficient Information". While extracting candidate sentences you're not allowed to make any changes to sentences from given context.""", input_keys=["question", "context"], output_key="candidate sentences", @@ -51,11 +54,18 @@ class ContextRelevancy(MetricWithLLM): name: str = "context_relevancy" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qc # type: ignore + context_relevancy_prompt: Prompt = field(default_factory=lambda: CONTEXT_RELEVANCE) batch_size: int = 15 show_deprecation_warning: bool = False - def __post_init__(self: t.Self): - pass + def adapt(self, language: str, cache_dir: str | None = None) -> None: + logger.info(f"Adapting Context Relevancy to {language}") + self.context_relevancy_prompt = self.context_relevancy_prompt.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: str | None = None) -> None: + self.context_relevancy_prompt.save(cache_dir) def _score_batch( self: t.Self, @@ -64,7 +74,7 @@ def _score_batch( callback_group_name: str = "batch", ) -> list[float]: if self.show_deprecation_warning: - logging.warning( + logger.warning( "The 'context_relevancy' metric is going to be deprecated soon! Please use the 'context_precision' metric instead. It is a drop-in replacement just a simple search and replace should work." # noqa ) prompts = [] @@ -76,7 +86,9 @@ def _score_batch( ) as batch_group: for q, c in zip(questions, contexts): prompts.append( - CONTEXT_RELEVANCE.format(question=q, context="\n".join(c)) + self.context_relevancy_prompt.format( + question=q, context="\n".join(c) + ) ) responses: list[list[str]] = [] diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 3c49ba685..9df7ed897 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -1,21 +1,24 @@ from __future__ import annotations +import logging import typing as t -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np from langchain.callbacks.manager import CallbackManager, trace_as_chain_group -from ragas.utils import json_loader from ragas.llms.prompt import Prompt from ragas.metrics.base import EvaluationMode, MetricWithLLM +from ragas.utils import json_loader if t.TYPE_CHECKING: from datasets import Dataset from langchain.callbacks.base import Callbacks +logger = logging.getLogger(__name__) LONG_FORM_ANSWER_PROMPT = Prompt( + name="long_form_answer", instruction="Create one or more statements from each sentence in the given answer.", examples=[ { @@ -55,6 +58,7 @@ NLI_STATEMENTS_MESSAGE = Prompt( + name="nli_statements", instruction="Natural language inference. Use only 'Yes' (1), 'No' (0) and 'Null' (-1) as verdict.", examples=[ { @@ -118,8 +122,27 @@ class Faithfulness(MetricWithLLM): name: str = "faithfulness" # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore + long_form_answer_prompt: Prompt = field( + default_factory=lambda: LONG_FORM_ANSWER_PROMPT + ) + nli_statements_message: Prompt = field( + default_factory=lambda: NLI_STATEMENTS_MESSAGE + ) batch_size: int = 15 + def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: + logger.info(f"Adapting Faithfulness metric to {language}") + self.long_form_answer_prompt = self.long_form_answer_prompt.adapt( + language, self.llm, cache_dir + ) + self.nli_statements_message = self.nli_statements_message.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: t.Optional[str] = None) -> None: + self.long_form_answer_prompt.save(cache_dir) + self.nli_statements_message.save(cache_dir) + def _score_batch( self: t.Self, dataset: Dataset, @@ -157,7 +180,7 @@ def _score_batch( [f"statement_{i+1}: {st}" for i, st in enumerate(statements)] ) contexts_str: str = "\n".join(context) - human_prompt = NLI_STATEMENTS_MESSAGE.format( + human_prompt = self.nli_statements_message.format( context=contexts_str, statements=statements_str ) prompts.append(human_prompt) diff --git a/src/ragas/metrics/base.py b/src/ragas/metrics/base.py index 097058733..cfe7225c9 100644 --- a/src/ragas/metrics/base.py +++ b/src/ragas/metrics/base.py @@ -17,13 +17,11 @@ from tqdm import tqdm from ragas.embeddings.base import RagasEmbeddings -from ragas.llms import llm_factory +from ragas.llms import RagasLLM, llm_factory if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks - from ragas.llms import RagasLLM - def make_batches(total_size: int, batch_size: int) -> list[range]: """ @@ -64,6 +62,20 @@ def init_model(self): """ ... + # @abstractmethod + def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: + """ + Adapt the metric to a different language. + """ + pass + + # @abstractmethod + def save(self, cache_dir: t.Optional[str] = None) -> None: + """ + Save the metric to a path. + """ + pass + def score( self: t.Self, dataset: Dataset, diff --git a/src/ragas/metrics/critique.py b/src/ragas/metrics/critique.py index b9b97976c..a4b3db30b 100644 --- a/src/ragas/metrics/critique.py +++ b/src/ragas/metrics/critique.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from collections import Counter from dataclasses import dataclass, field @@ -8,18 +9,20 @@ from datasets import Dataset from langchain.callbacks.manager import CallbackManager, trace_as_chain_group -from ragas.utils import json_loader from ragas.llms import llm_factory from ragas.llms.prompt import Prompt from ragas.metrics.base import EvaluationMode, MetricWithLLM +from ragas.utils import json_loader if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks from ragas.llms import RagasLLM +logger = logging.getLogger(__name__) CRITIQUE_PROMPT = Prompt( + name="critique", instruction="Given a input and submission. Evaluate the submission only using the given criteria. Use only 'Yes' (1) and 'No' (0) as verdict.", examples=[ { @@ -63,6 +66,7 @@ class AspectCritique(MetricWithLLM): name: str = field(default="", repr=True) # type: ignore evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore + critic_prompt: Prompt = field(default_factory=lambda: CRITIQUE_PROMPT) definition: str = field(default="", repr=True) strictness: int = field(default=1, repr=False) batch_size: int = field(default=15, repr=False) @@ -82,6 +86,13 @@ def __post_init__(self: t.Self): self.strictness if self.strictness % 2 != 0 else self.strictness + 1 ) + def adapt(self, language: str, cache_dir: str | None = None) -> None: + logger.info(f"Adapting Critic to {language}") + self.critic_prompt.adapt(language, self.llm, cache_dir) + + def save(self, cache_dir: str | None = None) -> None: + self.critic_prompt.save(cache_dir) + def prompt_format( self: t.Self, question: str, diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py index 096dd816a..f53206ba9 100644 --- a/src/ragas/testset/testset_generator.py +++ b/src/ragas/testset/testset_generator.py @@ -63,9 +63,7 @@ "conditional": "_condition_question", } -retry_errors = ( - ValidationError, -) +retry_errors = (ValidationError,) DataRow = namedtuple( "DataRow", @@ -416,9 +414,7 @@ def generate( proposal = None try: - proposal = self._make_proposal( - curr_node, neighbor_nodes, evolve_type - ) + proposal = self._make_proposal(curr_node, neighbor_nodes, evolve_type) except Exception as e: err_cause = e.__cause__ if not isinstance(err_cause, retry_errors): diff --git a/src/ragas/utils.py b/src/ragas/utils.py index 944e89fd4..00cc57d5f 100644 --- a/src/ragas/utils.py +++ b/src/ragas/utils.py @@ -17,6 +17,12 @@ # constant to tell us that there is no key passed to the llm/embeddings NO_KEY = "no-key" +# Cache location +DEFAULT_XDG_CACHE_HOME = "~/.cache" +XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME) +DEFAULT_RAGAS_CACHE_HOME = os.path.join(XDG_CACHE_HOME, "ragas") +RAGAS_CACHE_HOME = os.path.expanduser(os.getenv("RAGAS_HOME", DEFAULT_RAGAS_CACHE_HOME)) + @lru_cache(maxsize=1) def get_debug_mode() -> bool: @@ -39,6 +45,7 @@ def load_as_json(text): return {} +# not migrating to Prompt format to avoid circular imports JSON_PROMPT = HumanMessagePromptTemplate.from_template( """ @@ -152,4 +159,4 @@ def _find_outermost_json(self, text): return -1, -1 # No valid JSON found -json_loader = JsonLoader() \ No newline at end of file +json_loader = JsonLoader() diff --git a/tests/unit/test_import.py b/tests/unit/test_import.py index 53b312db8..0df78a883 100644 --- a/tests/unit/test_import.py +++ b/tests/unit/test_import.py @@ -27,4 +27,4 @@ def test_import_module(): assert hasattr(ragas.metrics, metric) for metric in test_critique: - assert hasattr(ragas.metrics.critique, metric) \ No newline at end of file + assert hasattr(ragas.metrics.critique, metric) diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py index feae8fa94..4fcd1d487 100644 --- a/tests/unit/test_prompt.py +++ b/tests/unit/test_prompt.py @@ -1,39 +1,41 @@ from ragas.llms.prompt import Prompt TESTCASES = [ - { - "instruction" : 'Create one or more statements from each sentence in the given answer.', - "examples" : [ - { - "question":"Cadmium Chloride is slightly soluble in this chemical, it is also called what?", - "answer":"alcohol", - "statements in json":"""{ + { + "name": "test-prompt", + "instruction": "Create one or more statements from each sentence in the given answer.", + "examples": [ + { + "question": "Cadmium Chloride is slightly soluble in this chemical, it is also called what?", + "answer": "alcohol", + "statements in json": """{ "statements": [ "Cadmium Chloride is slightly soluble in alcohol." ] - }""" - }, - { - "question":"Were Hitler and Benito Mussolini of the same nationality?", - "answer":"Sorry, I can't provide answer to that question.", - "statements in json":"""{ + }""", + }, + { + "question": "Were Hitler and Benito Mussolini of the same nationality?", + "answer": "Sorry, I can't provide answer to that question.", + "statements in json": """{ "statements": [] - }""" - } - ], - "input_keys" : ["question", "answer"], - "output_key" : "statements in json", - }, - { - "instruction" : 'Natural language inference. Use only "Yes" (1) or "No" (0) as a binary verdict.', - "examples" : [ - { - "Context":"""John is a student at XYZ University. He is pursuing a degree in Computer Science. He is enrolled in several courses this semester, including Data Structures, Algorithms, and Database Management. John is a diligent student and spends a significant amount of time studying and completing assignments. He often stays late in the library to work on his projects. + }""", + }, + ], + "input_keys": ["question", "answer"], + "output_key": "statements in json", + }, + { + "name": "test-prompt", + "instruction": 'Natural language inference. Use only "Yes" (1) or "No" (0) as a binary verdict.', + "examples": [ + { + "Context": """John is a student at XYZ University. He is pursuing a degree in Computer Science. He is enrolled in several courses this semester, including Data Structures, Algorithms, and Database Management. John is a diligent student and spends a significant amount of time studying and completing assignments. He often stays late in the library to work on his projects. statement_1: John is majoring in Biology. statement_2: John is taking a course on Artificial Intelligence. statement_3: John is a dedicated student. statement_4: John has a part-time job.""", - "Answer":"""[ + "Answer": """[ { "statement_1": "John is majoring in Biology.", "reason": "John's major is explicitly mentioned as Computer Science. There is no information suggesting he is majoring in Biology.", @@ -54,31 +56,44 @@ "reason": "There is no information given in the context about John having a part-time job.", "verdict": "0" }] - """ - } - ], - "input_keys" : ["Context"], - "output_key" : "Answer", - "output_type" : "json" - }, - { - "instruction" : 'This is a test prompt without examples', - "input_keys" : ["Context"], - "output_key" : "Answer", - "output_type" : "json" - }, + """, + } + ], + "input_keys": ["Context"], + "output_key": "Answer", + "output_type": "json", + }, + { + "name": "test-prompt", + "instruction": "This is a test prompt without examples", + "input_keys": ["Context"], + "output_key": "Answer", + "output_type": "json", + }, ] -def test_prompt_object(): +def test_prompt_object(): for testcase in TESTCASES: prompt = Prompt(**testcase) assert prompt is not None, "Prompt object is not created" - assert prompt.instruction==testcase['instruction'], "instruction in object is not same as in the testcase" - assert prompt.input_keys==testcase['input_keys'], "input_keys in object is not same as in the testcase" - assert prompt.output_key==testcase['output_key'], "output_key in object is not same as in the testcase" - assert prompt.output_type==testcase.get('output_type', 'json'), "output_type in object is not same as in the testcase" - assert prompt.examples==testcase.get('examples', []), "examples should be empty if not provided" - if testcase.get('examples'): - assert isinstance(prompt.get_example_str(0), str), "get_example_str should return a string" \ No newline at end of file + assert ( + prompt.instruction == testcase["instruction"] + ), "instruction in object is not same as in the testcase" + assert ( + prompt.input_keys == testcase["input_keys"] + ), "input_keys in object is not same as in the testcase" + assert ( + prompt.output_key == testcase["output_key"] + ), "output_key in object is not same as in the testcase" + assert prompt.output_type == testcase.get( + "output_type", "json" + ), "output_type in object is not same as in the testcase" + assert prompt.examples == testcase.get( + "examples", [] + ), "examples should be empty if not provided" + if testcase.get("examples"): + assert isinstance( + prompt.get_example_str(0), str + ), "get_example_str should return a string"