From 721a34cd4c9eae04fdbbf2ace3bba90c426825e2 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sat, 20 Jan 2024 12:23:14 -0800 Subject: [PATCH 01/10] cleanup simple evolution --- src/ragas/testset/evolutions.py | 17 ++++++++--------- src/ragas/testset/generator.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 88e879467..2e8f7c43a 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -130,8 +130,10 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): self.nodes = docstore.get_random_nodes(k=1) return await self.aretry_evolve(llm, docstore, update_count=False) - # frame a basic question with with node - seed_question = await simple_evolution(llm, merged_node) + results = llm.generate_text( + prompt=seed_question_prompt.format(context=merged_node.page_content) + ) + seed_question = results.generations[0][0].text # NOTE: might need improvement # select only one seed question here is_valid_question = await self.question_filter.afilter(seed_question) @@ -145,13 +147,6 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): return seed_question -async def simple_evolution(llm: BaseRagasLLM, seed_doc: Document): - prompt = seed_question_prompt.format(context=seed_doc.page_content) - results = llm.generate_text(prompt=prompt) - results = results.generations[0][0].text - return results - - async def multi_context_evolution( llm: BaseRagasLLM, seed_node: Node, doc_store: DocumentStore ): @@ -163,3 +158,7 @@ async def multi_context_evolution( results = await llm.agenerate_text(prompt=prompt) question = results.generations[0][0].text.strip() return question + + +class MultiContext(Evolution): + ... diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 94b957034..1ff77b5ce 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -6,10 +6,10 @@ from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings import BaseRagasEmbeddings -from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution +from ragas.executor import Executor +from ragas.testset.evolutions import SimpleEvolution, QuestionFilter, NodeFilter @dataclass From 6d583b265b3f382dade8b5c3877eabea6694166a Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sat, 20 Jan 2024 18:19:08 -0800 Subject: [PATCH 02/10] added multi_context question --- src/ragas/testset/evolutions.py | 63 ++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 2e8f7c43a..d49a5522d 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import typing as t from abc import ABC, abstractmethod @@ -14,11 +16,15 @@ filter_question_prompt, multi_context_question_prompt, seed_question_prompt, + compress_question_prompt, ) rng = default_rng() logger = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from ragas.llms.prompt import Prompt + @dataclass class Filter(ABC): @@ -147,18 +153,49 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): return seed_question -async def multi_context_evolution( - llm: BaseRagasLLM, seed_node: Node, doc_store: DocumentStore -): - question = simple_evolution(llm, seed_node) - similar_context = doc_store.get_similar(seed_node)[0] - prompt = multi_context_question_prompt.format( - question=question, context1=seed_node.page_content, context2=similar_context - ) - results = await llm.agenerate_text(prompt=prompt) - question = results.generations[0][0].text.strip() - return question +@dataclass +class MultiContextEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution(self.node_filter, self.question_filter) -class MultiContext(Evolution): - ... + def _transform_question( + self, llm: BaseRagasLLM, prompt: Prompt, question: str + ) -> str: + results = llm.generate_text(prompt=prompt.format(question=question)) + return results.generations[0][0].text.strip() + + def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + logger.info("evolving question") + return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + + async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + # gerenate seed question + self._root_node = docstore.get_random_nodes(k=1)[0] + question = await self.se.aevolve(llm, docstore) + logger.debug("[MultiContextEvolution] simple question generated: %s", question) + + # find a similar node and generate a question based on both + similar_context = docstore.get_similar(self._root_node)[0] + prompt = multi_context_question_prompt.format( + question=question, + context1=self._root_node.page_content, + context2=similar_context, + ) + results = await llm.agenerate_text(prompt=prompt) + question = results.generations[0][0].text.strip() + logger.debug( + "[MultiContextEvolution] multicontext question generated: %s", question + ) + + # compress the question + compressed_question = self._transform_question( + llm=llm, prompt=compress_question_prompt, question=question + ) + logger.debug( + "[MultiContextEvolution] multicontext question compressed: %s", question + ) + + return compressed_question From 7d07f4f64ac5c68c82c97b45fa484efe9925055f Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sat, 20 Jan 2024 20:43:24 -0800 Subject: [PATCH 03/10] resoning_question --- src/ragas/testset/evolutions.py | 53 +++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index d49a5522d..cff96ba79 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -17,6 +17,7 @@ multi_context_question_prompt, seed_question_prompt, compress_question_prompt, + reasoning_question_prompt, ) rng = default_rng() @@ -174,15 +175,17 @@ def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: # gerenate seed question self._root_node = docstore.get_random_nodes(k=1)[0] - question = await self.se.aevolve(llm, docstore) - logger.debug("[MultiContextEvolution] simple question generated: %s", question) + simple_question = await self.se.aevolve(llm, docstore) + logger.debug( + "[MultiContextEvolution] simple question generated: %s", simple_question + ) # find a similar node and generate a question based on both - similar_context = docstore.get_similar(self._root_node)[0] + similar_node = docstore.get_similar(self._root_node)[0] prompt = multi_context_question_prompt.format( - question=question, + question=simple_question, context1=self._root_node.page_content, - context2=similar_context, + context2=similar_node, ) results = await llm.agenerate_text(prompt=prompt) question = results.generations[0][0].text.strip() @@ -198,4 +201,44 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: "[MultiContextEvolution] multicontext question compressed: %s", question ) + if not await self.question_filter.afilter(compressed_question): + # retry + ... + + # if not await self.evolution_elimation.afilter( + # simple_question, compressed_question + # ): + # ... + return compressed_question + + +@dataclass +class ReasoningEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution(self.node_filter, self.question_filter) + + def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + logger.debug("evolving question") + return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + + async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + self._root_node = docstore.get_random_nodes(k=1)[0] + simple_question = await self.se.aevolve(llm, docstore) + logger.debug( + "[MultiContextEvolution] simple question generated: %s", simple_question + ) + + result = await llm.agenerate_text( + prompt=reasoning_question_prompt.format( + question=simple_question, context=self._root_node + ) + ) + reasoning_question = result.generations[0][0].text.strip() + if not await self.question_filter.afilter(reasoning_question): + # retry + ... + return reasoning_question From 77640338f4873f919e9d06e487fc2c448f421657 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 12:50:18 -0800 Subject: [PATCH 04/10] refactored evolutions --- src/ragas/testset/evolutions.py | 211 ++++++++++++++++++++------------ 1 file changed, 135 insertions(+), 76 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index cff96ba79..c63b4539d 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -18,6 +18,7 @@ seed_question_prompt, compress_question_prompt, reasoning_question_prompt, + evolution_elimination_prompt, ) rng = default_rng() @@ -65,22 +66,49 @@ async def afilter(self, question: str) -> bool: return json_results.get("verdict") != "No" +@dataclass +class EvolutionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, simple_question: str, compressed_question: str) -> bool: + return asyncio.get_event_loop().run_until_complete( + self.afilter(simple_question, compressed_question) + ) + + async def afilter(self, simple_question: str, compressed_question: str) -> bool: + prompt = evolution_elimination_prompt.format( + question1=simple_question, question2=compressed_question + ) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" + + +@dataclass +class CurrentNodes: + root_node: Node + nodes: t.List[Node] = field(default_factory=list) + + @dataclass class Evolution: + generator_llm: BaseRagasLLM + docstore: DocumentStore node_filter: NodeFilter question_filter: QuestionFilter - nodes: t.List[Node] = field(default_factory=list) max_tries: int = 5 - _root_node: t.Optional[Node] = field(default=None, init=False, repr=False) _tries: int = field(default=0, init=False, repr=False) - def merged_nodes(self) -> Node: + @staticmethod + def merge_nodes(nodes: CurrentNodes) -> Node: return Node( - doc_id="merged", page_content=" ".join(n.page_content for n in self.nodes) + doc_id="merged", page_content=" ".join(n.page_content for n in nodes.nodes) ) async def aretry_evolve( - self, llm: BaseRagasLLM, docstore: DocumentStore, update_count: bool = True + self, current_nodes: CurrentNodes, update_count: bool = True ): if update_count: self._tries += 1 @@ -88,56 +116,81 @@ async def aretry_evolve( if self._tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self.aevolve(llm, docstore) + return await self.aevolve(current_nodes) + + def _transform_question(self, prompt: Prompt, question: str) -> str: + results = self.generator_llm.generate_text( + prompt=prompt.format(question=question) + ) + return results.generations[0][0].text.strip() @abstractmethod - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + def evolve(self, current_nodes: CurrentNodes) -> str: ... @abstractmethod - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: + async def aevolve(self, current_nodes: CurrentNodes) -> str: ... +@dataclass +class ComplexEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + evolution_filter: EvolutionFilter = field(init=False, repr=False) + + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + @dataclass class SimpleEvolution(Evolution): - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): + def evolve(self, current_nodes: CurrentNodes) -> str: logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - def _get_more_adjacent_nodes(self, docstore: DocumentStore): + def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes """ - assert self._root_node is not None, "root node cannot be None" # get more nodes from above the context window - prev_adjacent_node = docstore.get_adjacent(self._root_node, Direction.PREV) + prev_adjacent_node = self.docstore.get_adjacent( + current_nodes.nodes[0], Direction.PREV + ) if prev_adjacent_node is None: # get more nodes from below the context window - next_adjacent_node = docstore.get_adjacent(self._root_node, Direction.NEXT) + next_adjacent_node = self.docstore.get_adjacent( + current_nodes.nodes[-1], Direction.NEXT + ) if next_adjacent_node is not None: # add next nodes towards the end - self.nodes.append(next_adjacent_node) + current_nodes.nodes.append(next_adjacent_node) else: # retry with new base node - self.nodes = docstore.get_random_nodes(k=1) - self._root_node = self.nodes[0] + nodes = self.docstore.get_random_nodes(k=1) + return CurrentNodes(root_node=nodes[0], nodes=nodes) else: # add prev nodes in index 0 - self.nodes.insert(0, prev_adjacent_node) - - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): - # can the node be used to frame a question? - if self._tries == 0: - self.nodes = docstore.get_random_nodes(k=1) - self._root_node = self.nodes[0] - merged_node = self.merged_nodes() - passed = await self.node_filter.afilter(self.nodes[0]) + current_nodes.nodes.insert(0, prev_adjacent_node) + + return current_nodes + + async def aevolve(self, current_nodes: CurrentNodes) -> str: + merged_node = self.merge_nodes(current_nodes) + passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: - self.nodes = docstore.get_random_nodes(k=1) - return await self.aretry_evolve(llm, docstore, update_count=False) + nodes = self.docstore.get_random_nodes(k=1) + new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) + return await self.aretry_evolve(new_current_nodes, update_count=False) - results = llm.generate_text( + results = self.generator_llm.generate_text( prompt=seed_question_prompt.format(context=merged_node.page_content) ) seed_question = results.generations[0][0].text @@ -146,48 +199,34 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore): is_valid_question = await self.question_filter.afilter(seed_question) if not is_valid_question: # get more context to rewrite question - self._get_more_adjacent_nodes(docstore) + current_nodes = self._get_more_adjacent_nodes(current_nodes) # retry with new nodes added - return await self.aretry_evolve(llm, docstore) + return await self.aretry_evolve(current_nodes) else: # if valid question return seed_question @dataclass -class MultiContextEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution(self.node_filter, self.question_filter) - - def _transform_question( - self, llm: BaseRagasLLM, prompt: Prompt, question: str - ) -> str: - results = llm.generate_text(prompt=prompt.format(question=question)) - return results.generations[0][0].text.strip() - - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): +class MultiContextEvolution(ComplexEvolution): + def evolve(self, current_nodes: CurrentNodes) -> str: logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: - # gerenate seed question - self._root_node = docstore.get_random_nodes(k=1)[0] - simple_question = await self.se.aevolve(llm, docstore) + async def aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se.aevolve(current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) # find a similar node and generate a question based on both - similar_node = docstore.get_similar(self._root_node)[0] + similar_node = self.docstore.get_similar(current_nodes.root_node)[0] prompt = multi_context_question_prompt.format( question=simple_question, - context1=self._root_node.page_content, + context1=current_nodes.root_node.page_content, context2=similar_node, ) - results = await llm.agenerate_text(prompt=prompt) + results = await self.generator_llm.agenerate_text(prompt=prompt) question = results.generations[0][0].text.strip() logger.debug( "[MultiContextEvolution] multicontext question generated: %s", question @@ -195,7 +234,7 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: # compress the question compressed_question = self._transform_question( - llm=llm, prompt=compress_question_prompt, question=question + prompt=compress_question_prompt, question=question ) logger.debug( "[MultiContextEvolution] multicontext question compressed: %s", question @@ -203,42 +242,62 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: if not await self.question_filter.afilter(compressed_question): # retry - ... + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_nodes) - # if not await self.evolution_elimation.afilter( - # simple_question, compressed_question - # ): - # ... + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_nodes) return compressed_question @dataclass -class ReasoningEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution(self.node_filter, self.question_filter) - - def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore): +class ReasoningEvolution(ComplexEvolution): + def evolve(self, current_nodes: CurrentNodes) -> str: logger.debug("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore)) + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str: - self._root_node = docstore.get_random_nodes(k=1)[0] - simple_question = await self.se.aevolve(llm, docstore) + async def aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se.aevolve(current_nodes) logger.debug( - "[MultiContextEvolution] simple question generated: %s", simple_question + "[ReasoningEvolution] simple question generated: %s", simple_question ) - result = await llm.agenerate_text( + result = await self.generator_llm.agenerate_text( prompt=reasoning_question_prompt.format( - question=simple_question, context=self._root_node + question=simple_question, context=current_nodes.root_node.page_content ) ) reasoning_question = result.generations[0][0].text.strip() - if not await self.question_filter.afilter(reasoning_question): + # + # compress the question + compressed_question = self._transform_question( + prompt=compress_question_prompt, question=reasoning_question + ) + logger.debug( + "[ReasoningEvolution] multicontext question compressed: %s", + reasoning_question, + ) + + if not await self.question_filter.afilter(compressed_question): # retry - ... + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_nodes) + + assert self.evolution_filter is not None, "evolution filter cannot be None" + if not await self.evolution_filter.afilter( + simple_question, compressed_question + ): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + logger.debug( + "evolution_filter failed, retrying with %s", len(current_nodes.nodes) + ) + return await self.aretry_evolve(current_nodes) + return reasoning_question From 3127ad1c70d2a29f111059ae6457c0359b960b97 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 14:53:27 -0800 Subject: [PATCH 05/10] return DataRow now --- src/ragas/testset/evolutions.py | 132 ++++++++++++++++++++------------ 1 file changed, 84 insertions(+), 48 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index c63b4539d..dfe891b90 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -4,6 +4,7 @@ import typing as t from abc import ABC, abstractmethod from dataclasses import dataclass, field +from collections import namedtuple from fsspec.exceptions import asyncio from numpy.random import default_rng @@ -19,6 +20,7 @@ compress_question_prompt, reasoning_question_prompt, evolution_elimination_prompt, + question_answer_prompt, ) rng = default_rng() @@ -92,6 +94,18 @@ class CurrentNodes: nodes: t.List[Node] = field(default_factory=list) +DataRow = namedtuple( + "DataRow", + [ + "question", + "context", + "answer", + "question_type", + "evolution_elimination", + ], +) + + @dataclass class Evolution: generator_llm: BaseRagasLLM @@ -104,19 +118,19 @@ class Evolution: @staticmethod def merge_nodes(nodes: CurrentNodes) -> Node: return Node( - doc_id="merged", page_content=" ".join(n.page_content for n in nodes.nodes) + doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes) ) async def aretry_evolve( self, current_nodes: CurrentNodes, update_count: bool = True - ): + ) -> str: if update_count: self._tries += 1 logger.info("retrying evolution: %s times", self._tries) if self._tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self.aevolve(current_nodes) + return await self._aevolve(current_nodes) def _transform_question(self, prompt: Prompt, question: str) -> str: results = self.generator_llm.generate_text( @@ -124,38 +138,6 @@ def _transform_question(self, prompt: Prompt, question: str) -> str: ) return results.generations[0][0].text.strip() - @abstractmethod - def evolve(self, current_nodes: CurrentNodes) -> str: - ... - - @abstractmethod - async def aevolve(self, current_nodes: CurrentNodes) -> str: - ... - - -@dataclass -class ComplexEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - evolution_filter: EvolutionFilter = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution( - generator_llm=self.generator_llm, - docstore=self.docstore, - node_filter=self.node_filter, - question_filter=self.question_filter, - ) - # init evolution filter with critic llm from another filter - self.evolution_filter = EvolutionFilter(self.node_filter.llm) - - -@dataclass -class SimpleEvolution(Evolution): - def evolve(self, current_nodes: CurrentNodes) -> str: - logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes @@ -182,7 +164,69 @@ def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): return current_nodes - async def aevolve(self, current_nodes: CurrentNodes) -> str: + def evolve(self, current_nodes: CurrentNodes) -> DataRow: + return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) + + async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: + evolved_question = await self._aevolve(current_nodes) + return self.generate_datarow( + question=evolved_question, + current_nodes=current_nodes, + ) + + @abstractmethod + async def _aevolve(self, current_nodes: CurrentNodes) -> str: + ... + + def generate_datarow( + self, + question: str, + current_nodes: CurrentNodes, + question_type: str = "", + evolution_elimination: bool = False, + ): + merged_nodes = self.merge_nodes(current_nodes) + results = self.generator_llm.generate_text( + prompt=question_answer_prompt.format( + question=question, context=merged_nodes.page_content + ) + ) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("answer generated: %s", json_results) + + # TODO: what do if answer is -1? + answer = json_results.get("answer") + + return DataRow( + question=question, + context=merged_nodes.page_content, + answer=answer, + question_type=question_type, + evolution_elimination=evolution_elimination, + ) + + +@dataclass +class ComplexEvolution(Evolution): + se: SimpleEvolution = field(init=False, repr=False) + evolution_filter: EvolutionFilter = field(init=False, repr=False) + + def __post_init__(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + +@dataclass +class SimpleEvolution(Evolution): + async def _aevolve(self, current_nodes: CurrentNodes) -> str: merged_node = self.merge_nodes(current_nodes) passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: @@ -209,12 +253,8 @@ async def aevolve(self, current_nodes: CurrentNodes) -> str: @dataclass class MultiContextEvolution(ComplexEvolution): - def evolve(self, current_nodes: CurrentNodes) -> str: - logger.info("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - - async def aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se.aevolve(current_nodes) + async def _aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se._aevolve(current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) @@ -258,12 +298,8 @@ async def aevolve(self, current_nodes: CurrentNodes) -> str: @dataclass class ReasoningEvolution(ComplexEvolution): - def evolve(self, current_nodes: CurrentNodes) -> str: - logger.debug("evolving question") - return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) - - async def aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se.aevolve(current_nodes) + async def _aevolve(self, current_nodes: CurrentNodes) -> str: + simple_question = await self.se._aevolve(current_nodes) logger.debug( "[ReasoningEvolution] simple question generated: %s", simple_question ) From 4529ebdb6d3f776b6acff961aad0f267ff9fffc4 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 15:04:22 -0800 Subject: [PATCH 06/10] moved out filter --- src/ragas/testset/evolutions.py | 64 +------------------------- src/ragas/testset/filters.py | 79 +++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 62 deletions(-) create mode 100644 src/ragas/testset/filters.py diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index dfe891b90..9e626216d 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -9,83 +9,23 @@ from fsspec.exceptions import asyncio from numpy.random import default_rng -from ragas.llms import BaseRagasLLM from ragas.llms.json_load import load_as_json from ragas.testset.docstore import Direction, Document, DocumentStore, Node from ragas.testset.prompts import ( - context_scoring_prompt, - filter_question_prompt, multi_context_question_prompt, seed_question_prompt, compress_question_prompt, reasoning_question_prompt, - evolution_elimination_prompt, question_answer_prompt, ) +from ragas.testset.filters import NodeFilter, QuestionFilter, EvolutionFilter rng = default_rng() logger = logging.getLogger(__name__) if t.TYPE_CHECKING: from ragas.llms.prompt import Prompt - - -@dataclass -class Filter(ABC): - ... - - -@dataclass -class NodeFilter(Filter): - llm: BaseRagasLLM - threshold: float = 7.5 - - def filter(self, node: Node) -> t.Dict: - return asyncio.get_event_loop().run_until_complete(self.afilter(node)) - - async def afilter(self, node: Node) -> t.Dict: - prompt = context_scoring_prompt.format(context=node.page_content) - results = await self.llm.agenerate_text(prompt=prompt) - output = results.generations[0][0].text.strip() - score = load_as_json(output) - score.update({"score": score.get("score", 0) >= self.threshold}) - return score - - -@dataclass -class QuestionFilter(Filter): - llm: BaseRagasLLM - - def filter(self, question: str) -> bool: - return asyncio.get_event_loop().run_until_complete(self.afilter(question)) - - async def afilter(self, question: str) -> bool: - prompt = filter_question_prompt.format(question=question) - results = await self.llm.agenerate_text(prompt=prompt) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("filtered question: %s", json_results) - return json_results.get("verdict") != "No" - - -@dataclass -class EvolutionFilter(Filter): - llm: BaseRagasLLM - - def filter(self, simple_question: str, compressed_question: str) -> bool: - return asyncio.get_event_loop().run_until_complete( - self.afilter(simple_question, compressed_question) - ) - - async def afilter(self, simple_question: str, compressed_question: str) -> bool: - prompt = evolution_elimination_prompt.format( - question1=simple_question, question2=compressed_question - ) - results = await self.llm.agenerate_text(prompt=prompt) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("filtered question: %s", json_results) - return json_results.get("verdict") != "No" + from ragas.llms import BaseRagasLLM @dataclass diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py new file mode 100644 index 000000000..e74e0cef5 --- /dev/null +++ b/src/ragas/testset/filters.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from dataclasses import dataclass +import logging +from abc import ABC +import asyncio +import typing as t + +from ragas.testset.prompts import ( + context_scoring_prompt, + filter_question_prompt, + evolution_elimination_prompt, +) +from ragas.llms.json_load import load_as_json + +if t.TYPE_CHECKING: + from ragas.llms.base import BaseRagasLLM + from ragas.testset.docstore import Node + + +logger = logging.getLogger(__name__) + + +@dataclass +class Filter(ABC): + ... + + +@dataclass +class NodeFilter(Filter): + llm: BaseRagasLLM + threshold: float = 7.5 + + def filter(self, node: Node) -> t.Dict: + return asyncio.get_event_loop().run_until_complete(self.afilter(node)) + + async def afilter(self, node: Node) -> t.Dict: + prompt = context_scoring_prompt.format(context=node.page_content) + results = await self.llm.agenerate_text(prompt=prompt) + output = results.generations[0][0].text.strip() + score = load_as_json(output) + score.update({"score": score.get("score", 0) >= self.threshold}) + return score + + +@dataclass +class QuestionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, question: str) -> bool: + return asyncio.get_event_loop().run_until_complete(self.afilter(question)) + + async def afilter(self, question: str) -> bool: + prompt = filter_question_prompt.format(question=question) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" + + +@dataclass +class EvolutionFilter(Filter): + llm: BaseRagasLLM + + def filter(self, simple_question: str, compressed_question: str) -> bool: + return asyncio.get_event_loop().run_until_complete( + self.afilter(simple_question, compressed_question) + ) + + async def afilter(self, simple_question: str, compressed_question: str) -> bool: + prompt = evolution_elimination_prompt.format( + question1=simple_question, question2=compressed_question + ) + results = await self.llm.agenerate_text(prompt=prompt) + results = results.generations[0][0].text.strip() + json_results = load_as_json(results) + logger.debug("filtered question: %s", json_results) + return json_results.get("verdict") != "No" From 9ce7eedb8ed1703022b22339c8cf30c7bc8a1ad9 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Sun, 21 Jan 2024 15:05:16 -0800 Subject: [PATCH 07/10] fix fmt --- src/ragas/testset/evolutions.py | 16 ++++++++-------- src/ragas/testset/filters.py | 10 +++++----- src/ragas/testset/generator.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 9e626216d..693a6cccf 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -2,30 +2,30 @@ import logging import typing as t -from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from abc import abstractmethod from collections import namedtuple +from dataclasses import dataclass, field from fsspec.exceptions import asyncio from numpy.random import default_rng from ragas.llms.json_load import load_as_json -from ragas.testset.docstore import Direction, Document, DocumentStore, Node +from ragas.testset.docstore import Direction, DocumentStore, Node +from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( - multi_context_question_prompt, - seed_question_prompt, compress_question_prompt, - reasoning_question_prompt, + multi_context_question_prompt, question_answer_prompt, + reasoning_question_prompt, + seed_question_prompt, ) -from ragas.testset.filters import NodeFilter, QuestionFilter, EvolutionFilter rng = default_rng() logger = logging.getLogger(__name__) if t.TYPE_CHECKING: - from ragas.llms.prompt import Prompt from ragas.llms import BaseRagasLLM + from ragas.llms.prompt import Prompt @dataclass diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py index e74e0cef5..79d19bd47 100644 --- a/src/ragas/testset/filters.py +++ b/src/ragas/testset/filters.py @@ -1,17 +1,17 @@ from __future__ import annotations -from dataclasses import dataclass -import logging -from abc import ABC import asyncio +import logging import typing as t +from abc import ABC +from dataclasses import dataclass +from ragas.llms.json_load import load_as_json from ragas.testset.prompts import ( context_scoring_prompt, - filter_question_prompt, evolution_elimination_prompt, + filter_question_prompt, ) -from ragas.llms.json_load import load_as_json if t.TYPE_CHECKING: from ragas.llms.base import BaseRagasLLM diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 1ff77b5ce..94b957034 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -6,10 +6,10 @@ from llama_index.readers.schema import Document as LlamaindexDocument from ragas.embeddings import BaseRagasEmbeddings +from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.executor import Executor -from ragas.testset.evolutions import SimpleEvolution, QuestionFilter, NodeFilter +from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution @dataclass From 9a82bc0a7381f6070432e655a3186b979cf32eb8 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Mon, 22 Jan 2024 16:53:36 -0800 Subject: [PATCH 08/10] basics working --- src/ragas/testset/evolutions.py | 123 +++++++++++++++++++------------- src/ragas/testset/generator.py | 90 ++++++++++++++++++----- 2 files changed, 146 insertions(+), 67 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 693a6cccf..28e93d2e2 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -9,7 +9,6 @@ from fsspec.exceptions import asyncio from numpy.random import default_rng -from ragas.llms.json_load import load_as_json from ragas.testset.docstore import Direction, DocumentStore, Node from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.testset.prompts import ( @@ -48,12 +47,11 @@ class CurrentNodes: @dataclass class Evolution: - generator_llm: BaseRagasLLM - docstore: DocumentStore - node_filter: NodeFilter - question_filter: QuestionFilter + generator_llm: t.Optional[BaseRagasLLM] = None + docstore: t.Optional[DocumentStore] = None + node_filter: t.Optional[NodeFilter] = None + question_filter: t.Optional[QuestionFilter] = None max_tries: int = 5 - _tries: int = field(default=0, init=False, repr=False) @staticmethod def merge_nodes(nodes: CurrentNodes) -> Node: @@ -62,17 +60,19 @@ def merge_nodes(nodes: CurrentNodes) -> Node: ) async def aretry_evolve( - self, current_nodes: CurrentNodes, update_count: bool = True + self, current_tries: int, current_nodes: CurrentNodes, update_count: bool = True ) -> str: if update_count: - self._tries += 1 - logger.info("retrying evolution: %s times", self._tries) - if self._tries > self.max_tries: + current_tries += 1 + logger.info("retrying evolution: %s times", current_tries) + if current_tries > self.max_tries: # TODO: make this into a custom exception raise ValueError("Max tries reached") - return await self._aevolve(current_nodes) + return await self._aevolve(current_tries, current_nodes) def _transform_question(self, prompt: Prompt, question: str) -> str: + assert self.generator_llm is not None, "generator_llm cannot be None" + results = self.generator_llm.generate_text( prompt=prompt.format(question=question) ) @@ -82,6 +82,8 @@ def _get_more_adjacent_nodes(self, current_nodes: CurrentNodes): """ if the evolutions doesn't have enough nodes to frame a question, get more nodes """ + assert self.docstore is not None, "docstore cannot be None" + # get more nodes from above the context window prev_adjacent_node = self.docstore.get_adjacent( current_nodes.nodes[0], Direction.PREV @@ -108,14 +110,16 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow: return asyncio.get_event_loop().run_until_complete(self.aevolve(current_nodes)) async def aevolve(self, current_nodes: CurrentNodes) -> DataRow: - evolved_question = await self._aevolve(current_nodes) + # init tries with 0 when first called + current_tries = 0 + evolved_question = await self._aevolve(current_tries, current_nodes) return self.generate_datarow( question=evolved_question, current_nodes=current_nodes, ) @abstractmethod - async def _aevolve(self, current_nodes: CurrentNodes) -> str: + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: ... def generate_datarow( @@ -125,18 +129,19 @@ def generate_datarow( question_type: str = "", evolution_elimination: bool = False, ): + assert self.generator_llm is not None, "generator_llm cannot be None" + merged_nodes = self.merge_nodes(current_nodes) results = self.generator_llm.generate_text( prompt=question_answer_prompt.format( question=question, context=merged_nodes.page_content ) ) - results = results.generations[0][0].text.strip() - json_results = load_as_json(results) - logger.debug("answer generated: %s", json_results) + answer = results.generations[0][0].text.strip() + logger.debug("answer generated: %s", answer) - # TODO: what do if answer is -1? - answer = json_results.get("answer") + if answer == "-1": + answer = None return DataRow( question=question, @@ -147,32 +152,22 @@ def generate_datarow( ) -@dataclass -class ComplexEvolution(Evolution): - se: SimpleEvolution = field(init=False, repr=False) - evolution_filter: EvolutionFilter = field(init=False, repr=False) - - def __post_init__(self): - # init simple evolution to get seed question - self.se = SimpleEvolution( - generator_llm=self.generator_llm, - docstore=self.docstore, - node_filter=self.node_filter, - question_filter=self.question_filter, - ) - # init evolution filter with critic llm from another filter - self.evolution_filter = EvolutionFilter(self.node_filter.llm) - - -@dataclass +@dataclass(unsafe_hash=True) class SimpleEvolution(Evolution): - async def _aevolve(self, current_nodes: CurrentNodes) -> str: + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.docstore is not None, "docstore cannot be None" + assert self.node_filter is not None, "node filter cannot be None" + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + merged_node = self.merge_nodes(current_nodes) passed = await self.node_filter.afilter(current_nodes.root_node) if not passed["score"]: nodes = self.docstore.get_random_nodes(k=1) new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) - return await self.aretry_evolve(new_current_nodes, update_count=False) + return await self.aretry_evolve( + current_tries, new_current_nodes, update_count=False + ) results = self.generator_llm.generate_text( prompt=seed_question_prompt.format(context=merged_node.page_content) @@ -185,16 +180,39 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: # get more context to rewrite question current_nodes = self._get_more_adjacent_nodes(current_nodes) # retry with new nodes added - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) else: # if valid question return seed_question @dataclass +class ComplexEvolution(Evolution): + se: t.Optional[SimpleEvolution] = field(default=None, repr=False) + evolution_filter: t.Optional[EvolutionFilter] = field(default=None, repr=False) + + def init_evolution(self): + # init simple evolution to get seed question + self.se = SimpleEvolution( + generator_llm=self.generator_llm, + docstore=self.docstore, + node_filter=self.node_filter, + question_filter=self.question_filter, + ) + # init evolution filter with critic llm from another filter + assert self.node_filter is not None, "node filter cannot be None" + self.evolution_filter = EvolutionFilter(self.node_filter.llm) + + +@dataclass(unsafe_hash=True) class MultiContextEvolution(ComplexEvolution): - async def _aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se._aevolve(current_nodes) + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.docstore is not None, "docstore cannot be None" + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) @@ -223,7 +241,7 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: if not await self.question_filter.afilter(compressed_question): # retry current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) assert self.evolution_filter is not None, "evolution filter cannot be None" if not await self.evolution_filter.afilter( @@ -231,15 +249,19 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: ): # retry current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) return compressed_question -@dataclass +@dataclass(unsafe_hash=True) class ReasoningEvolution(ComplexEvolution): - async def _aevolve(self, current_nodes: CurrentNodes) -> str: - simple_question = await self.se._aevolve(current_nodes) + async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: + assert self.generator_llm is not None, "generator_llm cannot be None" + assert self.question_filter is not None, "question_filter cannot be None" + assert self.se is not None, "simple evolution cannot be None" + + simple_question = await self.se._aevolve(current_tries, current_nodes) logger.debug( "[ReasoningEvolution] simple question generated: %s", simple_question ) @@ -263,7 +285,7 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: if not await self.question_filter.afilter(compressed_question): # retry current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) assert self.evolution_filter is not None, "evolution filter cannot be None" if not await self.evolution_filter.afilter( @@ -274,6 +296,11 @@ async def _aevolve(self, current_nodes: CurrentNodes) -> str: logger.debug( "evolution_filter failed, retrying with %s", len(current_nodes.nodes) ) - return await self.aretry_evolve(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) return reasoning_question + + +simple = SimpleEvolution() +multi_context = MultiContextEvolution() +reasoning = ReasoningEvolution() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 94b957034..b0a2f7628 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import typing as t from dataclasses import dataclass +import pandas as pd from langchain.embeddings import OpenAIEmbeddings from langchain_openai.chat_models import ChatOpenAI from llama_index.readers.schema import Document as LlamaindexDocument @@ -9,7 +12,35 @@ from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import NodeFilter, QuestionFilter, SimpleEvolution +from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow, Evolution +from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter + +Distributions = t.Dict[Evolution, float] + + +@dataclass +class TestDataset: + """ + TestDataset class + """ + + test_data: t.List[DataRow] + + def to_pandas(self) -> pd.DataFrame: + data_samples = [] + for data in self.test_data: + question_type = data.question_type + data = { + "question": data.question, + "context": data.context, + "answer": "" if data.answer is None else data.answer, + "question_type": question_type, + "episode_done": True, + "evolution_elimination": data.evolution_elimination, + } + data_samples.append(data) + + return pd.DataFrame.from_records(data_samples) @dataclass @@ -53,31 +84,52 @@ def with_openai( ) def generate_with_llamaindex_docs( - self, documents: t.Sequence[LlamaindexDocument], test_size: int + self, + documents: t.Sequence[LlamaindexDocument], + test_size: int, + distributions: Distributions = {}, ): # chunk documents and add to docstore self.docstore.add_documents( [Document.from_llamaindex_document(doc) for doc in documents] ) - return self.generate(test_size=test_size) - - def generate(self, test_size: int): - node_filter = NodeFilter(self.critic_llm) - ques_filter = QuestionFilter(self.critic_llm) - exec = Executor() - qs = [] - for i in range(test_size): - se = SimpleEvolution(node_filter, ques_filter) - exec.submit( - se.aevolve, - self.generator_llm, - self.docstore, - name=f"SimpleEvolution-{i}", - ) + return self.generate(test_size=test_size, distributions=distributions) + + def generate(self, test_size: int, distributions: Distributions = {}): + # init filters and evolutions + for evolution in distributions.keys(): + if evolution.generator_llm is None: + evolution.generator_llm = self.generator_llm + if evolution.docstore is None: + evolution.docstore = self.docstore + + if evolution.question_filter is None: + evolution.question_filter = QuestionFilter(llm=self.critic_llm) + if evolution.node_filter is None: + evolution.node_filter = NodeFilter(llm=self.critic_llm) + + if isinstance(evolution, ComplexEvolution): + evolution.init_evolution() + if evolution.evolution_filter is None: + evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + + exec = Executor(raise_exceptions=True, is_async=True) + + current_nodes = [ + CurrentNodes(root_node=n, nodes=[n]) + for n in self.docstore.get_random_nodes(k=test_size) + ] + for evolution, probability in distributions.items(): + for i in range(round(probability * test_size)): + exec.submit( + evolution.aevolve, + current_nodes[i], + name=f"{evolution.__class__.__name__}-{i}", + ) try: - qs = exec.results() + test_data_rows = exec.results() except ValueError as e: raise e - return qs + return TestDataset(test_data=test_data_rows) From 602ae7bcf11f695bd02791b4a2b89db6dc353fbb Mon Sep 17 00:00:00 2001 From: jjmachan Date: Mon, 22 Jan 2024 17:18:19 -0800 Subject: [PATCH 09/10] final polish --- src/ragas/testset/evolutions.py | 15 ++++++++++++--- src/ragas/testset/generator.py | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 28e93d2e2..09e62eaa3 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -152,7 +152,7 @@ def generate_datarow( ) -@dataclass(unsafe_hash=True) +@dataclass class SimpleEvolution(Evolution): async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: assert self.docstore is not None, "docstore cannot be None" @@ -185,6 +185,9 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str # if valid question return seed_question + def __hash__(self): + return hash(self.__class__.__name__) + @dataclass class ComplexEvolution(Evolution): @@ -204,7 +207,7 @@ def init_evolution(self): self.evolution_filter = EvolutionFilter(self.node_filter.llm) -@dataclass(unsafe_hash=True) +@dataclass class MultiContextEvolution(ComplexEvolution): async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: assert self.docstore is not None, "docstore cannot be None" @@ -253,8 +256,11 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str return compressed_question + def __hash__(self): + return hash(self.__class__.__name__) -@dataclass(unsafe_hash=True) + +@dataclass class ReasoningEvolution(ComplexEvolution): async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str: assert self.generator_llm is not None, "generator_llm cannot be None" @@ -300,6 +306,9 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str return reasoning_question + def __hash__(self): + return hash(self.__class__.__name__) + simple = SimpleEvolution() multi_context = MultiContextEvolution() diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index b0a2f7628..748ff615e 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from dataclasses import dataclass @@ -15,6 +16,7 @@ from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow, Evolution from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter +logger = logging.getLogger(__name__) Distributions = t.Dict[Evolution, float] @@ -88,6 +90,7 @@ def generate_with_llamaindex_docs( documents: t.Sequence[LlamaindexDocument], test_size: int, distributions: Distributions = {}, + **kwargs, ): # chunk documents and add to docstore self.docstore.add_documents( @@ -96,9 +99,11 @@ def generate_with_llamaindex_docs( return self.generate(test_size=test_size, distributions=distributions) - def generate(self, test_size: int, distributions: Distributions = {}): + def generate( + self, test_size: int, distributions: Distributions = {}, show_debug_logs=False + ): # init filters and evolutions - for evolution in distributions.keys(): + for evolution in distributions: if evolution.generator_llm is None: evolution.generator_llm = self.generator_llm if evolution.docstore is None: @@ -113,8 +118,12 @@ def generate(self, test_size: int, distributions: Distributions = {}): evolution.init_evolution() if evolution.evolution_filter is None: evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm) + if show_debug_logs: + from ragas.utils import patch_logger + + patch_logger("ragas.testset.evolutions", logging.DEBUG) - exec = Executor(raise_exceptions=True, is_async=True) + exec = Executor(desc="Generating", raise_exceptions=True, is_async=True) current_nodes = [ CurrentNodes(root_node=n, nodes=[n]) From 25d8e15dd2a523a7cda5d84690a2c8c009dbbc8f Mon Sep 17 00:00:00 2001 From: jjmachan Date: Mon, 22 Jan 2024 17:28:19 -0800 Subject: [PATCH 10/10] added benchmark --- src/ragas/testset/generator.py | 4 ++-- tests/benchmarks/benchmark_testsetgen.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 748ff615e..29714a9a9 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -13,11 +13,11 @@ from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore -from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow, Evolution +from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter logger = logging.getLogger(__name__) -Distributions = t.Dict[Evolution, float] +Distributions = t.Dict[t.Any, float] @dataclass diff --git a/tests/benchmarks/benchmark_testsetgen.py b/tests/benchmarks/benchmark_testsetgen.py index e78e7eb48..282e2d75f 100644 --- a/tests/benchmarks/benchmark_testsetgen.py +++ b/tests/benchmarks/benchmark_testsetgen.py @@ -3,10 +3,13 @@ from llama_index import download_loader +from ragas.testset.evolutions import multi_context, reasoning, simple from ragas.testset.generator import TestsetGenerator generator = TestsetGenerator.with_openai() +distributions = {simple: 0.5, multi_context: 0.4, reasoning: 0.1} + def get_documents(): SemanticScholarReader = download_loader("SemanticScholarReader") @@ -30,7 +33,9 @@ def get_documents(): os.environ["PYTHONASYNCIODEBUG"] = "1" print("Starting [Asyncio]") start = time.time() - generator.generate_with_llamaindex_docs(documents, test_size=100) + generator.generate_with_llamaindex_docs( + documents=documents, test_size=100, distributions=distributions + ) print(f"Time taken: {time.time() - start:.2f}s") # Threads