Skip to content

Commit b3dba66

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - add async version of prompt optimizer
FUTURE_COPYBARA_INTEGRATE_REVIEW=#5501 from googleapis:release-please--branches--main 259a77b PiperOrigin-RevId: 776295311
1 parent 610c523 commit b3dba66

File tree

6 files changed

+222
-39
lines changed

6 files changed

+222
-39
lines changed

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def client(use_vertex, replays_prefix, http_options, request):
112112
)
113113
os.environ["GOOGLE_CLOUD_PROJECT"] = "project-id"
114114
os.environ["GOOGLE_CLOUD_LOCATION"] = "location"
115+
os.environ["VAPO_CONFIG_PATH"] = "gs://dummy-test/dummy-config.json"
116+
os.environ["VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"] = "1234567890"
115117

116118
# Set the replay directory to the root directory of the replays.
117119
# This is needed to ensure that the replay files are found.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
import os
18+
19+
from tests.unit.vertexai.genai.replays import pytest_helper
20+
from vertexai._genai import types
21+
22+
23+
def test_optimize(client):
24+
"""Tests the optimize request parameters method."""
25+
26+
if not os.environ.get("VAPO_CONFIG_PATH"):
27+
raise ValueError("VAPO_CONFIG_PATH environment variable is not set.")
28+
if not os.environ.get("VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"):
29+
raise ValueError(
30+
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER " "environment variable is not set."
31+
)
32+
33+
config = types.PromptOptimizerVAPOConfig(
34+
config_path=os.environ.get("VAPO_CONFIG_PATH"),
35+
wait_for_completion=True,
36+
service_account_project_number=os.environ.get(
37+
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
38+
),
39+
)
40+
job = client.prompt_optimizer.optimize(
41+
method="vapo",
42+
config=config,
43+
)
44+
assert job.state == types.JobState.JOB_STATE_SUCCEEDED
45+
46+
47+
pytestmark = pytest_helper.setup(
48+
file=__file__,
49+
globals_for_file=globals(),
50+
test_method="prompt_optimizer.optimize",
51+
)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Utility functions for prompt optimizer."""
16+
17+
from . import types
18+
19+
20+
def _get_service_account(
21+
config: types.PromptOptimizerVAPOConfigOrDict,
22+
) -> str:
23+
"""Get the service account from the config for the custom job."""
24+
if hasattr(config, "service_account") and config.service_account:
25+
if (hasattr(config, "service_account_project_number") and
26+
config.service_account_project_number):
27+
raise ValueError(
28+
"Only one of service_account or service_account_project_number "
29+
"can be provided."
30+
)
31+
return config.service_account
32+
elif (
33+
hasattr(config, "service_account_project_number")
34+
and config.service_account_project_number
35+
):
36+
return (
37+
f"{config.service_account_project_number}-compute@developer.gserviceaccount.com"
38+
)
39+
else:
40+
raise ValueError(
41+
"Either service_account or service_account_project_number is required."
42+
)

vertexai/_genai/client.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, api_client: genai_client.Client):
3232
self._api_client = api_client
3333
self._evals = None
3434
self._agent_engines = None
35+
self._prompt_optimizer = None
3536

3637
@property
3738
@_common.experimental_warning(
@@ -52,7 +53,17 @@ def evals(self):
5253
) from e
5354
return self._evals.AsyncEvals(self._api_client)
5455

55-
# TODO(b/424176979): add async prompt optimizer here.
56+
@property
57+
@_common.experimental_warning(
58+
"The Vertex SDK GenAI prompt optimizer module is experimental, "
59+
"and may change in future versions."
60+
)
61+
def prompt_optimizer(self):
62+
if self._prompt_optimizer is None:
63+
self._prompt_optimizer = importlib.import_module(
64+
".prompt_optimizer", __package__
65+
)
66+
return self._prompt_optimizer.AsyncPromptOptimizer(self._api_client)
5667

5768
@property
5869
@_common.experimental_warning(

vertexai/_genai/prompt_optimizer.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from google.genai._common import get_value_by_path as getv
2828
from google.genai._common import set_value_by_path as setv
2929

30+
from . import _prompt_optimizer_utils
3031
from . import types
3132

3233

@@ -574,6 +575,7 @@ def _wait_for_completion(self, job_name: str) -> None:
574575
raise RuntimeError(f"Job failed with state: {job.state}")
575576
else:
576577
logger.info(f"Job completed with state: {job.state}")
578+
return job
577579

578580
def optimize(
579581
self,
@@ -584,17 +586,11 @@ def optimize(
584586
585587
Args:
586588
method: The method for optimizing multiple prompts.
587-
config: The config to use. Config consists of the following fields: -
588-
config_path: The gcs path to the config file, e.g.
589-
gs://bucket/config.json. - service_account: Optional. The service
590-
account to use for the custom job. Cannot be provided at the same
591-
time as 'service_account_project_number'. -
592-
service_account_project_number: Optional. The project number used to
593-
construct the default service account:
594-
f"{service_account_project_number}-compute@developer.gserviceaccount.com"
595-
Cannot be provided at the same time as 'service_account'. -
596-
wait_for_completion: Optional. Whether to wait for the job to
597-
complete. Default is True.
589+
config: PromptOptimizerVAPOConfig instance containing the
590+
configuration for prompt optimization.
591+
592+
Returns:
593+
The custom job that was created.
598594
"""
599595

