Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions src/ragas/testset/evolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,23 @@ class Evolution:

@staticmethod
def merge_nodes(nodes: CurrentNodes) -> Node:
return Node(
# TODO: while merging merge according to the order of documents
# if any nodes from same document take account their page order

new_node = Node(
doc_id="merged",
page_content="\n".join(n.page_content for n in nodes.nodes),
keyphrases=[phrase for n in nodes.nodes for phrase in n.keyphrases],
)

embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding else None
if embed_dim:
node_embeddings = np.array([n.embedding for n in nodes.nodes]).reshape(
-1, embed_dim
)
new_node.embedding = np.average(node_embeddings, axis=0)
return new_node

def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
self.is_async = is_async
if run_config is None:
Expand Down Expand Up @@ -191,7 +202,10 @@ async def generate_datarow(
root_node=current_nodes.root_node, nodes=current_nodes.nodes
)
else:
relevant_context = current_nodes
selected_nodes = [current_nodes.nodes[i] for i in relevant_context_indices]
relevant_context = CurrentNodes(
root_node=selected_nodes[0], nodes=selected_nodes
)

merged_nodes = self.merge_nodes(relevant_context)
results = await self.generator_llm.generate(
Expand All @@ -207,7 +221,7 @@ async def generate_datarow(

return DataRow(
question=question,
contexts=[n.page_content for n in current_nodes.nodes],
contexts=[n.page_content for n in relevant_context.nodes],
ground_truth="" if answer is None else answer,
evolution_type=evolution_type,
)
Expand Down Expand Up @@ -253,7 +267,7 @@ async def _aevolve(
assert self.question_filter is not None, "question_filter cannot be None"

merged_node = self.merge_nodes(current_nodes)
passed = await self.node_filter.filter(current_nodes.root_node)
passed = await self.node_filter.filter(merged_node)
if not passed["score"]:
nodes = self.docstore.get_random_nodes(k=1)
new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes)
Expand Down Expand Up @@ -334,16 +348,19 @@ async def _acomplex_evolution(
assert self.question_filter is not None, "question_filter cannot be None"
assert self.se is not None, "simple evolution cannot be None"

simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
simple_question, current_nodes, _ = await self.se._aevolve(
current_tries, current_nodes
)
logger.debug(
"[%s] simple question generated: %s",
self.__class__.__name__,
simple_question,
)

merged_node = self.merge_nodes(current_nodes)
result = await self.generator_llm.generate(
prompt=question_prompt.format(
question=simple_question, context=current_nodes.root_node.page_content
question=simple_question, context=merged_node.page_content
)
)
reasoning_question = result.generations[0][0].text.strip()
Expand Down Expand Up @@ -409,42 +426,53 @@ async def _aevolve(
assert self.question_filter is not None, "question_filter cannot be None"
assert self.se is not None, "simple evolution cannot be None"

simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
simple_question, current_nodes, _ = await self.se._aevolve(
current_tries, current_nodes
)
logger.debug(
"[MultiContextEvolution] simple question generated: %s", simple_question
)

# find a similar node and generate a question based on both
similar_node = self.docstore.get_similar(current_nodes.root_node)
merged_node = self.merge_nodes(current_nodes)
similar_node = self.docstore.get_similar(merged_node, top_k=1)
if similar_node == []:
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
new_random_nodes = self.docstore.get_random_nodes(k=1)
current_nodes = CurrentNodes(
root_node=new_random_nodes[0], nodes=new_random_nodes
)
return await self.aretry_evolve(current_tries, current_nodes)
else:
assert isinstance(similar_node[0], Node), "similar_node must be a Node"
current_nodes = CurrentNodes(
root_node=merged_node, nodes=[merged_node, similar_node[0]]
)

prompt = self.multi_context_question_prompt.format(
question=simple_question,
context1=current_nodes.root_node.page_content,
context2=similar_node,
context1=merged_node.page_content,
context2=similar_node[0].page_content,
)
results = await self.generator_llm.generate(prompt=prompt)
question = results.generations[0][0].text.strip()
logger.debug(
"[MultiContextEvolution] multicontext question generated: %s", question
)

if not await self.question_filter.filter(question):
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
return await self.aretry_evolve(current_tries, current_nodes)

# compress the question
compressed_question = await self._transform_question(
prompt=self.compress_question_prompt, question=question
)
logger.debug(
"[MultiContextEvolution] multicontext question compressed: %s", question
"[MultiContextEvolution] multicontext question compressed: %s",
compressed_question,
)

if not await self.question_filter.filter(compressed_question):
# retry
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
return await self.aretry_evolve(current_tries, current_nodes)

assert self.evolution_filter is not None, "evolution filter cannot be None"
if await self.evolution_filter.filter(simple_question, compressed_question):
# retry
Expand Down