diff --git a/src/ragas/llms/prompt.py b/src/ragas/llms/prompt.py index 1faa74d39..3b273f18c 100644 --- a/src/ragas/llms/prompt.py +++ b/src/ragas/llms/prompt.py @@ -149,8 +149,18 @@ def format(self, **kwargs: t.Any) -> PromptValue: def adapt( self, language: str, llm: BaseRagasLLM, cache_dir: t.Optional[str] = None ) -> Prompt: + + def get_all_keys(nested_json): + keys = set() + for key, value in nested_json.items(): + keys.add(key) + if isinstance(value, dict): + keys = keys.union(get_all_keys(value)) + return keys + if self.language == language: return self + # TODO: Add callbacks cache_dir = cache_dir if cache_dir else get_cache_dir() if os.path.exists(os.path.join(cache_dir, language, f"{self.name}.json")): @@ -158,6 +168,7 @@ def adapt( logger.info("Adapting %s to %s", self.name, language) prompts = [] + output_keys = [] for example in self.examples: prompts.extend( [ @@ -176,6 +187,8 @@ def adapt( translate_to=language, input=example.get(self.output_key) ) ) + if self.output_type.lower() == "json": + output_keys.append(get_all_keys(example.get(self.output_key))) # NOTE: this is a slow loop, consider Executor to fasten this results = [] @@ -200,6 +213,11 @@ def adapt( else example[-1] ) + if self.output_type.lower() == "json": + assert ( + set(example_dict[self.output_key].keys()) == output_keys[i] + ), "Adapted output keys do not match with the original output keys" + self.examples[i] = example_dict self.language = language