diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 90c275e63..46e85a448 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -63,12 +63,23 @@ class Evolution: @staticmethod def merge_nodes(nodes: CurrentNodes) -> Node: - return Node( + # TODO: while merging merge according to the order of documents + # if any nodes from same document take account their page order + + new_node = Node( doc_id="merged", page_content="\n".join(n.page_content for n in nodes.nodes), keyphrases=[phrase for n in nodes.nodes for phrase in n.keyphrases], ) + embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding else None + if embed_dim: + node_embeddings = np.array([n.embedding for n in nodes.nodes]).reshape( + -1, embed_dim + ) + new_node.embedding = np.average(node_embeddings, axis=0) + return new_node + def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None): self.is_async = is_async if run_config is None: @@ -191,7 +202,10 @@ async def generate_datarow( root_node=current_nodes.root_node, nodes=current_nodes.nodes ) else: - relevant_context = current_nodes + selected_nodes = [current_nodes.nodes[i] for i in relevant_context_indices] + relevant_context = CurrentNodes( + root_node=selected_nodes[0], nodes=selected_nodes + ) merged_nodes = self.merge_nodes(relevant_context) results = await self.generator_llm.generate( @@ -207,7 +221,7 @@ async def generate_datarow( return DataRow( question=question, - contexts=[n.page_content for n in current_nodes.nodes], + contexts=[n.page_content for n in relevant_context.nodes], ground_truth="" if answer is None else answer, evolution_type=evolution_type, ) @@ -253,7 +267,7 @@ async def _aevolve( assert self.question_filter is not None, "question_filter cannot be None" merged_node = self.merge_nodes(current_nodes) - passed = await self.node_filter.filter(current_nodes.root_node) + passed = await self.node_filter.filter(merged_node) if not passed["score"]: nodes = self.docstore.get_random_nodes(k=1) new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes) @@ -334,16 +348,19 @@ async def _acomplex_evolution( assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) + simple_question, current_nodes, _ = await self.se._aevolve( + current_tries, current_nodes + ) logger.debug( "[%s] simple question generated: %s", self.__class__.__name__, simple_question, ) + merged_node = self.merge_nodes(current_nodes) result = await self.generator_llm.generate( prompt=question_prompt.format( - question=simple_question, context=current_nodes.root_node.page_content + question=simple_question, context=merged_node.page_content ) ) reasoning_question = result.generations[0][0].text.strip() @@ -409,22 +426,32 @@ async def _aevolve( assert self.question_filter is not None, "question_filter cannot be None" assert self.se is not None, "simple evolution cannot be None" - simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes) + simple_question, current_nodes, _ = await self.se._aevolve( + current_tries, current_nodes + ) logger.debug( "[MultiContextEvolution] simple question generated: %s", simple_question ) - # find a similar node and generate a question based on both - similar_node = self.docstore.get_similar(current_nodes.root_node) + merged_node = self.merge_nodes(current_nodes) + similar_node = self.docstore.get_similar(merged_node, top_k=1) if similar_node == []: # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + new_random_nodes = self.docstore.get_random_nodes(k=1) + current_nodes = CurrentNodes( + root_node=new_random_nodes[0], nodes=new_random_nodes + ) return await self.aretry_evolve(current_tries, current_nodes) + else: + assert isinstance(similar_node[0], Node), "similar_node must be a Node" + current_nodes = CurrentNodes( + root_node=merged_node, nodes=[merged_node, similar_node[0]] + ) prompt = self.multi_context_question_prompt.format( question=simple_question, - context1=current_nodes.root_node.page_content, - context2=similar_node, + context1=merged_node.page_content, + context2=similar_node[0].page_content, ) results = await self.generator_llm.generate(prompt=prompt) question = results.generations[0][0].text.strip() @@ -432,19 +459,20 @@ async def _aevolve( "[MultiContextEvolution] multicontext question generated: %s", question ) + if not await self.question_filter.filter(question): + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + # compress the question compressed_question = await self._transform_question( prompt=self.compress_question_prompt, question=question ) logger.debug( - "[MultiContextEvolution] multicontext question compressed: %s", question + "[MultiContextEvolution] multicontext question compressed: %s", + compressed_question, ) - if not await self.question_filter.filter(compressed_question): - # retry - current_nodes = self.se._get_more_adjacent_nodes(current_nodes) - return await self.aretry_evolve(current_tries, current_nodes) - assert self.evolution_filter is not None, "evolution filter cannot be None" if await self.evolution_filter.filter(simple_question, compressed_question): # retry