From 55064c2a9423e855757ab9e6532274b9d50e19e0 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Thu, 8 Feb 2024 16:47:29 -0800 Subject: [PATCH 1/3] ensure dict type --- src/ragas/testset/evolutions.py | 7 +++++-- src/ragas/testset/extractor.py | 1 + src/ragas/testset/filters.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 09802a188..90c275e63 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -180,9 +180,12 @@ async def generate_datarow( relevent_contexts_result = await json_loader.safe_load( results.generations[0][0].text.strip(), llm=self.generator_llm ) - relevant_context_indices = relevent_contexts_result.get( - "relevant_context", None + relevant_context_indices = ( + relevent_contexts_result.get("relevant_context", None) + if isinstance(relevent_contexts_result, dict) + else None ) + if relevant_context_indices is None: relevant_context = CurrentNodes( root_node=current_nodes.root_node, nodes=current_nodes.nodes diff --git a/src/ragas/testset/extractor.py b/src/ragas/testset/extractor.py index 77c586c2e..09154d8cf 100644 --- a/src/ragas/testset/extractor.py +++ b/src/ragas/testset/extractor.py @@ -50,6 +50,7 @@ async def extract(self, node: Node, is_async: bool = True) -> t.List[str]: keyphrases = await json_loader.safe_load( results.generations[0][0].text.strip(), llm=self.llm, is_async=is_async ) + keyphrases = keyphrases if isinstance(keyphrases, dict) else {} logger.debug("keyphrases: %s", keyphrases) return keyphrases.get("keyphrases", []) diff --git a/src/ragas/testset/filters.py b/src/ragas/testset/filters.py index 0eb06b77b..9d017b7e5 100644 --- a/src/ragas/testset/filters.py +++ b/src/ragas/testset/filters.py @@ -54,6 +54,7 @@ async def filter(self, node: Node) -> t.Dict: results = await self.llm.generate(prompt=prompt) output = results.generations[0][0].text.strip() score = await json_loader.safe_load(output, llm=self.llm) + score = score if isinstance(score, dict) else {} logger.debug("node filter: %s", score) score.update({"score": score.get("score", 0) >= self.threshold}) return score @@ -85,6 +86,7 @@ async def filter(self, question: str) -> bool: results = await self.llm.generate(prompt=prompt) results = results.generations[0][0].text.strip() json_results = await json_loader.safe_load(results, llm=self.llm) + json_results = json_results if isinstance(json_results, dict) else {} logger.debug("filtered question: %s", json_results) return json_results.get("verdict") == "1" @@ -117,6 +119,7 @@ async def filter(self, simple_question: str, compressed_question: str) -> bool: results = await self.llm.generate(prompt=prompt) results = results.generations[0][0].text.strip() json_results = await json_loader.safe_load(results, llm=self.llm) + json_results = json_results if isinstance(json_results, dict) else {} logger.debug("evolution filter: %s", json_results) return json_results.get("verdict") == "1" From d28bd75c44f3905d44c8666c4492ff0a4ca720bf Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 26 Feb 2024 12:36:22 -0800 Subject: [PATCH 2/3] add demonstration --- src/ragas/metrics/_context_recall.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ragas/metrics/_context_recall.py b/src/ragas/metrics/_context_recall.py index 57069119b..488678079 100644 --- a/src/ragas/metrics/_context_recall.py +++ b/src/ragas/metrics/_context_recall.py @@ -58,6 +58,16 @@ "Attributed": "1", }, }, + { + "question": """What is the primary fuel for the Sun?""", + "context": """NULL""", + "answer": """Hydrogen""", + "classification": { + "statement_1": "The Sun's primary fuel is hydrogen.", + "reason": "The context contains no information", + "Attributed": "0", + }, + }, ], input_keys=["question", "context", "answer"], output_key="classification", From 9134732e965706561f97fedbddd8c4ceedf9d321 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Mon, 26 Feb 2024 12:36:32 -0800 Subject: [PATCH 3/3] replace assert --- src/ragas/metrics/_faithfulness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 9e10c54af..b03ca4f2b 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -187,7 +187,7 @@ async def _ascore( is_async=is_async, ) - assert isinstance(statements, dict), "Invalid JSON response" + statements = statements if isinstance(statements, dict) else {} p = self._create_nli_prompt(row, statements.get("statements", [])) nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async) json_output = await json_loader.safe_load(