diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 88e879467..09e62eaa3 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -1,165 +1,315 @@ +from __future__ import annotations + import logging import typing as t -from abc import ABC, abstractmethod +from abc import abstractmethod +from collections import namedtuple from dataclasses import dataclass, field from fsspec.exceptions import asyncio from numpy.random import default_rng -from ragas.llms import BaseRagasLLM -from ragas.llms.json_load import load_as_json -from ragas.testset.docstore import Direction, Document, DocumentStore, Node +from ragas.testset.docstore import Direction, DocumentStore, Node +from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( - context_scoring_prompt, - filter_question_prompt, + compress_question_prompt, multi_context_question_prompt, + question_answer_prompt, + reasoning_question_prompt, seed_question_prompt, ) rng = default_rng() logger = logging.getLogger(__name__) - -@dataclass -class Filter(ABC): - ... - - -@dataclass -class NodeFilter(Filter): - llm: BaseRagasLLM - threshold: float = 7.5 - - def filter(self, node: Node) -> t.Dict: - return asyncio.get_event_loop().run_until_complete(self.afilter(node)) - - async def afilter(self, node: Node) -> t.Dict: - prompt = context_scoring_prompt.format(context=node.page_content) - results = await self.llm.agenerate_text(prompt=prompt) - output = results.generations[0][0].text.strip() - score = load_as_json(output) - score.update({"score": score.get("score", 0) >= self.threshold}) - return score +if t.TYPE_CHECKING: + from ragas.llms import BaseRagasLLM + from ragas.llms.prompt import Prompt @dataclass -class QuestionFilter(Filter): - llm: BaseRagasLLM +class CurrentNodes: + root_node: Node + nodes: t.List[Node] = field(default_factory=list) - def filter(self, question: str) -> bool: - return asyncio.get_event_loop().run_until_complete(self.afilter(question)) - async def afilter(self, question: str) -> bool: - prompt = filter_question_prompt.format(question=question) - results = await self.llm.agenerate_text(prompt=prompt) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("filtered question: %s", json_results) - return json_results.get("verdict") != "No" +DataRow = namedtuple( + "DataRow", + [ + "question", + "context", + "answer", + "question_type", + "evolution_elimination", + ], +) @dataclass class Evolution: - node_filter: NodeFilter - question_filter: QuestionFilter - nodes: t.List[Node] = field(default_factory=list) + generator_llm: t.Optional[BaseRagasLLM] = None + docstore: t.Optional[DocumentStore] = None + node_filter: t.Optional[NodeFilter] = None + question_filter: t.Optional[QuestionFilter] = None max_tries: int = 5 - _root_node: t.Optional[Node] = field(default=None, init=False, repr=False) - _tries: int = field(default=0, init=False, repr=False) - def merged_nodes(self) -> Node: + @staticmethod + def merge_nodes(nodes: CurrentNodes) -> Node: return Node( - doc_id="merged", page_content=" ".join(n.page_content for n in self.nodes) + doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes) ) async def aretry_evolve( - self, llm: BaseRagasLLM, docstore: DocumentStore, update_count: bool = True - ): + self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True + ) -> str: if update_count: - self._tries += 1 - logger.info("retrying evolution: %s times", self._tries) - if self._tries > self.max_tries: + current_tries += 1 + logger.info("retrying evolution: %s times", current_tries) + if current_tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self.aevolve(llm, docstore) - - @abstractmethod - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: - ... - - @abstractmethod - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: - ... + return await self._aevolve(current_tries, current_nodes) + def _transform_question(self, prompt: Prompt, question: str) -> str: + assert self.generator_llm is not None, "generator_llm cannot be None" -@dataclass -class SimpleEvolution(Evolution): - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): - logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + results = self.generator_llm.generate_text( + prompt=prompt.format(question=question) + ) + return results.generations[0][0].text.strip() - def _get_more_adjacent_nodes(self, docstore: DocumentStore): + def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes """ - assert self._root_node is not None, "root node cannot be None" + assert self.docstore is not None, "docstore cannot be None" + # get more nodes from above the context window - prev_adjacent_node = docstore.get_adjacent(self._root_node, Direction.PREV) + prev_adjacent_node = self.docstore.get_adjacent( + current_nodes.nodes[0], Direction.PREV + ) if prev_adjacent_node is None: # get more nodes from below the context window - next_adjacent_node = docstore.get_adjacent(self._root_node, Direction.NEXT) + next_adjacent_node = self.docstore.get_adjacent( + current_nodes.nodes[-1], Direction.NEXT + ) if next_adjacent_node is not None: # add next nodes towards the end - self.nodes.append(next_adjacent_node) + current_nodes.nodes.append(next_adjacent_node) else: # retry with new base node - self.nodes = docstore.get_random_nodes(k=1) - self._root_node = self.nodes[0] + nodes = self.docstore.get_random_nodes(k=1) + return CurrentNodes(root_node=nodes[0], nodes=nodes) else: # add prev nodes in index 0 - self.nodes.insert(0, prev_adjacent_node) - - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): - # can the node be used to frame a question? - if self._tries == 0: - self.nodes = docstore.get_random_nodes(k=1) - self._root_node = self.nodes[0] - merged_node = self.merged_nodes() - passed = await self.node_filter.afilter(self.nodes[0]) + current_nodes.nodes.insert(0, prev_adjacent_node) + + return current_nodes + + def evolve(self, current_nodes: CurrentNodes) -> DataRow: + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) + + async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: + # init tries with 0 when first called + current_tries = 0 + evolved_question = await self._aevolve(current_tries, current_nodes) + return self.generate_datarow( + question=evolved_question, + current_nodes=current_nodes, + ) + + @abstractmethod + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + ... + + def generate_datarow( + self, + question: str, + current_nodes: CurrentNodes, + question_type: str = "", + evolution_elimination: bool = False, + ): + assert self.generator_llm is not None, "generator_llm cannot be None" + + merged_nodes = self.merge_nodes(current_nodes) + results = self.generator_llm.generate_text( + prompt=question_answer_prompt.format( + question=question, context=merged_nodes.page_content + ) + ) + answer = results.generations[0][0].text.strip() + logger.debug("answer generated: %s", answer) + + if answer == "-1": + answer = None + + return DataRow( + question=question, + context=merged_nodes.page_content, + answer=answer, + question_type=question_type, + evolution_elimination=evolution_elimination, + ) + + +@dataclass +class SimpleEvolution(Evolution): + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.docstore is not None, "docstore cannot be None" + assert self.node_filter is not None, "node filter cannot be None" + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + + merged_node = self.merge_nodes(current_nodes) + passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: - self.nodes = docstore.get_random_nodes(k=1) - return await self.aretry_evolve(llm, docstore, update_count=False) + nodes = self.docstore.get_random_nodes(k=1) + new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) + return await self.aretry_evolve( + current_tries, new_current_nodes, update_count=False + ) - # frame a basic question with with node - seed_question = await simple_evolution(llm, merged_node) + results = self.generator_llm.generate_text( + prompt=seed_question_prompt.format(context=merged_node.page_content) + ) + seed_question = results.generations[0][0].text # NOTE: might need improvement # select only one seed question here is_valid_question = await self.question_filter.afilter(seed_question) if not is_valid_question: # get more context to rewrite question - self._get_more_adjacent_nodes(docstore) + current_nodes = self._get_more_adjacent_nodes(current_nodes) # retry with new nodes added - return await self.aretry_evolve(llm, docstore) + return await self.aretry_evolve(current_tries, current_nodes) else: # if valid question return seed_question + def __hash__(self): + return hash(self.__class__.__name__) + + +@dataclass +class ComplexEvolution(Evolution): + se: t.Optional[SimpleEvolution] = field(default=None, repr=False) + evolution_filter: t.Optional[EvolutionFilter] = field(default=None, repr=False) + + def init_evolution(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + assert self.node_filter is not None, "node filter cannot be None" + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + +@dataclass +class MultiContextEvolution(ComplexEvolution): + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.docstore is not None, "docstore cannot be None" + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question = await self.se._aevolve(current_tries, current_nodes) + logger.debug( + "[MultiContextEvolution] simple question generated: %s", simple_question + ) + + # find a similar node and generate a question based on both + similar_node = self.docstore.get_similar(current_nodes.root_node)[0] + prompt = multi_context_question_prompt.format( + question=simple_question, + context1=current_nodes.root_node.page_content, + context2=similar_node, + ) + results = await self.generator_llm.agenerate_text(prompt=prompt) + question = results.generations[0][0].text.strip() + logger.debug( + "[MultiContextEvolution] multicontext question generated: %s", question + ) + + # compress the question + compressed_question = self._transform_question( + prompt=compress_question_prompt, question=question + ) + logger.debug( + "[MultiContextEvolution] multicontext question compressed: %s", question + ) + + if not await self.question_filter.afilter(compressed_question): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + + return compressed_question + + def __hash__(self): + return hash(self.__class__.__name__) + + +@dataclass +class ReasoningEvolution(ComplexEvolution): + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question = await self.se._aevolve(current_tries, current_nodes) + logger.debug( + "[ReasoningEvolution] simple question generated: %s", simple_question + ) + + result = await self.generator_llm.agenerate_text( + prompt=reasoning_question_prompt.format( + question=simple_question, context=current_nodes.root_node.page_content + ) + ) + reasoning_question = result.generations[0][0].text.strip() + # + # compress the question + compressed_question = self._transform_question( + prompt=compress_question_prompt, question=reasoning_question + ) + logger.debug( + "[ReasoningEvolution] multicontext question compressed: %s", + reasoning_question, + ) + + if not await self.question_filter.afilter(compressed_question): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + logger.debug( + "evolution_filter failed, retrying with %s", len(current_nodes.nodes) + ) + return await self.aretry_evolve(current_tries, current_nodes) + + return reasoning_question + + def __hash__(self): + return hash(self.__class__.__name__) + -async def simple_evolution(llm: BaseRagasLLM, seed_doc: Document): - prompt = seed_question_prompt.format(context=seed_doc.page_content) - results = llm.generate_text(prompt=prompt) - results = results.generations[0][0].text - return results - - -async def multi_context_evolution( - llm: BaseRagasLLM, seed_node: Node, doc_store: DocumentStore -): - question = simple_evolution(llm, seed_node) - similar_context = doc_store.get_similar(seed_node)[0] - prompt = multi_context_question_prompt.format( - question=question, context1=seed_node.page_content, context2=similar_context - ) - results = await llm.agenerate_text(prompt=prompt) - question = results.generations[0][0].text.strip() - return question +simple = SimpleEvolution() +multi_context = MultiContextEvolution() +reasoning = ReasoningEvolution() diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py new file mode 100644 index 000000000..79d19bd47 --- /dev/null +++ b/src/ragas/testset/filters.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import asyncio +import logging +import typing as t +from abc import ABC +from dataclasses import dataclass + +from ragas.llms.json_load import load_as_json +from ragas.testset.prompts import ( + context_scoring_prompt, + evolution_elimination_prompt, + filter_question_prompt, +) + +if t.TYPE_CHECKING: + from ragas.llms.base import BaseRagasLLM + from ragas.testset.docstore import Node + + +logger = logging.getLogger(__name__) + + +@dataclass +class Filter(ABC): + ... + + +@dataclass +class NodeFilter(Filter): + llm: BaseRagasLLM + threshold: float = 7.5 + + def filter(self, node: Node) -> t.Dict: + return asyncio.get_event_loop().run_until_complete(self.afilter(node)) + + async def afilter(self, node: Node) -> t.Dict: + prompt = context_scoring_prompt.format(context=node.page_content) + results = await self.llm.agenerate_text(prompt=prompt) + output = results.generations[0][0].text.strip() + score = load_as_json(output) + score.update({"score": score.get("score", 0) >= self.threshold}) + return score + + +@dataclass +class QuestionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, question: str) -> bool: + return asyncio.get_event_loop().run_until_complete(self.afilter(question)) + + async def afilter(self, question: str) -> bool: + prompt = filter_question_prompt.format(question=question) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" + + +@dataclass +class EvolutionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, simple_question: str, compressed_question: str) -> bool: + return asyncio.get_event_loop().run_until_complete( + self.afilter(simple_question, compressed_question) + ) + + async def afilter(self, simple_question: str, compressed_question: str) -> bool: + prompt = evolution_elimination_prompt.format( + question1=simple_question, question2=compressed_question + ) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 94b957034..29714a9a9 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +import logging import typing as t from dataclasses import dataclass +import pandas as pd from langchain.embeddings import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI from llama_index.readers.schema import Document as LlamaindexDocument @@ -9,7 +13,36 @@ from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution +from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow +from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter + +logger = logging.getLogger(__name__) +Distributions = t.Dict[t.Any, float] + + +@dataclass +class TestDataset: + """ + TestDataset class + """ + + test_data: t.List[DataRow] + + def to_pandas(self) -> pd.DataFrame: + data_samples = [] + for data in self.test_data: + question_type = data.question_type + data = { + "question": data.question, + "context": data.context, + "answer": "" if data.answer is None else data.answer, + "question_type": question_type, + "episode_done": True, + "evolution_elimination": data.evolution_elimination, + } + data_samples.append(data) + + return pd.DataFrame.from_records(data_samples) @dataclass @@ -53,31 +86,59 @@ def with_openai( ) def generate_with_llamaindex_docs( - self, documents: t.Sequence[LlamaindexDocument], test_size: int + self, + documents: t.Sequence[LlamaindexDocument], + test_size: int, + distributions: Distributions = {}, + **kwargs, ): # chunk documents and add to docstore self.docstore.add_documents( [Document.from_llamaindex_document(doc) for doc in documents] ) - return self.generate(test_size=test_size) - - def generate(self, test_size: int): - node_filter = NodeFilter(self.critic_llm) - ques_filter = QuestionFilter(self.critic_llm) - exec = Executor() - qs = [] - for i in range(test_size): - se = SimpleEvolution(node_filter, ques_filter) - exec.submit( - se.aevolve, - self.generator_llm, - self.docstore, - name=f"SimpleEvolution-{i}", - ) + return self.generate(test_size=test_size, distributions=distributions) + + def generate( + self, test_size: int, distributions: Distributions = {}, show_debug_logs=False + ): + # init filters and evolutions + for evolution in distributions: + if evolution.generator_llm is None: + evolution.generator_llm = self.generator_llm + if evolution.docstore is None: + evolution.docstore = self.docstore + + if evolution.question_filter is None: + evolution.question_filter = QuestionFilter(llm=self.critic_llm) + if evolution.node_filter is None: + evolution.node_filter = NodeFilter(llm=self.critic_llm) + + if isinstance(evolution, ComplexEvolution): + evolution.init_evolution() + if evolution.evolution_filter is None: + evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + if show_debug_logs: + from ragas.utils import patch_logger + + patch_logger("ragas.testset.evolutions", logging.DEBUG) + + exec = Executor(desc="Generating", raise_exceptions=True, is_async=True) + + current_nodes = [ + CurrentNodes(root_node=n, nodes=[n]) + for n in self.docstore.get_random_nodes(k=test_size) + ] + for evolution, probability in distributions.items(): + for i in range(round(probability * test_size)): + exec.submit( + evolution.aevolve, + current_nodes[i], + name=f"{evolution.__class__.__name__}-{i}", + ) try: - qs = exec.results() + test_data_rows = exec.results() except ValueError as e: raise e - return qs + return TestDataset(test_data=test_data_rows) diff --git a/tests/benchmarks/benchmark_testsetgen.py b/tests/benchmarks/benchmark_testsetgen.py index e78e7eb48..282e2d75f 100644 --- a/tests/benchmarks/benchmark_testsetgen.py +++ b/tests/benchmarks/benchmark_testsetgen.py @@ -3,10 +3,13 @@ from llama_index import download_loader +from ragas.testset.evolutions import multi_context, reasoning, simple from ragas.testset.generator import TestsetGenerator generator = TestsetGenerator.with_openai() +distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1} + def get_documents(): SemanticScholarReader = download_loader("SemanticScholarReader") @@ -30,7 +33,9 @@ def get_documents(): os.environ["PYTHONASYNCIODEBUG"] = "1" print("Starting [Asyncio]") start = time.time() - generator.generate_with_llamaindex_docs(documents, test_size=100) + generator.generate_with_llamaindex_docs( + documents=documents, test_size=100, distributions=distributions + ) print(f"Time taken: {time.time() - start:.2f}s") # Threads