From 24f626c804f82ec3d9249b85ce2c84eec1485fd8 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 18 Mar 2026 18:11:00 -0700 Subject: [PATCH] feat: Refactor RegisteredMetricHandler and implement llm_based_metric_spec support PiperOrigin-RevId: 885888137 --- .../vertexai/genai/replays/test_evaluate.py | 31 +++++++- vertexai/_genai/_evals_metric_handlers.py | 76 ++++++++++++------- vertexai/_genai/_transformers.py | 70 ++++++++++++++--- vertexai/_genai/evals.py | 20 ++++- 4 files changed, 155 insertions(+), 42 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_evaluate.py b/tests/unit/vertexai/genai/replays/test_evaluate.py index 89c802c43d..51f394ab1c 100644 --- a/tests/unit/vertexai/genai/replays/test_evaluate.py +++ b/tests/unit/vertexai/genai/replays/test_evaluate.py @@ -13,7 +13,7 @@ # limitations under the License. # # pylint: disable=protected-access,bad-continuation,missing-function-docstring - +import re from tests.unit.vertexai.genai.replays import pytest_helper from vertexai._genai import types from google.genai import types as genai_types @@ -353,13 +353,22 @@ def test_evaluation_agent_data(client): assert case_result.response_candidate_results is not None -def test_metric_resource_name(client): +def test_evaluation_metric_resource_name(client): """Tests with a metric resource name in types.Metric.""" client._api_client._http_options.api_version = "v1beta1" client._api_client._http_options.base_url = ( "https://us-central1-staging-aiplatform.sandbox.googleapis.com/" ) - metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128" + metric_resource_name = client.evals.create_evaluation_metric( + display_name="test_metric", + description="test_description", + metric=types.RubricMetric.GENERAL_QUALITY, + ) + assert isinstance(metric_resource_name, str) + assert re.match( + r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$", + metric_resource_name, + ) byor_df = pd.DataFrame( { "prompt": ["Write a simple story about a dinosaur"], @@ -375,8 +384,22 @@ def test_metric_resource_name(client): ) assert isinstance(evaluation_result, types.EvaluationResult) assert evaluation_result.eval_case_results is not None - assert len(evaluation_result.eval_case_results) > 0 + assert len(evaluation_result.eval_case_results) == 1 assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric" + assert evaluation_result.summary_metrics[0].mean_score is not None + assert evaluation_result.summary_metrics[0].num_cases_valid == 1 + assert evaluation_result.summary_metrics[0].num_cases_error == 0 + + case_result = evaluation_result.eval_case_results[0] + assert case_result.response_candidate_results is not None + assert len(case_result.response_candidate_results) == 1 + + metric_result = case_result.response_candidate_results[0].metric_results[ + "my_custom_metric" + ] + assert metric_result.score is not None + assert metric_result.score > 0.2 + assert metric_result.error_message is None pytestmark = pytest_helper.setup( diff --git a/vertexai/_genai/_evals_metric_handlers.py b/vertexai/_genai/_evals_metric_handlers.py index 10e07fe07a..19b8aa4fd4 100644 --- a/vertexai/_genai/_evals_metric_handlers.py +++ b/vertexai/_genai/_evals_metric_handlers.py @@ -1281,66 +1281,91 @@ def aggregate( ) -class RegisteredMetricHandler(MetricHandler[types.MetricSource]): +class RegisteredMetricHandler(MetricHandler[types.Metric]): """Metric handler for registered metrics.""" def __init__( self, module: "evals.Evals", - metric: Union[types.MetricSource, types.MetricSourceDict], + metric: types.Metric, ): if isinstance(metric, dict): metric = types.MetricSource(**metric) super().__init__(module=module, metric=metric) - # TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler. def _build_request_payload( self, eval_case: types.EvalCase, response_index: int ) -> dict[str, Any]: - """Builds request payload for registered metric.""" - if not self.metric.metric: + """Builds request payload for registered metric by assembling EvaluationInstance.""" + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric_name + ) + + if not response_content and not getattr(eval_case, "agent_data", None): raise ValueError( - "Registered metric must have an underlying metric definition." + f"Response content missing for candidate {response_index}." + ) + + reference_instance_data = None + if eval_case.reference: + reference_instance_data = PredefinedMetricHandler._content_to_instance_data( + eval_case.reference.response ) - return PredefinedMetricHandler( - self.module, metric=self.metric.metric - )._build_request_payload(eval_case, response_index) + + extracted_prompt = _get_prompt_from_eval_case(eval_case) + prompt_instance_data = PredefinedMetricHandler._content_to_instance_data( + extracted_prompt + ) + + instance_payload = types.EvaluationInstance( + prompt=prompt_instance_data, + response=PredefinedMetricHandler._content_to_instance_data( + response_content + ), + reference=reference_instance_data, + rubric_groups=eval_case.rubric_groups, + agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case), + ) + + request_payload = { + "instance": instance_payload, + } + return request_payload @property def metric_name(self) -> str: - # Resolve name from resource name or internal metric name - if isinstance(self.metric, types.MetricSource): - if self.metric.metric and self.metric.metric.name: - return self.metric.metric.name - if self.metric.metric_resource_name: - return self.metric.metric_resource_name - return "unknown" - else: # Should be Metric - metric_like = self.metric - if metric_like.name: - return metric_like.name - if metric_like.metric_resource_name: - return metric_like.metric_resource_name - return "unknown" + return self.metric.name or "unknown_metric" @override def get_metric_result( self, eval_case: types.EvalCase, response_index: int ) -> types.EvalCaseMetricResult: - """Processes a single evaluation case for a registered metric.""" + """Processes a single evaluation case using a MetricSource reference.""" metric_name = self.metric_name + metric_source = types.MetricSource( + metric_resource_name=self.metric.metric_resource_name + ) + try: payload = self._build_request_payload(eval_case, response_index) for attempt in range(_MAX_RETRIES): try: api_response = self.module._evaluate_instances( - metric_sources=[self.metric], + metric_sources=[metric_source], instance=payload.get("instance"), autorater_config=payload.get("autorater_config"), ) break except genai_errors.ClientError as e: if e.code == 429: + logger.warning( + "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" + " seconds...", + attempt + 1, + _MAX_RETRIES, + e, + 2**attempt, + ) if attempt == _MAX_RETRIES - 1: return types.EvalCaseMetricResult( metric_name=metric_name, @@ -1377,7 +1402,6 @@ def aggregate( self, eval_case_metric_results: list[types.EvalCaseMetricResult] ) -> types.AggregatedMetricResult: """Aggregates the metric results for a registered metric.""" - logger.debug("Aggregating results for registered metric: %s", self.metric_name) return _default_aggregate_scores( self.metric_name, eval_case_metric_results, calculate_pass_rate=True ) diff --git a/vertexai/_genai/_transformers.py b/vertexai/_genai/_transformers.py index 05e1460497..f3e56e0be8 100644 --- a/vertexai/_genai/_transformers.py +++ b/vertexai/_genai/_transformers.py @@ -25,13 +25,6 @@ _METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$" -def t_metric( - metric: "types.MetricSubclass", -) -> dict[str, Any]: - """Prepares the metric payload for a single metric.""" - return t_metrics([metric])[0] - - def t_metrics( metrics: "list[types.MetricSubclass]", set_default_aggregation_metrics: bool = False, @@ -82,16 +75,19 @@ def t_metrics( } # Pointwise metrics elif hasattr(metric, "prompt_template") and metric.prompt_template: - pointwise_spec = {"metric_prompt_template": metric.prompt_template} + llm_based_spec = {"metric_prompt_template": metric.prompt_template} system_instruction = getv(metric, ["judge_model_system_instruction"]) if system_instruction: - pointwise_spec["system_instruction"] = system_instruction + llm_based_spec["system_instruction"] = system_instruction + rubric_group_name = getv(metric, ["rubric_group_name"]) + if rubric_group_name: + llm_based_spec["rubric_group_key"] = rubric_group_name return_raw_output = getv(metric, ["return_raw_output"]) if return_raw_output: - pointwise_spec["custom_output_format_config"] = { + llm_based_spec["custom_output_format_config"] = { "return_raw_output": return_raw_output } - metric_payload_item["pointwise_metric_spec"] = pointwise_spec + metric_payload_item["llm_based_metric_spec"] = llm_based_spec elif getattr(metric, "metric_resource_name", None) is not None: # Safe pass pass @@ -127,3 +123,55 @@ def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]: metric_payload = t_metrics([metric])[0] sources_payload.append({"metric": metric_payload}) return sources_payload + + +def t_metric_for_registry( + metric: "types.Metric", +) -> dict[str, Any]: + """Prepares the metric payload specifically for EvaluationMetric registration.""" + metric_payload_item: dict[str, Any] = {} + metric_name = getattr(metric, "name", None) + if metric_name: + metric_name = metric_name.lower() + + # Handle standard computation metrics + if metric_name == "exact_match": + metric_payload_item["exact_match_spec"] = {} + elif metric_name == "bleu": + metric_payload_item["bleu_spec"] = {} + elif metric_name and metric_name.startswith("rouge"): + rouge_type = metric_name.replace("_", "") + metric_payload_item["rouge_spec"] = {"rouge_type": rouge_type} + # API Pre-defined metrics + elif metric_name and metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS: + metric_payload_item["predefined_metric_spec"] = { + "metric_spec_name": metric_name, + "metric_spec_parameters": metric.metric_spec_parameters, + } + # Custom Code Execution Metric + elif hasattr(metric, "remote_custom_function") and metric.remote_custom_function: + metric_payload_item["custom_code_execution_spec"] = { + "evaluation_function": metric.remote_custom_function + } + + # Map LLM-based metrics to the new llm_based_metric_spec + elif (hasattr(metric, "prompt_template") and metric.prompt_template) or ( + hasattr(metric, "rubric_group_name") and metric.rubric_group_name + ): + llm_based_spec = {} + + if hasattr(metric, "prompt_template") and metric.prompt_template: + llm_based_spec["metric_prompt_template"] = metric.prompt_template + system_instruction = getv(metric, ["judge_model_system_instruction"]) + if system_instruction: + llm_based_spec["system_instruction"] = system_instruction + rubric_group_name = getv(metric, ["rubric_group_name"]) + if rubric_group_name: + llm_based_spec["rubric_group_key"] = rubric_group_name + + metric_payload_item["llm_based_metric_spec"] = llm_based_spec + + else: + raise ValueError(f"Unsupported metric type: {metric_name}") + + return metric_payload_item diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 1dea3ac853..05b37bd369 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -79,7 +79,11 @@ def _CreateEvaluationMetricParameters_to_vertex( setv(to_object, ["description"], getv(from_object, ["description"])) if getv(from_object, ["metric"]) is not None: - setv(to_object, ["metric"], t.t_metric(getv(from_object, ["metric"]))) + setv( + to_object, + ["metric"], + t.t_metric_for_registry(getv(from_object, ["metric"])), + ) if getv(from_object, ["config"]) is not None: setv(to_object, ["config"], getv(from_object, ["config"])) @@ -2346,6 +2350,13 @@ def create_evaluation_metric( ) metric = resolved_metrics[0] + # Add fallback logic for display_name + if display_name is None and metric: + if isinstance(metric, dict): + display_name = metric.get("name") + else: + display_name = getattr(metric, "name", None) + result = self._create_evaluation_metric( display_name=display_name, description=description, @@ -3519,6 +3530,13 @@ async def create_evaluation_metric( ) metric = resolved_metrics[0] + # Add fallback logic for display_name + if display_name is None and metric: + if isinstance(metric, dict): + display_name = metric.get("name") + else: + display_name = getattr(metric, "name", None) + result = await self._create_evaluation_metric( display_name=display_name, description=description,