Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ragas/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def parse_run_traces(
prompt_trace = traces[prompt_uuid]
prompt_traces[f"{prompt_trace.name}"] = {
"input": prompt_trace.inputs.get("data", {}),
"output": prompt_trace.outputs.get("output", {}),
"output": prompt_trace.outputs.get("output", {})[0],
}
metric_traces[f"{metric_trace.name}"] = prompt_traces
parased_traces.append(metric_traces)
Expand Down
10 changes: 1 addition & 9 deletions src/ragas/optimizers/genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def dict_to_str(dict: t.Dict[str, t.Any]) -> str:
exclude_none=True
)
),
output=traces[idx][prompt_name]["output"][0].model_dump(
output=traces[idx][prompt_name]["output"].model_dump(
exclude_none=True
),
expected_output=dataset[idx]["prompts"][prompt_name][
Expand Down Expand Up @@ -586,14 +586,6 @@ def evaluate_candidate(
_run_id=run_id,
_pbar=parent_pbar,
)
# remap the traces to the original prompt names
remap_traces = {val.name: key for key, val in self.metric.get_prompts().items()}
for trace in results.traces:
for key in remap_traces:
if key in trace[self.metric.name]:
trace[self.metric.name][remap_traces[key]] = trace[
self.metric.name
].pop(key)
return results

def evaluate_fitness(
Expand Down
16 changes: 12 additions & 4 deletions src/ragas/prompt/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ class PromptMixin:
eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM]
"""

def _get_prompts(self) -> t.Dict[str, PydanticPrompt]:

prompts = {}
for key, value in inspect.getmembers(self):
if isinstance(value, PydanticPrompt):
prompts.update({key: value})
return prompts

def get_prompts(self) -> t.Dict[str, PydanticPrompt]:
"""
Returns a dictionary of prompts for the class.
"""
prompts = {}
for name, value in inspect.getmembers(self):
if isinstance(value, PydanticPrompt):
prompts.update({name: value})
for _, value in self._get_prompts().items():
prompts.update({value.name: value})
return prompts

def set_prompts(self, **prompts):
Expand All @@ -41,6 +48,7 @@ def set_prompts(self, **prompts):
If the prompt is not an instance of `PydanticPrompt`.
"""
available_prompts = self.get_prompts()
name_to_var = {v.name: k for k, v in self._get_prompts().items()}
for key, value in prompts.items():
if key not in available_prompts:
raise ValueError(
Expand All @@ -50,7 +58,7 @@ def set_prompts(self, **prompts):
raise ValueError(
f"Prompt with name '{key}' must be an instance of 'ragas.prompt.PydanticPrompt'"
)
setattr(self, key, value)
setattr(self, name_to_var[key], value)

async def adapt_prompts(
self, language: str, llm: BaseRagasLLM, adapt_instruction: bool = False
Expand Down
Loading