600596
if method != "vapo":
@@ -631,23 +627,7 @@ def optimize(
631627
}
632628
]
633629

634-
if config.service_account:
635-
if config.service_account_project_number:
636-
raise ValueError(
637-
"Only one of service_account or"
638-
" service_account_project_number can be provided."
639-
)
640-
service_account = config.service_account
641-
elif config.project_number:
642-
service_account = (
643-
f"{config.service_account_project_number}"
644-
"-compute@developer.gserviceaccount.com"
645-
)
646-
else:
647-
raise ValueError(
648-
"Either service_account or service_account_project_number is"
649-
" required."
650-
)
630+
service_account = _prompt_optimizer_utils._get_service_account(config)
651631

652632
job_spec = types.CustomJobSpec(
653633
worker_pool_specs=worker_pool_specs,
@@ -672,11 +652,11 @@ def optimize(
672652
logger.info("Job created: %s", job.name)
673653

674654
# Construct the dashboard URL
675-
dashboard_url = f"https://pantheon.corp.google.com/vertex-ai/locations/{region}/training/{job_id}/cpu?e=13802955&project={project}"
655+
dashboard_url = f"https://console.cloud.google.com/vertex-ai/locations/{region}/training/{job_id}/cpu?project={project}"
676656
logger.info("View the job status at: %s", dashboard_url)
677657

678658
if wait_for_completion:
679-
self._wait_for_completion(job_id)
659+
job = self._wait_for_completion(job_id)
680660
return job
681661

682662

@@ -843,3 +823,92 @@ async def _get_custom_job(
843823

844824
self._api_client._verify_response(return_value)
845825
return return_value
826+
827+
async def optimize(
828+
self,
829+
method: str,
830+
config: types.PromptOptimizerVAPOConfigOrDict,
831+
) -> types.CustomJob:
832+
"""Call async Vertex AI Prompt Optimizer (VAPO).
833+
834+
# Todo: b/428953357 - Add example in the README.
835+
Example usage:
836+
client = vertexai.Client(project=PROJECT_NAME, location='us-central1')
837+
vapo_config = vertexai.types.PromptOptimizerVAPOConfig(
838+
config_path="gs://you-bucket-name/your-config.json",
839+
service_account=service_account,
840+
wait_for_completion=True
841+
)
842+
job = await client.aio.prompt_optimizer.optimize(
843+
method="vapo", config=vapo_config)
844+
845+
Args:
846+
method: The method for optimizing multiple prompts (currently only
847+
vapo is supported).
848+
config: PromptOptimizerVAPOConfig instance containing the
849+
configuration for prompt optimization.
850+
851+
Returns:
852+
The custom job that was created.
853+
"""
854+
if method != "vapo":
855+
raise ValueError("Only vapo methods is currently supported.")
856+
857+
if isinstance(config, dict):
858+
config = types.PromptOptimizerVAPOConfig(**config)
859+
860+
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
861+
display_name = f"vapo-optimizer-{timestamp}"
862+
863+
if not config.config_path:
864+
raise ValueError("Config path is required.")
865+
bucket = "/".join(config.config_path.split("/")[:-1])
866+
867+
container_uri = "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0"
868+
869+
region = self._api_client.location
870+
project = self._api_client.project
871+
container_args = {
872+
"config": config.config_path,
873+
}
874+
args = ["--%s=%s" % (k, v) for k, v in container_args.items()]
875+
worker_pool_specs = [
876+
{
877+
"replica_count": 1,
878+
"container_spec": {
879+
"image_uri": container_uri,
880+
"args": args,
881+
},
882+
"machine_spec": {
883+
"machine_type": "n1-standard-4",
884+
},
885+
}
886+
]
887+
888+
service_account = _prompt_optimizer_utils._get_service_account(config)
889+
890+
job_spec = types.CustomJobSpec(
891+
worker_pool_specs=worker_pool_specs,
892+
base_output_directory=types.GcsDestination(output_uri_prefix=bucket),
893+
service_account=service_account,
894+
)
895+
896+
custom_job = types.CustomJob(
897+
display_name=display_name,
898+
job_spec=job_spec,
899+
)
900+
901+
job = await self._create_custom_job_resource(
902+
custom_job=custom_job,
903+
)
904+
905+
# Get the job id for the dashboard url and display to the user.
906+
job_resource_name = job.name
907+
job_id = job_resource_name.split("/")[-1]
908+
logger.info("Job created: %s", job.name)
909+
910+
# Construct the dashboard URL to show to the user.
911+
dashboard_url = f"https://console.cloud.google.com/vertex-ai/locations/{region}/training/{job_id}/cpu?project={project}"
912+
logger.info("View the job status at: %s", dashboard_url)
913+
914+
return job

vertexai/_genai/types.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5207,29 +5207,37 @@ class PromptOptimizerVAPOConfig(_common.BaseModel):
52075207
"""VAPO Prompt Optimizer Config."""
52085208

52095209
config_path: Optional[str] = Field(
5210-
default=None, description="""The gcs path to the config file."""
5210+
default=None,
5211+
description="""The gcs path to the config file, e.g. gs://bucket/config.json.""",
5212+
)
5213+
service_account: Optional[str] = Field(
5214+
default=None,
5215+
description="""The service account to use for the custom job. Cannot be provided at the same time as service_account_project_number.""",
52115216
)
5212-
service_account: Optional[str] = Field(default=None, description="""""")
52135217
service_account_project_number: Optional[Union[int, str]] = Field(
5214-
default=None, description=""""""
5218+
default=None,
5219+
description="""The project number used to construct the default service account:{service_account_project_number}-compute@developer.gserviceaccount.comCannot be provided at the same time as "service_account".""",
5220+
)
5221+
wait_for_completion: Optional[bool] = Field(
5222+
default=True,
5223+
description="""Whether to wait for the job tocomplete. Ignored for async jobs.""",
52155224
)
5216-
wait_for_completion: Optional[bool] = Field(default=True, description="""""")
52175225

52185226

52195227
class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
52205228
"""VAPO Prompt Optimizer Config."""
52215229

52225230
config_path: Optional[str]
5223-
"""The gcs path to the config file."""
5231+
"""The gcs path to the config file, e.g. gs://bucket/config.json."""
52245232

52255233
service_account: Optional[str]
5226-
""""""
5234+
"""The service account to use for the custom job. Cannot be provided at the same time as service_account_project_number."""
52275235

52285236
service_account_project_number: Optional[Union[int, str]]
5229-
""""""
5237+
"""The project number used to construct the default service account:{service_account_project_number}-compute@developer.gserviceaccount.comCannot be provided at the same time as "service_account"."""
52305238

52315239
wait_for_completion: Optional[bool]
5232-
""""""
5240+
"""Whether to wait for the job tocomplete. Ignored for async jobs."""
52335241

52345242

52355243
PromptOptimizerVAPOConfigOrDict = Union[

0 commit comments

Comments
 (0)