From 2ef09fad0a2b12a49a3181382c452d024f6ca788 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 2 Feb 2024 12:36:19 -0800 Subject: [PATCH 1/2] test adapted dict keys --- src/ragas/llms/prompt.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/ragas/llms/prompt.py b/src/ragas/llms/prompt.py index 505849c77..6b3f8cedd 100644 --- a/src/ragas/llms/prompt.py +++ b/src/ragas/llms/prompt.py @@ -147,12 +147,21 @@ def format(self, **kwargs: t.Any) -> PromptValue: def adapt( self, language: str, llm: BaseRagasLLM, cache_dir: t.Optional[str] = None ) -> Prompt: + def get_all_keys(nested_json): + keys = set() + for key, value in nested_json.items(): + keys.add(key) + if isinstance(value, dict): + keys = keys.union(get_all_keys(value)) + return keys + # TODO: Add callbacks cache_dir = cache_dir if cache_dir else get_cache_dir() if os.path.exists(os.path.join(cache_dir, language, f"{self.name}.json")): return self._load(language, self.name, cache_dir) prompts = [] + output_keys = [] for example in self.examples: prompts.extend( [ @@ -171,6 +180,8 @@ def adapt( translate_to=language, input=example.get(self.output_key) ) ) + if self.output_type.lower() == "json": + output_keys.append(get_all_keys(example.get(self.output_key))) # NOTE: this is a slow loop, consider Executor to fasten this results = [] @@ -195,6 +206,10 @@ def adapt( else example[-1] ) + assert ( + set(example_dict[self.output_key].keys()) == output_keys[i] + ), "Adapted output keys do not match with the original output keys" + self.examples[i] = example_dict self.language = language From a0919c11f32ef230ce24e2172b9196aacf38e8a6 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 2 Feb 2024 12:36:58 -0800 Subject: [PATCH 2/2] fix checks --- src/ragas/llms/prompt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ragas/llms/prompt.py b/src/ragas/llms/prompt.py index 6b3f8cedd..e8c5a8d67 100644 --- a/src/ragas/llms/prompt.py +++ b/src/ragas/llms/prompt.py @@ -206,9 +206,10 @@ def get_all_keys(nested_json): else example[-1] ) - assert ( - set(example_dict[self.output_key].keys()) == output_keys[i] - ), "Adapted output keys do not match with the original output keys" + if self.output_type.lower() == "json": + assert ( + set(example_dict[self.output_key].keys()) == output_keys[i] + ), "Adapted output keys do not match with the original output keys" self.examples[i] = example_dict