From 50ae3ebcb2805f0fc3b84acd77ae9f2b9fb7c9a0 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Fri, 26 Jan 2024 17:22:23 -0800 Subject: [PATCH 1/5] rough outline --- src/ragas/testset/evolutions.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 255a1b93a..14ba95055 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -211,6 +211,18 @@ def init_evolution(self): assert self.node_filter is not None, "node filter cannot be None" self.evolution_filter = EvolutionFilter(self.node_filter.llm) + def _acomplex_evolution(): + # this is copy of RasoningEvolution._aevolve + +@dataclass +class ConditionalEvolution(ComplexEvolution): + conditional_question: Prompt = field(default=question_answer_prompt) + + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: + return await self._acomplex_evolution(current_tries, current_nodes, self.) + @dataclass class MultiContextEvolution(ComplexEvolution): @@ -287,7 +299,7 @@ async def _aevolve( ) ) reasoning_question = result.generations[0][0].text.strip() - # + # compress the question compressed_question = self._transform_question( prompt=compress_question_prompt, question=reasoning_question From 58101fa21eb53442af7bc05bf6154e596db0697f Mon Sep 17 00:00:00 2001 From: jjmachan Date: Fri, 26 Jan 2024 19:47:47 -0800 Subject: [PATCH 2/5] added conditional --- src/ragas/testset/evolutions.py | 112 +++++++++++++---------- src/ragas/testset/generator.py | 3 +- tests/benchmarks/benchmark_testsetgen.py | 4 +- 3 files changed, 70 insertions(+), 49 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 14ba95055..f680a7d2f 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -19,6 +19,7 @@ question_answer_prompt, reasoning_question_prompt, seed_question_prompt, + conditional_question_prompt, ) rng = default_rng() @@ -56,6 +57,9 @@ def merge_nodes(nodes: CurrentNodes) -> Node: doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes) ) + def init_evolution(self): + ... + async def aretry_evolve( self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True ) -> EvolutionOutput: @@ -211,17 +215,54 @@ def init_evolution(self): assert self.node_filter is not None, "node filter cannot be None" self.evolution_filter = EvolutionFilter(self.node_filter.llm) - def _acomplex_evolution(): - # this is copy of RasoningEvolution._aevolve + async def _acomplex_evolution( + self, current_tries: int, current_nodes: CurrentNodes, question_prompt: Prompt + ): + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" -@dataclass -class ConditionalEvolution(ComplexEvolution): - conditional_question: Prompt = field(default=question_answer_prompt) + simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) + logger.debug( + "[%s] simple question generated: %s", + self.__class__.__name__, + simple_question, + ) - async def _aevolve( - self, current_tries: int, current_nodes: CurrentNodes - ) -> EvolutionOutput: - return await self._acomplex_evolution(current_tries, current_nodes, self.) + result = await self.generator_llm.agenerate_text( + prompt=question_prompt.format( + question=simple_question, context=current_nodes.root_node.page_content + ) + ) + reasoning_question = result.generations[0][0].text.strip() + + # compress the question + compressed_question = self._transform_question( + prompt=compress_question_prompt, question=reasoning_question + ) + logger.debug( + "[%s] multicontext question compressed: %s", + self.__class__.__name__, + reasoning_question, + ) + + if not await self.question_filter.afilter(compressed_question): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + logger.debug( + "evolution_filter failed, retrying with %s", len(current_nodes.nodes) + ) + return await self.aretry_evolve(current_tries, current_nodes) + + return reasoning_question, current_nodes, "reasoning" @dataclass @@ -281,51 +322,29 @@ def __hash__(self): @dataclass class ReasoningEvolution(ComplexEvolution): + reasoning_question_prompt: Prompt = field(default=reasoning_question_prompt) + async def _aevolve( self, current_tries: int, current_nodes: CurrentNodes ) -> EvolutionOutput: - assert self.generator_llm is not None, "generator_llm cannot be None" - assert self.question_filter is not None, "question_filter cannot be None" - assert self.se is not None, "simple evolution cannot be None" - - simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) - logger.debug( - "[ReasoningEvolution] simple question generated: %s", simple_question + return await self._acomplex_evolution( + current_tries, current_nodes, self.reasoning_question_prompt ) - result = await self.generator_llm.agenerate_text( - prompt=reasoning_question_prompt.format( - question=simple_question, context=current_nodes.root_node.page_content - ) - ) - reasoning_question = result.generations[0][0].text.strip() - - # compress the question - compressed_question = self._transform_question( - prompt=compress_question_prompt, question=reasoning_question - ) - logger.debug( - "[ReasoningEvolution] multicontext question compressed: %s", - reasoning_question, - ) + def __hash__(self): + return hash(self.__class__.__name__) - if not await self.question_filter.afilter(compressed_question): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_tries, current_nodes) - assert self.evolution_filter is not None, "evolution filter cannot be None" - if not await self.evolution_filter.afilter( - simple_question, compressed_question - ): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - logger.debug( - "evolution_filter failed, retrying with %s", len(current_nodes.nodes) - ) - return await self.aretry_evolve(current_tries, current_nodes) +@dataclass +class ConditionalEvolution(ComplexEvolution): + conditional_question_prompt: Prompt = field(default=conditional_question_prompt) - return reasoning_question, current_nodes, "reasoning" + async def _aevolve( + self, current_tries: int, current_nodes: CurrentNodes + ) -> EvolutionOutput: + return await self._acomplex_evolution( + current_tries, current_nodes, self.conditional_question_prompt + ) def __hash__(self): return hash(self.__class__.__name__) @@ -334,3 +353,4 @@ def __hash__(self): simple = SimpleEvolution() multi_context = MultiContextEvolution() reasoning = ReasoningEvolution() +conditional = ConditionalEvolution() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 425d5f349..d9f449f40 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -156,9 +156,10 @@ def generate( evolution.node_filter = NodeFilter(llm=self.critic_llm) if isinstance(evolution, ComplexEvolution): - evolution.init_evolution() if evolution.evolution_filter is None: evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + + evolution.init_evolution() if with_debugging_logs: from ragas.utils import patch_logger diff --git a/tests/benchmarks/benchmark_testsetgen.py b/tests/benchmarks/benchmark_testsetgen.py index 282e2d75f..e0fd7ee02 100644 --- a/tests/benchmarks/benchmark_testsetgen.py +++ b/tests/benchmarks/benchmark_testsetgen.py @@ -3,12 +3,12 @@ from llama_index import download_loader -from ragas.testset.evolutions import multi_context, reasoning, simple +from ragas.testset.evolutions import multi_context, reasoning, simple, conditional from ragas.testset.generator import TestsetGenerator generator = TestsetGenerator.with_openai() -distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1} +distributions = {simple: 0.5, multi_context: 0.3, reasoning: 0.1, conditional: 0.1} def get_documents(): From 0ad840d909c5cbd993ca135604d30ef78a8ba8ae Mon Sep 17 00:00:00 2001 From: jjmachan Date: Fri, 26 Jan 2024 19:48:51 -0800 Subject: [PATCH 3/5] fmt --- src/ragas/llms/base.py | 7 ++----- src/ragas/testset/evolutions.py | 2 +- tests/benchmarks/benchmark_testsetgen.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index da43b7776..bd099523d 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -8,11 +8,8 @@ from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI from langchain_core.language_models import BaseLanguageModel from langchain_core.outputs import LLMResult -from tenacity import ( - retry, - stop_after_attempt, # for exponential backoff - wait_random_exponential, -) +from tenacity import stop_after_attempt # for exponential backoff +from tenacity import retry, wait_random_exponential if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index f680a7d2f..0ce6b22ef 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -15,11 +15,11 @@ from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( compress_question_prompt, + conditional_question_prompt, multi_context_question_prompt, question_answer_prompt, reasoning_question_prompt, seed_question_prompt, - conditional_question_prompt, ) rng = default_rng() diff --git a/tests/benchmarks/benchmark_testsetgen.py b/tests/benchmarks/benchmark_testsetgen.py index e0fd7ee02..4aa802e79 100644 --- a/tests/benchmarks/benchmark_testsetgen.py +++ b/tests/benchmarks/benchmark_testsetgen.py @@ -3,7 +3,7 @@ from llama_index import download_loader -from ragas.testset.evolutions import multi_context, reasoning, simple, conditional +from ragas.testset.evolutions import conditional, multi_context, reasoning, simple from ragas.testset.generator import TestsetGenerator generator = TestsetGenerator.with_openai() From 3ea5ba1b64f2d36adaa6bda3a2f1e4de4b34a616 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Fri, 26 Jan 2024 19:50:53 -0800 Subject: [PATCH 4/5] fix test warning --- tests/unit/llms/test_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/llms/test_llm.py b/tests/unit/llms/test_llm.py index c83ce48f9..ed50a8d31 100644 --- a/tests/unit/llms/test_llm.py +++ b/tests/unit/llms/test_llm.py @@ -10,7 +10,7 @@ from ragas.llms.prompt import PromptValue -class TestLLM(BaseRagasLLM): +class FakeTestLLM(BaseRagasLLM): def llm(self): return self From 43c81305fb101a03a34d2eee2fac47050028237c Mon Sep 17 00:00:00 2001 From: jjmachan Date: Fri, 26 Jan 2024 19:54:26 -0800 Subject: [PATCH 5/5] fix default factory error --- src/ragas/testset/evolutions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 0ce6b22ef..2502977a9 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -322,7 +322,9 @@ def __hash__(self): @dataclass class ReasoningEvolution(ComplexEvolution): - reasoning_question_prompt: Prompt = field(default=reasoning_question_prompt) + reasoning_question_prompt: Prompt = field( + default_factory=lambda: reasoning_question_prompt + ) async def _aevolve( self, current_tries: int, current_nodes: CurrentNodes @@ -337,7 +339,9 @@ def __hash__(self): @dataclass class ConditionalEvolution(ComplexEvolution): - conditional_question_prompt: Prompt = field(default=conditional_question_prompt) + conditional_question_prompt: Prompt = field( + default_factory=lambda: conditional_question_prompt + ) async def _aevolve( self, current_tries: int, current_nodes: CurrentNodes