Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions tests/unit/vertexai/genai/replays/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
from vertexai._genai import (
client as vertexai_genai_client_module,
)
from google.cloud import storage, bigquery
from google.genai import _replay_api_client
from google.genai import client as google_genai_client_module
from vertexai._genai import _evals_utils
from vertexai._genai import prompt_optimizer
import pytest

IS_KOKORO = os.getenv("KOKORO_BUILD_NUMBER") is not None
Expand Down Expand Up @@ -82,11 +85,39 @@ def _get_replay_id(use_vertex: bool, replays_prefix: str) -> str:
return "/".join([replays_prefix, test_name])


EVAL_CONFIG_GCS_URI = (
"gs://vertex-ai-generative-ai-eval-sdk-resources/metrics/text_quality/v1.0.0.yaml"
)


def _mock_read_file_contents_side_effect(uri: str):
"""
Side effect to mock GcsUtils.read_file_contents for eval test test_batch_evaluate.
"""
if uri == EVAL_CONFIG_GCS_URI:
# Construct the absolute path to the local mock file.
current_dir = os.path.dirname(__file__)
local_yaml_path = os.path.join(
current_dir, "test_resources/mock_eval_config.yaml"
)
try:
with open(local_yaml_path, "r") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(
"The mock data file 'mock_eval_config.yaml' was not found."
)

raise ValueError(
f"Unexpected GCS URI '{uri}' in replay test. Only "
f"'{EVAL_CONFIG_GCS_URI}' is mocked."
)


@pytest.fixture
def client(use_vertex, replays_prefix, http_options, request):

mode = request.config.getoption("--mode")
replays_directory_prefix = request.config.getoption("--replays-directory-prefix")
if mode not in ["auto", "record", "replay", "api", "tap"]:
raise ValueError("Invalid mode: " + mode)
test_function_name = request.function.__name__
Expand Down Expand Up @@ -114,13 +145,14 @@ def client(use_vertex, replays_prefix, http_options, request):
os.environ["GOOGLE_CLOUD_LOCATION"] = "location"
os.environ["VAPO_CONFIG_PATH"] = "gs://dummy-test/dummy-config.json"
os.environ["VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"] = "1234567890"
os.environ["GCS_BUCKET"] = "test-bucket"

# Set the replay directory to the root directory of the replays.
# This is needed to ensure that the replay files are found.
replays_root_directory = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
"../../../../../../../../../google/cloud/aiplatform/sdk/genai/replays",
"../../../../../../../../../../google/cloud/aiplatform/sdk/genai/replays",
)
)
os.environ["GOOGLE_GENAI_REPLAYS_DIRECTORY"] = replays_root_directory
Expand All @@ -131,18 +163,46 @@ def client(use_vertex, replays_prefix, http_options, request):
http_options=http_options,
)

replay_client.replays_directory = (
f"{replays_directory_prefix}/google/cloud/aiplatform/sdk/replays/"
)

with mock.patch.object(
google_genai_client_module.Client, "_get_api_client"
) as patch_method:
patch_method.return_value = replay_client
google_genai_client = vertexai_genai_client_module.Client()

# Yield the client so that cleanup can be completed at the end of the test.
yield google_genai_client
if mode != "replay":
yield google_genai_client
else:
# Eval tests make a call to GCS and BigQuery
# Need to mock this so it doesn't call the service in replay mode
with mock.patch.object(storage, "Client") as mock_storage_client:
mock_client_instance = mock.MagicMock()

mock_blob = mock.MagicMock()

mock_blob.name = "v1.0.0.yaml"

mock_client_instance.list_blobs.return_value = [mock_blob]

mock_storage_client.return_value = mock_client_instance

with mock.patch.object(bigquery, "Client") as mock_bigquery_client:
mock_bigquery_client.return_value = mock.MagicMock()

with mock.patch.object(
_evals_utils.GcsUtils, "read_file_contents"
) as mock_read_file_contents:
mock_read_file_contents.side_effect = (
_mock_read_file_contents_side_effect
)

with mock.patch.object(
prompt_optimizer.time, "sleep"
) as mock_job_wait:
mock_job_wait.return_value = None

google_genai_client = vertexai_genai_client_module.Client()

# Yield the client so that cleanup can be completed at the end of the test.
yield google_genai_client

# Save the replay after the test if we're in recording mode.
replay_client.close()
3 changes: 3 additions & 0 deletions tests/unit/vertexai/genai/replays/credentials.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"type": "service_account"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from vertexai._genai import types


# If you re-record this test, you will need to update the replay file to
# include the placeholder values for config path and service account
def test_optimize(client):
"""Tests the optimize request parameters method."""

Expand All @@ -36,6 +38,7 @@ def test_optimize(client):
service_account_project_number=os.environ.get(
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
),
optimizer_job_display_name="optimizer_job_test",
)
job = client.prompt_optimizer.optimize(
method="vapo",
Expand Down
8 changes: 3 additions & 5 deletions vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,11 +964,9 @@ def evaluate(
config = types.EvaluateMethodConfig.model_validate(config)
if isinstance(dataset, list):
dataset = [
(
types.EvaluationDataset.model_validate(ds_item)
if isinstance(ds_item, dict)
else ds_item
)
types.EvaluationDataset.model_validate(ds_item)
if isinstance(ds_item, dict)
else ds_item
for ds_item in dataset
]
else:
Expand Down
14 changes: 10 additions & 4 deletions vertexai/_genai/prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,11 @@ def optimize(
if isinstance(config, dict):
config = types.PromptOptimizerVAPOConfig(**config)

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
display_name = f"vapo-optimizer-{timestamp}"
if config.optimizer_job_display_name:
display_name = config.optimizer_job_display_name
else:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
display_name = f"vapo-optimizer-{timestamp}"
wait_for_completion = config.wait_for_completion
if not config.config_path:
raise ValueError("Config path is required.")
Expand Down Expand Up @@ -857,8 +860,11 @@ async def optimize(
if isinstance(config, dict):
config = types.PromptOptimizerVAPOConfig(**config)

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
display_name = f"vapo-optimizer-{timestamp}"
if config.optimizer_job_display_name:
display_name = config.optimizer_job_display_name
else:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
display_name = f"vapo-optimizer-{timestamp}"

if not config.config_path:
raise ValueError("Config path is required.")
Expand Down
13 changes: 10 additions & 3 deletions vertexai/_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5223,6 +5223,10 @@ class PromptOptimizerVAPOConfig(_common.BaseModel):
default=True,
description="""Whether to wait for the job tocomplete. Ignored for async jobs.""",
)
optimizer_job_display_name: Optional[str] = Field(
default=None,
description="""The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used.""",
)


class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
Expand All @@ -5240,6 +5244,9 @@ class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
wait_for_completion: Optional[bool]
"""Whether to wait for the job tocomplete. Ignored for async jobs."""

optimizer_job_display_name: Optional[str]
"""The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used."""


PromptOptimizerVAPOConfigOrDict = Union[
PromptOptimizerVAPOConfig, PromptOptimizerVAPOConfigDict
Expand Down Expand Up @@ -5769,9 +5776,9 @@ def to_yaml_file(self, file_path: str, version: Optional[str] = None) -> None:
exclude_unset=True,
exclude_none=True,
mode="json",
exclude=(
fields_to_exclude_callables if fields_to_exclude_callables else None
),
exclude=fields_to_exclude_callables
if fields_to_exclude_callables
else None,
)

if version:
Expand Down
Loading