From 55064c2a9423e855757ab9e6532274b9d50e19e0 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Thu, 8 Feb 2024 16:47:29 -0800 Subject: [PATCH 1/3] ensure dict type --- src/ragas/testset/evolutions.py | 7 +++++-- src/ragas/testset/extractor.py | 1 + src/ragas/testset/filters.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 09802a188..90c275e63 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -180,9 +180,12 @@ async def generate_datarow( relevent_contexts_result = await json_loader.safe_load( results.generations[0][0].text.strip(), llm=self.generator_llm ) - relevant_context_indices = relevent_contexts_result.get( - "relevant_context", None + relevant_context_indices = ( + relevent_contexts_result.get("relevant_context", None) + if isinstance(relevent_contexts_result, dict) + else None ) + if relevant_context_indices is None: relevant_context = CurrentNodes( root_node=current_nodes.root_node, nodes=current_nodes.nodes diff --git a/src/ragas/testset/extractor.py b/src/ragas/testset/extractor.py index 77c586c2e..09154d8cf 100644 --- a/src/ragas/testset/extractor.py +++ b/src/ragas/testset/extractor.py @@ -50,6 +50,7 @@ async def extract(self, node: Node, is_async: bool = True) -> t.List[str]: keyphrases = await json_loader.safe_load( results.generations[0][0].text.strip(), llm=self.llm, is_async=is_async ) + keyphrases = keyphrases if isinstance(keyphrases, dict) else {} logger.debug("keyphrases: %s", keyphrases) return keyphrases.get("keyphrases", []) diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py index 0eb06b77b..9d017b7e5 100644 --- a/src/ragas/testset/filters.py +++ b/src/ragas/testset/filters.py @@ -54,6 +54,7 @@ async def filter(self, node: Node) -> t.Dict: results = await self.llm.generate(prompt=prompt) output = results.generations[0][0].text.strip() score = await json_loader.safe_load(output, llm=self.llm) + score = score if isinstance(score, dict) else {} logger.debug("node filter: %s", score) score.update({"score": score.get("score", 0) >= self.threshold}) return score @@ -85,6 +86,7 @@ async def filter(self, question: str) -> bool: results = await self.llm.generate(prompt=prompt) results = results.generations[0][0].text.strip() json_results = await json_loader.safe_load(results, llm=self.llm) + json_results = json_results if isinstance(json_results, dict) else {} logger.debug("filtered question: %s", json_results) return json_results.get("verdict") == "1" @@ -117,6 +119,7 @@ async def filter(self, simple_question: str, compressed_question: str) -> bool: results = await self.llm.generate(prompt=prompt) results = results.generations[0][0].text.strip() json_results = await json_loader.safe_load(results, llm=self.llm) + json_results = json_results if isinstance(json_results, dict) else {} logger.debug("evolution filter: %s", json_results) return json_results.get("verdict") == "1" From 1459df5d6104b37a581822d7b0744b9e1b749dff Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 16 Feb 2024 14:47:45 -0800 Subject: [PATCH 2/3] fix condition --- src/ragas/testset/evolutions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 46e85a448..a75574e1c 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -72,7 +72,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node: keyphrases=[phrase for n in nodes.nodes for phrase in n.keyphrases], ) - embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding else None + embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding is not None else None if embed_dim: node_embeddings = np.array([n.embedding for n in nodes.nodes]).reshape( -1, embed_dim From e44c88816d94211df9eaef7543bae4654ad4b953 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 16 Feb 2024 17:49:58 -0800 Subject: [PATCH 3/3] fix node idx --- src/ragas/testset/evolutions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index a75574e1c..e312621ab 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -72,7 +72,11 @@ def merge_nodes(nodes: CurrentNodes) -> Node: keyphrases=[phrase for n in nodes.nodes for phrase in n.keyphrases], ) - embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding is not None else None + embed_dim = ( + len(nodes.nodes[0].embedding) + if nodes.nodes[0].embedding is not None + else None + ) if embed_dim: node_embeddings = np.array([n.embedding for n in nodes.nodes]).reshape( -1, embed_dim @@ -444,9 +448,7 @@ async def _aevolve( return await self.aretry_evolve(current_tries, current_nodes) else: assert isinstance(similar_node[0], Node), "similar_node must be a Node" - current_nodes = CurrentNodes( - root_node=merged_node, nodes=[merged_node, similar_node[0]] - ) + current_nodes.nodes.append(similar_node[0]) prompt = self.multi_context_question_prompt.format( question=simple_question,