Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/ragas/llms/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,26 @@ def format(self, **kwargs: t.Any) -> PromptValue:
def adapt(
self, language: str, llm: BaseRagasLLM, cache_dir: t.Optional[str] = None
) -> Prompt:

def get_all_keys(nested_json):
keys = set()
for key, value in nested_json.items():
keys.add(key)
if isinstance(value, dict):
keys = keys.union(get_all_keys(value))
return keys

if self.language == language:
return self

# TODO: Add callbacks
cache_dir = cache_dir if cache_dir else get_cache_dir()
if os.path.exists(os.path.join(cache_dir, language, f"{self.name}.json")):
return self._load(language, self.name, cache_dir)

logger.info("Adapting %s to %s", self.name, language)
prompts = []
output_keys = []
for example in self.examples:
prompts.extend(
[
Expand All @@ -176,6 +187,8 @@ def adapt(
translate_to=language, input=example.get(self.output_key)
)
)
if self.output_type.lower() == "json":
output_keys.append(get_all_keys(example.get(self.output_key)))

# NOTE: this is a slow loop, consider Executor to fasten this
results = []
Expand All @@ -200,6 +213,11 @@ def adapt(
else example[-1]
)

if self.output_type.lower() == "json":
assert (
set(example_dict[self.output_key].keys()) == output_keys[i]
), "Adapted output keys do not match with the original output keys"

self.examples[i] = example_dict

self.language = language
Expand Down