diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index da43b7776..bd099523d 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -8,11 +8,8 @@ from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI from langchain_core.language_models import BaseLanguageModel from langchain_core.outputs import LLMResult -from tenacity import ( - retry, - stop_after_attempt, # for exponential backoff - wait_random_exponential, -) +from tenacity import stop_after_attempt # for exponential backoff +from tenacity import retry, wait_random_exponential if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 255a1b93a..2502977a9 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -15,6 +15,7 @@ from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( compress_question_prompt, + conditional_question_prompt, multi_context_question_prompt, question_answer_prompt, reasoning_question_prompt, @@ -56,6 +57,9 @@ def merge_nodes(nodes: CurrentNodes) -> Node: doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes) ) + def init_evolution(self): + ... + async def aretry_evolve( self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True ) -> EvolutionOutput: @@ -211,6 +215,55 @@ def init_evolution(self): assert self.node_filter is not None, "node filter cannot be None" self.evolution_filter = EvolutionFilter(self.node_filter.llm) + async def _acomplex_evolution( + self, current_tries: int, current_nodes: CurrentNodes, question_prompt: Prompt + ): + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) + logger.debug( + "[%s] simple question generated: %s", + self.__class__.__name__, + simple_question, + ) + + result = await self.generator_llm.agenerate_text( + prompt=question_prompt.format( + question=simple_question, context=current_nodes.root_node.page_content + ) + ) + reasoning_question = result.generations[0][0].text.strip() + + # compress the question + compressed_question = self._transform_question( + prompt=compress_question_prompt, question=reasoning_question + ) + logger.debug( + "[%s] multicontext question compressed: %s", + self.__class__.__name__, + reasoning_question, + ) + + if not await self.question_filter.afilter(compressed_question): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + logger.debug( + "evolution_filter failed, retrying with %s", len(current_nodes.nodes) + ) + return await self.aretry_evolve(current_tries, current_nodes) + + return reasoning_question, current_nodes, "reasoning" + @dataclass class MultiContextEvolution(ComplexEvolution): @@ -269,51 +322,33 @@ def __hash__(self): @dataclass class ReasoningEvolution(ComplexEvolution): + reasoning_question_prompt: Prompt = field( + default_factory=lambda: reasoning_question_prompt + ) + async def _aevolve( self, current_tries: int, current_nodes: CurrentNodes ) -> EvolutionOutput: - assert self.generator_llm is not None, "generator_llm cannot be None" - assert self.question_filter is not None, "question_filter cannot be None" - assert self.se is not None, "simple evolution cannot be None" - - simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) - logger.debug( - "[ReasoningEvolution] simple question generated: %s", simple_question + return await self._acomplex_evolution( + current_tries, current_nodes, self.reasoning_question_prompt ) - result = await self.generator_llm.agenerate_text( - prompt=reasoning_question_prompt.format( - question=simple_question, context=current_nodes.root_node.page_content - ) - ) - reasoning_question = result.generations[0][0].text.strip() - # - # compress the question - compressed_question = self._transform_question( - prompt=compress_question_prompt, question=reasoning_question - ) - logger.debug( - "[ReasoningEvolution] multicontext question compressed: %s", - reasoning_question, - ) + def __hash__(self): + return hash(self.__class__.__name__) - if not await self.question_filter.afilter(compressed_question): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_tries, current_nodes) - assert self.evolution_filter is not None, "evolution filter cannot be None" - if not await self.evolution_filter.afilter( - simple_question, compressed_question - ): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - logger.debug( - "evolution_filter failed, retrying with %s", len(current_nodes.nodes) - ) - return await self.aretry_evolve(current_tries, current_nodes) +@dataclass +class ConditionalEvolution(ComplexEvolution): + conditional_question_prompt: Prompt = field( + default_factory=lambda: conditional_question_prompt + ) - return reasoning_question, current_nodes, "reasoning" + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: + return await self._acomplex_evolution( + current_tries, current_nodes, self.conditional_question_prompt + ) def __hash__(self): return hash(self.__class__.__name__) @@ -322,3 +357,4 @@ def __hash__(self): simple = SimpleEvolution() multi_context = MultiContextEvolution() reasoning = ReasoningEvolution() +conditional = ConditionalEvolution() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 425d5f349..d9f449f40 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -156,9 +156,10 @@ def generate( evolution.node_filter = NodeFilter(llm=self.critic_llm) if isinstance(evolution, ComplexEvolution): - evolution.init_evolution() if evolution.evolution_filter is None: evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + + evolution.init_evolution() if with_debugging_logs: from ragas.utils import patch_logger diff --git a/tests/benchmarks/benchmark_testsetgen.py b/tests/benchmarks/benchmark_testsetgen.py index 282e2d75f..4aa802e79 100644 --- a/tests/benchmarks/benchmark_testsetgen.py +++ b/tests/benchmarks/benchmark_testsetgen.py @@ -3,12 +3,12 @@ from llama_index import download_loader -from ragas.testset.evolutions import multi_context, reasoning, simple +from ragas.testset.evolutions import conditional, multi_context, reasoning, simple from ragas.testset.generator import TestsetGenerator generator = TestsetGenerator.with_openai() -distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1} +distributions = {simple: 0.5, multi_context: 0.3, reasoning: 0.1, conditional: 0.1} def get_documents(): diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index c83ce48f9..ed50a8d31 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -10,7 +10,7 @@ from ragas.llms.prompt import PromptValue -class TestLLM(BaseRagasLLM): +class FakeTestLLM(BaseRagasLLM): def llm(self): return self