From 55064c2a9423e855757ab9e6532274b9d50e19e0 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Thu, 8 Feb 2024 16:47:29 -0800 Subject: [PATCH 1/2] ensure dict type --- src/ragas/testset/evolutions.py | 7 +++++-- src/ragas/testset/extractor.py | 1 + src/ragas/testset/filters.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 09802a188..90c275e63 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -180,9 +180,12 @@ async def generate_datarow( relevent_contexts_result = await json_loader.safe_load( results.generations[0][0].text.strip(), llm=self.generator_llm ) - relevant_context_indices = relevent_contexts_result.get( - "relevant_context", None + relevant_context_indices = ( + relevent_contexts_result.get("relevant_context", None) + if isinstance(relevent_contexts_result, dict) + else None ) + if relevant_context_indices is None: relevant_context = CurrentNodes( root_node=current_nodes.root_node, nodes=current_nodes.nodes diff --git a/src/ragas/testset/extractor.py b/src/ragas/testset/extractor.py index 77c586c2e..09154d8cf 100644 --- a/src/ragas/testset/extractor.py +++ b/src/ragas/testset/extractor.py @@ -50,6 +50,7 @@ async def extract(self, node: Node, is_async: bool = True) -> t.List[str]: keyphrases = await json_loader.safe_load( results.generations[0][0].text.strip(), llm=self.llm, is_async=is_async ) + keyphrases = keyphrases if isinstance(keyphrases, dict) else {} logger.debug("keyphrases: %s", keyphrases) return keyphrases.get("keyphrases", []) diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py index 0eb06b77b..9d017b7e5 100644 --- a/src/ragas/testset/filters.py +++ b/src/ragas/testset/filters.py @@ -54,6 +54,7 @@ async def filter(self, node: Node) -> t.Dict: results = await self.llm.generate(prompt=prompt) output = results.generations[0][0].text.strip() score = await json_loader.safe_load(output, llm=self.llm) + score = score if isinstance(score, dict) else {} logger.debug("node filter: %s", score) score.update({"score": score.get("score", 0) >= self.threshold}) return score @@ -85,6 +86,7 @@ async def filter(self, question: str) -> bool: results = await self.llm.generate(prompt=prompt) results = results.generations[0][0].text.strip() json_results = await json_loader.safe_load(results, llm=self.llm) + json_results = json_results if isinstance(json_results, dict) else {} logger.debug("filtered question: %s", json_results) return json_results.get("verdict") == "1" @@ -117,6 +119,7 @@ async def filter(self, simple_question: str, compressed_question: str) -> bool: results = await self.llm.generate(prompt=prompt) results = results.generations[0][0].text.strip() json_results = await json_loader.safe_load(results, llm=self.llm) + json_results = json_results if isinstance(json_results, dict) else {} logger.debug("evolution filter: %s", json_results) return json_results.get("verdict") == "1" From 121cf3f6544adc713e36639aeefe6e5d3fc70126 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 14 Feb 2024 16:56:18 -0800 Subject: [PATCH 2/2] fix evolution flows --- src/ragas/testset/evolutions.py | 64 +++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 90c275e63..46e85a448 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -63,12 +63,23 @@ class Evolution: @staticmethod def merge_nodes(nodes: CurrentNodes) -> Node: - return Node( + # TODO: while merging merge according to the order of documents + # if any nodes from same document take account their page order + + new_node = Node( doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes), keyphrases=[phrase for n in nodes.nodes for phrase in n.keyphrases], ) + embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding else None + if embed_dim: + node_embeddings = np.array([n.embedding for n in nodes.nodes]).reshape( + -1, embed_dim + ) + new_node.embedding = np.average(node_embeddings, axis=0) + return new_node + def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None): self.is_async = is_async if run_config is None: @@ -191,7 +202,10 @@ async def generate_datarow( root_node=current_nodes.root_node, nodes=current_nodes.nodes ) else: - relevant_context = current_nodes + selected_nodes = [current_nodes.nodes[i] for i in relevant_context_indices] + relevant_context = CurrentNodes( + root_node=selected_nodes[0], nodes=selected_nodes + ) merged_nodes = self.merge_nodes(relevant_context) results = await self.generator_llm.generate( @@ -207,7 +221,7 @@ async def generate_datarow( return DataRow( question=question, - contexts=[n.page_content for n in current_nodes.nodes], + contexts=[n.page_content for n in relevant_context.nodes], ground_truth="" if answer is None else answer, evolution_type=evolution_type, ) @@ -253,7 +267,7 @@ async def _aevolve( assert self.question_filter is not None, "question_filter cannot be None" merged_node = self.merge_nodes(current_nodes) - passed = await self.node_filter.filter(current_nodes.root_node) + passed = await self.node_filter.filter(merged_node) if not passed["score"]: nodes = self.docstore.get_random_nodes(k=1) new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) @@ -334,16 +348,19 @@ async def _acomplex_evolution( assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) + simple_question, current_nodes, _ = await self.se._aevolve( + current_tries, current_nodes + ) logger.debug( "[%s] simple question generated: %s", self.__class__.__name__, simple_question, ) + merged_node = self.merge_nodes(current_nodes) result = await self.generator_llm.generate( prompt=question_prompt.format( - question=simple_question, context=current_nodes.root_node.page_content + question=simple_question, context=merged_node.page_content ) ) reasoning_question = result.generations[0][0].text.strip() @@ -409,22 +426,32 @@ async def _aevolve( assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) + simple_question, current_nodes, _ = await self.se._aevolve( + current_tries, current_nodes + ) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) - # find a similar node and generate a question based on both - similar_node = self.docstore.get_similar(current_nodes.root_node) + merged_node = self.merge_nodes(current_nodes) + similar_node = self.docstore.get_similar(merged_node, top_k=1) if similar_node == []: # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + new_random_nodes = self.docstore.get_random_nodes(k=1) + current_nodes = CurrentNodes( + root_node=new_random_nodes[0], nodes=new_random_nodes + ) return await self.aretry_evolve(current_tries, current_nodes) + else: + assert isinstance(similar_node[0], Node), "similar_node must be a Node" + current_nodes = CurrentNodes( + root_node=merged_node, nodes=[merged_node, similar_node[0]] + ) prompt = self.multi_context_question_prompt.format( question=simple_question, - context1=current_nodes.root_node.page_content, - context2=similar_node, + context1=merged_node.page_content, + context2=similar_node[0].page_content, ) results = await self.generator_llm.generate(prompt=prompt) question = results.generations[0][0].text.strip() @@ -432,19 +459,20 @@ async def _aevolve( "[MultiContextEvolution] multicontext question generated: %s", question ) + if not await self.question_filter.filter(question): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + # compress the question compressed_question = await self._transform_question( prompt=self.compress_question_prompt, question=question ) logger.debug( - "[MultiContextEvolution] multicontext question compressed: %s", question + "[MultiContextEvolution] multicontext question compressed: %s", + compressed_question, ) - if not await self.question_filter.filter(compressed_question): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_tries, current_nodes) - assert self.evolution_filter is not None, "evolution filter cannot be None" if await self.evolution_filter.filter(simple_question, compressed_question): # retry