Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions tests/unit/vertexai/genai/replays/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import re
from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types
from google.genai import types as genai_types
Expand Down Expand Up @@ -353,13 +353,22 @@ def test_evaluation_agent_data(client):
assert case_result.response_candidate_results is not None


def test_metric_resource_name(client):
def test_evaluation_metric_resource_name(client):
"""Tests with a metric resource name in types.Metric."""
client._api_client._http_options.api_version = "v1beta1"
client._api_client._http_options.base_url = (
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
)
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
metric_resource_name = client.evals.create_evaluation_metric(
display_name="test_metric",
description="test_description",
metric=types.RubricMetric.GENERAL_QUALITY,
)
assert isinstance(metric_resource_name, str)
assert re.match(
r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$",
metric_resource_name,
)
byor_df = pd.DataFrame(
{
"prompt": ["Write a simple story about a dinosaur"],
Expand All @@ -375,8 +384,22 @@ def test_metric_resource_name(client):
)
assert isinstance(evaluation_result, types.EvaluationResult)
assert evaluation_result.eval_case_results is not None
assert len(evaluation_result.eval_case_results) > 0
assert len(evaluation_result.eval_case_results) == 1
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"
assert evaluation_result.summary_metrics[0].mean_score is not None
assert evaluation_result.summary_metrics[0].num_cases_valid == 1
assert evaluation_result.summary_metrics[0].num_cases_error == 0

case_result = evaluation_result.eval_case_results[0]
assert case_result.response_candidate_results is not None
assert len(case_result.response_candidate_results) == 1

metric_result = case_result.response_candidate_results[0].metric_results[
"my_custom_metric"
]
assert metric_result.score is not None
assert metric_result.score > 0.2
assert metric_result.error_message is None


pytestmark = pytest_helper.setup(
Expand Down
76 changes: 50 additions & 26 deletions vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,66 +1281,91 @@ def aggregate(
)


class RegisteredMetricHandler(MetricHandler[types.MetricSource]):
class RegisteredMetricHandler(MetricHandler[types.Metric]):
"""Metric handler for registered metrics."""

def __init__(
self,
module: "evals.Evals",
metric: Union[types.MetricSource, types.MetricSourceDict],
metric: types.Metric,
):
if isinstance(metric, dict):
metric = types.MetricSource(**metric)
super().__init__(module=module, metric=metric)

# TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
def _build_request_payload(
self, eval_case: types.EvalCase, response_index: int
) -> dict[str, Any]:
"""Builds request payload for registered metric."""
if not self.metric.metric:
"""Builds request payload for registered metric by assembling EvaluationInstance."""
response_content = _get_response_from_eval_case(
eval_case, response_index, self.metric_name
)

if not response_content and not getattr(eval_case, "agent_data", None):
raise ValueError(
"Registered metric must have an underlying metric definition."
f"Response content missing for candidate {response_index}."
)

reference_instance_data = None
if eval_case.reference:
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
eval_case.reference.response
)
return PredefinedMetricHandler(
self.module, metric=self.metric.metric
)._build_request_payload(eval_case, response_index)

extracted_prompt = _get_prompt_from_eval_case(eval_case)
prompt_instance_data = PredefinedMetricHandler._content_to_instance_data(
extracted_prompt
)

instance_payload = types.EvaluationInstance(
prompt=prompt_instance_data,
response=PredefinedMetricHandler._content_to_instance_data(
response_content
),
reference=reference_instance_data,
rubric_groups=eval_case.rubric_groups,
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
)

request_payload = {
"instance": instance_payload,
}
return request_payload

@property
def metric_name(self) -> str:
# Resolve name from resource name or internal metric name
if isinstance(self.metric, types.MetricSource):
if self.metric.metric and self.metric.metric.name:
return self.metric.metric.name
if self.metric.metric_resource_name:
return self.metric.metric_resource_name
return "unknown"
else: # Should be Metric
metric_like = self.metric
if metric_like.name:
return metric_like.name
if metric_like.metric_resource_name:
return metric_like.metric_resource_name
return "unknown"
return self.metric.name or "unknown_metric"

@override
def get_metric_result(
self, eval_case: types.EvalCase, response_index: int
) -> types.EvalCaseMetricResult:
"""Processes a single evaluation case for a registered metric."""
"""Processes a single evaluation case using a MetricSource reference."""
metric_name = self.metric_name
metric_source = types.MetricSource(
metric_resource_name=self.metric.metric_resource_name
)

try:
payload = self._build_request_payload(eval_case, response_index)
for attempt in range(_MAX_RETRIES):
try:
api_response = self.module._evaluate_instances(
metric_sources=[self.metric],
metric_sources=[metric_source],
instance=payload.get("instance"),
autorater_config=payload.get("autorater_config"),
)
break
except genai_errors.ClientError as e:
if e.code == 429:
logger.warning(
"Resource Exhausted error on attempt %d/%d: %s. Retrying in %s"
" seconds...",
attempt + 1,
_MAX_RETRIES,
e,
2**attempt,
)
if attempt == _MAX_RETRIES - 1:
return types.EvalCaseMetricResult(
metric_name=metric_name,
Expand Down Expand Up @@ -1377,7 +1402,6 @@ def aggregate(
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
) -> types.AggregatedMetricResult:
"""Aggregates the metric results for a registered metric."""
logger.debug("Aggregating results for registered metric: %s", self.metric_name)
return _default_aggregate_scores(
self.metric_name, eval_case_metric_results, calculate_pass_rate=True
)
Expand Down
70 changes: 59 additions & 11 deletions vertexai/_genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
_METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$"


def t_metric(
metric: "types.MetricSubclass",
) -> dict[str, Any]:
"""Prepares the metric payload for a single metric."""
return t_metrics([metric])[0]


def t_metrics(
metrics: "list[types.MetricSubclass]",
set_default_aggregation_metrics: bool = False,
Expand Down Expand Up @@ -82,16 +75,19 @@ def t_metrics(
}
# Pointwise metrics
elif hasattr(metric, "prompt_template") and metric.prompt_template:
pointwise_spec = {"metric_prompt_template": metric.prompt_template}
llm_based_spec = {"metric_prompt_template": metric.prompt_template}
system_instruction = getv(metric, ["judge_model_system_instruction"])
if system_instruction:
pointwise_spec["system_instruction"] = system_instruction
llm_based_spec["system_instruction"] = system_instruction
rubric_group_name = getv(metric, ["rubric_group_name"])
if rubric_group_name:
llm_based_spec["rubric_group_key"] = rubric_group_name
return_raw_output = getv(metric, ["return_raw_output"])
if return_raw_output:
pointwise_spec["custom_output_format_config"] = {
llm_based_spec["custom_output_format_config"] = {
"return_raw_output": return_raw_output
}
metric_payload_item["pointwise_metric_spec"] = pointwise_spec
metric_payload_item["llm_based_metric_spec"] = llm_based_spec
elif getattr(metric, "metric_resource_name", None) is not None:
# Safe pass
pass
Expand Down Expand Up @@ -127,3 +123,55 @@ def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]:
metric_payload = t_metrics([metric])[0]
sources_payload.append({"metric": metric_payload})
return sources_payload


def t_metric_for_registry(
metric: "types.Metric",
) -> dict[str, Any]:
"""Prepares the metric payload specifically for EvaluationMetric registration."""
metric_payload_item: dict[str, Any] = {}
metric_name = getattr(metric, "name", None)
if metric_name:
metric_name = metric_name.lower()

# Handle standard computation metrics
if metric_name == "exact_match":
metric_payload_item["exact_match_spec"] = {}
elif metric_name == "bleu":
metric_payload_item["bleu_spec"] = {}
elif metric_name and metric_name.startswith("rouge"):
rouge_type = metric_name.replace("_", "")
metric_payload_item["rouge_spec"] = {"rouge_type": rouge_type}
# API Pre-defined metrics
elif metric_name and metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
metric_payload_item["predefined_metric_spec"] = {
"metric_spec_name": metric_name,
"metric_spec_parameters": metric.metric_spec_parameters,
}
# Custom Code Execution Metric
elif hasattr(metric, "remote_custom_function") and metric.remote_custom_function:
metric_payload_item["custom_code_execution_spec"] = {
"evaluation_function": metric.remote_custom_function
}

# Map LLM-based metrics to the new llm_based_metric_spec
elif (hasattr(metric, "prompt_template") and metric.prompt_template) or (
hasattr(metric, "rubric_group_name") and metric.rubric_group_name
):
llm_based_spec = {}

if hasattr(metric, "prompt_template") and metric.prompt_template:
llm_based_spec["metric_prompt_template"] = metric.prompt_template
system_instruction = getv(metric, ["judge_model_system_instruction"])
if system_instruction:
llm_based_spec["system_instruction"] = system_instruction
rubric_group_name = getv(metric, ["rubric_group_name"])
if rubric_group_name:
llm_based_spec["rubric_group_key"] = rubric_group_name

metric_payload_item["llm_based_metric_spec"] = llm_based_spec

else:
raise ValueError(f"Unsupported metric type: {metric_name}")

return metric_payload_item
20 changes: 19 additions & 1 deletion vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def _CreateEvaluationMetricParameters_to_vertex(
setv(to_object, ["description"], getv(from_object, ["description"]))

if getv(from_object, ["metric"]) is not None:
setv(to_object, ["metric"], t.t_metric(getv(from_object, ["metric"])))
setv(
to_object,
["metric"],
t.t_metric_for_registry(getv(from_object, ["metric"])),
)

if getv(from_object, ["config"]) is not None:
setv(to_object, ["config"], getv(from_object, ["config"]))
Expand Down Expand Up @@ -2346,6 +2350,13 @@ def create_evaluation_metric(
)
metric = resolved_metrics[0]

# Add fallback logic for display_name
if display_name is None and metric:
if isinstance(metric, dict):
display_name = metric.get("name")
else:
display_name = getattr(metric, "name", None)

result = self._create_evaluation_metric(
display_name=display_name,
description=description,
Expand Down Expand Up @@ -3519,6 +3530,13 @@ async def create_evaluation_metric(
)
metric = resolved_metrics[0]

# Add fallback logic for display_name
if display_name is None and metric:
if isinstance(metric, dict):
display_name = metric.get("name")
else:
display_name = getattr(metric, "name", None)

result = await self._create_evaluation_metric(
display_name=display_name,
description=description,
Expand Down
Loading