Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from langchain_community.llms import AzureOpenAI, OpenAI, VertexAI
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import LLMResult
from tenacity import (
retry,
stop_after_attempt, # for exponential backoff
wait_random_exponential,
)
from tenacity import stop_after_attempt # for exponential backoff
from tenacity import retry, wait_random_exponential

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
Expand Down
110 changes: 73 additions & 37 deletions src/ragas/testset/evolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
from ragas.testset.prompts import (
compress_question_prompt,
conditional_question_prompt,
multi_context_question_prompt,
question_answer_prompt,
reasoning_question_prompt,
Expand Down Expand Up @@ -56,6 +57,9 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes)
)

def init_evolution(self):
...

async def aretry_evolve(
self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True
) -> EvolutionOutput:
Expand Down Expand Up @@ -211,6 +215,55 @@ def init_evolution(self):
assert self.node_filter is not None, "node filter cannot be None"
self.evolution_filter = EvolutionFilter(self.node_filter.llm)

async def _acomplex_evolution(
self, current_tries: int, current_nodes: CurrentNodes, question_prompt: Prompt
):
assert self.generator_llm is not None, "generator_llm cannot be None"
assert self.question_filter is not None, "question_filter cannot be None"
assert self.se is not None, "simple evolution cannot be None"

simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
logger.debug(
"[%s] simple question generated: %s",
self.__class__.__name__,
simple_question,
)

result = await self.generator_llm.agenerate_text(
prompt=question_prompt.format(
question=simple_question, context=current_nodes.root_node.page_content
)
)
reasoning_question = result.generations[0][0].text.strip()

# compress the question
compressed_question = self._transform_question(
prompt=compress_question_prompt, question=reasoning_question
)
logger.debug(
"[%s] multicontext question compressed: %s",
self.__class__.__name__,
reasoning_question,
)

if not await self.question_filter.afilter(compressed_question):
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
return await self.aretry_evolve(current_tries, current_nodes)

assert self.evolution_filter is not None, "evolution filter cannot be None"
if not await self.evolution_filter.afilter(
simple_question, compressed_question
):
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
logger.debug(
"evolution_filter failed, retrying with %s", len(current_nodes.nodes)
)
return await self.aretry_evolve(current_tries, current_nodes)

return reasoning_question, current_nodes, "reasoning"


@dataclass
class MultiContextEvolution(ComplexEvolution):
Expand Down Expand Up @@ -269,51 +322,33 @@ def __hash__(self):

@dataclass
class ReasoningEvolution(ComplexEvolution):
reasoning_question_prompt: Prompt = field(
default_factory=lambda: reasoning_question_prompt
)

async def _aevolve(
self, current_tries: int, current_nodes: CurrentNodes
) -> EvolutionOutput:
assert self.generator_llm is not None, "generator_llm cannot be None"
assert self.question_filter is not None, "question_filter cannot be None"
assert self.se is not None, "simple evolution cannot be None"

simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
logger.debug(
"[ReasoningEvolution] simple question generated: %s", simple_question
return await self._acomplex_evolution(
current_tries, current_nodes, self.reasoning_question_prompt
)

result = await self.generator_llm.agenerate_text(
prompt=reasoning_question_prompt.format(
question=simple_question, context=current_nodes.root_node.page_content
)
)
reasoning_question = result.generations[0][0].text.strip()
#
# compress the question
compressed_question = self._transform_question(
prompt=compress_question_prompt, question=reasoning_question
)
logger.debug(
"[ReasoningEvolution] multicontext question compressed: %s",
reasoning_question,
)
def __hash__(self):
return hash(self.__class__.__name__)

if not await self.question_filter.afilter(compressed_question):
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
return await self.aretry_evolve(current_tries, current_nodes)

assert self.evolution_filter is not None, "evolution filter cannot be None"
if not await self.evolution_filter.afilter(
simple_question, compressed_question
):
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
logger.debug(
"evolution_filter failed, retrying with %s", len(current_nodes.nodes)
)
return await self.aretry_evolve(current_tries, current_nodes)
@dataclass
class ConditionalEvolution(ComplexEvolution):
conditional_question_prompt: Prompt = field(
default_factory=lambda: conditional_question_prompt
)

return reasoning_question, current_nodes, "reasoning"
async def _aevolve(
self, current_tries: int, current_nodes: CurrentNodes
) -> EvolutionOutput:
return await self._acomplex_evolution(
current_tries, current_nodes, self.conditional_question_prompt
)

def __hash__(self):
return hash(self.__class__.__name__)
Expand All @@ -322,3 +357,4 @@ def __hash__(self):
simple = SimpleEvolution()
multi_context = MultiContextEvolution()
reasoning = ReasoningEvolution()
conditional = ConditionalEvolution()
3 changes: 2 additions & 1 deletion src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def generate(
evolution.node_filter = NodeFilter(llm=self.critic_llm)

if isinstance(evolution, ComplexEvolution):
evolution.init_evolution()
if evolution.evolution_filter is None:
evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm)

evolution.init_evolution()
if with_debugging_logs:
from ragas.utils import patch_logger

Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/benchmark_testsetgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from llama_index import download_loader

from ragas.testset.evolutions import multi_context, reasoning, simple
from ragas.testset.evolutions import conditional, multi_context, reasoning, simple
from ragas.testset.generator import TestsetGenerator

generator = TestsetGenerator.with_openai()

distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1}
distributions = {simple: 0.5, multi_context: 0.3, reasoning: 0.1, conditional: 0.1}


def get_documents():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llms/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ragas.llms.prompt import PromptValue


class TestLLM(BaseRagasLLM):
class FakeTestLLM(BaseRagasLLM):
def llm(self):
return self

Expand Down