diff --git a/tests/unit/vertexai/genai/replays/test_assemble_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_assemble_multimodal_datasets.py new file mode 100644 index 0000000000..ac6499c985 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_assemble_multimodal_datasets.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types + +import pytest + +METADATA_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" +) +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" +DATASET = "8810841321427173376" + + +def test_assemble_dataset(client): + operation = client.datasets._assemble_multimodal_dataset( + name=DATASET, + gemini_request_read_config={ + "template_config": { + "field_mapping": {"question": "questionColumn"}, + }, + }, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + + +def test_assemble_dataset_public(client): + bigquery_destination = client.datasets.assemble( + name=DATASET, + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + model="gemini-1.5-flash", + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ), + ) + assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}") + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_assemble_dataset_async(client): + operation = await client.aio.datasets._assemble_multimodal_dataset( + name=DATASET, + gemini_request_read_config={ + "template_config": { + "field_mapping": {"question": "questionColumn"}, + }, + }, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + + +@pytest.mark.asyncio +async def test_assemble_dataset_public_async(client): + bigquery_destination = await client.aio.datasets.assemble( + name=DATASET, + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + model="gemini-1.5-flash", + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ), + ) + assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}") diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index d7608bc911..99e00e792e 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -54,6 +54,9 @@ def test_create_dataset_from_bigquery(client): ) assert isinstance(dataset, types.MultimodalDataset) assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) def test_create_dataset_from_bigquery_without_bq_prefix(client): @@ -70,6 +73,9 @@ def test_create_dataset_from_bigquery_without_bq_prefix(client): ) assert isinstance(dataset, types.MultimodalDataset) assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) pytestmark = pytest_helper.setup( @@ -111,6 +117,9 @@ async def test_create_dataset_from_bigquery_async(client): ) assert isinstance(dataset, types.MultimodalDataset) assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) @pytest.mark.asyncio @@ -129,6 +138,9 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client): ) assert isinstance(dataset, types.MultimodalDataset) assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) @pytest.mark.asyncio @@ -146,3 +158,6 @@ async def test_create_dataset_from_bigquery_async_without_bq_prefix(client): ) assert isinstance(dataset, types.MultimodalDataset) assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index 764d6b7a6d..9e5014bab9 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -14,7 +14,23 @@ # """Utility functions for multimodal dataset.""" +from typing import Any, TypeVar, Type +from vertexai._genai.types import common +from pydantic import BaseModel METADATA_SCHEMA_URI = ( "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" ) + +T = TypeVar("T", bound=BaseModel) + + +def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: + """Creates a model from a response.""" + model_field_names = model_type.model_fields.keys() + filtered_response = {} + for key, value in response.items(): + snake_key = common.camel_to_snake(key) + if snake_key in model_field_names: + filtered_response[snake_key] = value + return model_type(**filtered_response) diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 4252665d00..6bf2967a9d 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -35,6 +35,27 @@ logger = logging.getLogger("vertexai_genai.datasets") +def _AssembleDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["gemini_request_read_config"]) is not None: + setv( + to_object, + ["geminiRequestReadConfig"], + getv(from_object, ["gemini_request_read_config"]), + ) + + return to_object + + def _CreateMultimodalDatasetParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -174,6 +195,63 @@ def _UpdateMultimodalDatasetParameters_to_vertex( class Datasets(_api_module.BaseModule): + def _assemble_multimodal_dataset( + self, + *, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + ) -> types.MultimodalDatasetOperation: + """ + Assembles a multimodal dataset resource. + """ + + parameter_model = types._AssembleDatasetParameters( + config=config, + name=name, + gemini_request_read_config=gemini_request_read_config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _AssembleDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{name}:assemble".format_map(request_url_dict) + else: + path = "datasets/{name}:assemble" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + + self._api_client._verify_response(return_value) + return return_value + def _create_multimodal_dataset( self, *, @@ -517,23 +595,21 @@ def _wait_for_operation( self, operation: types.MultimodalDatasetOperation, timeout_seconds: int, - ) -> types.MultimodalDataset: - """Waits for a multimodal dataset operation to complete. + ) -> dict[str, Any]: + """Waits for a multimodal or assemble dataset operation to complete. Args: - operation: The multimodal dataset operation to wait for. + operation: The multimodal or assemble dataset operation to wait for. timeout_seconds: The maximum time in seconds to wait for the operation to complete. Returns: - The name of the Multimodal Dataset resource from the operation result. + A dict containing the operation response. Raises: TimeoutError: If the operation does not complete within the timeout. ValueError: If the operation fails. """ - multimodal_operation: Optional[types.MultimodalDatasetOperation] = None - response_operation_name = operation.name dataset_id = response_operation_name.split("/datasets/")[1].split("/")[0] operation_id = response_operation_name.split("/")[-1] @@ -544,11 +620,11 @@ def _wait_for_operation( max_wait_time_seconds = 60 while (time.time() - start_time) < timeout_seconds: - multimodal_operation = self._get_multimodal_dataset_operation( + operation = self._get_multimodal_dataset_operation( dataset_id=dataset_id, operation_id=operation_id, ) - if multimodal_operation.done: + if operation.done: break time.sleep(sleep_duration_seconds) sleep_duration_seconds = min( @@ -556,26 +632,15 @@ def _wait_for_operation( ) else: raise TimeoutError( - "Create multimodal dataset operation did not complete within the" + "The operation did not complete within the" f" specified timeout of {timeout_seconds} seconds." ) - if ( - not multimodal_operation - or multimodal_operation.response is None - or multimodal_operation.response.name is None - ): - logger.error( - f"Error creating multimodal dataset resource for the operation {operation.name}." - ) - raise ValueError("Error creating multimodal dataset resource.") - if ( - hasattr(multimodal_operation, "error") - and multimodal_operation.error is not None - ): - raise ValueError( - f"Error creating multimodal dataset resource: {multimodal_operation.error}" - ) - return multimodal_operation.response + if not operation or operation.response is None: + logger.error(f"Error running the operation {operation.response}.") + raise ValueError(f"Error running the operation {operation.response}.") + if hasattr(operation, "error") and operation.error is not None: + raise ValueError(f"Error running the operation {operation.error}") + return operation.response def create_from_bigquery( self, @@ -614,10 +679,11 @@ def create_from_bigquery( metadata_schema_uri=_datasets_utils.METADATA_SCHEMA_URI, metadata=multimodal_dataset.metadata, ) - return self._wait_for_operation( + response = self._wait_for_operation( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) + return _datasets_utils.create_from_response(types.MultimodalDataset, response) def update_multimodal_dataset( self, @@ -715,9 +781,112 @@ def delete_multimodal_dataset( return self._delete_multimodal_dataset(config=config, name=name) + def assemble( + self, + *, + name: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + ) -> str: + """Assemble the dataset into a BigQuery table. + + Waits for the assemble operation to complete before returning. + + Args: + name: + Required. The name of the dataset to assemble. The name should be in + the format of "projects/{project}/locations/{location}/datasets/{dataset}". + template_config: + Optional. The template config to use to assemble the dataset. If + not provided, the template config attached to the dataset will be + used. + config: + Optional. A configuration for assembling the dataset. If not + provided, the default configuration will be used. + + Returns: + The URI of the bigquery table of the assembled dataset. + """ + if isinstance(config, dict): + config = types.AssembleDatasetConfig(**config) + elif not config: + config = types.AssembleDatasetConfig() + + operation = self._assemble_multimodal_dataset( + name=name, + gemini_request_read_config={ + "template_config": template_config, + }, + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return response["bigqueryDestination"] + class AsyncDatasets(_api_module.BaseModule): + async def _assemble_multimodal_dataset( + self, + *, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + ) -> types.MultimodalDatasetOperation: + """ + Assembles a multimodal dataset resource. + """ + + parameter_model = types._AssembleDatasetParameters( + config=config, + name=name, + gemini_request_read_config=gemini_request_read_config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _AssembleDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{name}:assemble".format_map(request_url_dict) + else: + path = "datasets/{name}:assemble" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + + self._api_client._verify_response(return_value) + return return_value + async def _create_multimodal_dataset( self, *, @@ -1073,7 +1242,7 @@ async def _wait_for_operation( self, operation: types.MultimodalDatasetOperation, timeout_seconds: int, - ) -> types.MultimodalDataset: + ) -> dict[str, Any]: """Waits for a multimodal dataset operation to complete. Args: @@ -1082,14 +1251,12 @@ async def _wait_for_operation( to complete. Returns: - The name of the Multimodal Dataset resource from the operation result. + A dict containing the operation response. Raises: TimeoutError: If the operation does not complete within the timeout. ValueError: If the operation fails. """ - multimodal_operation: Optional[types.MultimodalDatasetOperation] = None - response_operation_name = operation.name dataset_id = response_operation_name.split("/datasets/")[1].split("/")[0] operation_id = response_operation_name.split("/")[-1] @@ -1100,11 +1267,11 @@ async def _wait_for_operation( max_wait_time_seconds = 60 while (time.time() - start_time) < timeout_seconds: - multimodal_operation = await self._get_multimodal_dataset_operation( + operation = await self._get_multimodal_dataset_operation( dataset_id=dataset_id, operation_id=operation_id, ) - if multimodal_operation.done: + if operation.done: break await asyncio.sleep(sleep_duration_seconds) sleep_duration_seconds = min( @@ -1112,26 +1279,15 @@ async def _wait_for_operation( ) else: raise TimeoutError( - "Create multimodal dataset operation did not complete within the" + "The operation did not complete within the" f" specified timeout of {timeout_seconds} seconds." ) - if ( - not multimodal_operation - or multimodal_operation.response is None - or multimodal_operation.response.name is None - ): - logger.error( - f"Error creating multimodal dataset resource for the operation {operation.name}." - ) - raise ValueError("Error creating multimodal dataset resource.") - if ( - hasattr(multimodal_operation, "error") - and multimodal_operation.error is not None - ): - raise ValueError( - f"Error creating multimodal dataset resource: {multimodal_operation.error}" - ) - return multimodal_operation.response + if not operation or operation.response is None: + logger.error(f"Error running the operation {operation.response}.") + raise ValueError(f"Error running the operation {operation.response}.") + if hasattr(operation, "error") and operation.error is not None: + raise ValueError(f"Error running the operation {operation.error}") + return operation.response async def create_from_bigquery( self, @@ -1170,10 +1326,11 @@ async def create_from_bigquery( metadata_schema_uri=_datasets_utils.METADATA_SCHEMA_URI, metadata=multimodal_dataset.metadata, ) - return await self._wait_for_operation( + response = await self._wait_for_operation( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) + return _datasets_utils.create_from_response(types.MultimodalDataset, response) async def update_multimodal_dataset( self, @@ -1266,3 +1423,47 @@ async def delete_multimodal_dataset( config = types.CreateMultimodalDatasetConfig() return await self._delete_multimodal_dataset(config=config, name=name) + + async def assemble( + self, + *, + name: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + ) -> str: + """Assemble the dataset into a BigQuery table. + + Waits for the assemble operation to complete before returning. + + Args: + name: + Required. The name of the dataset to assemble. The name should be in + the format of "projects/{project}/locations/{location}/datasets/{dataset}". + template_config: + Optional. The template config to use to assemble the dataset. If + not provided, the template config attached to the dataset will be + used. + config: + Optional. A configuration for assembling the dataset. If not + provided, the default configuration will be used. + + Returns: + The URI of the bigquery table of the assembled dataset. + """ + if isinstance(config, dict): + config = types.AssembleDatasetConfig(**config) + elif not config: + config = types.AssembleDatasetConfig() + + operation = await self._assemble_multimodal_dataset( + name=name, + gemini_request_read_config={ + "template_config": template_config, + }, + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return response["bigqueryDestination"] diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index 4b9ad32eb7..f8dafa2a0f 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -22,6 +22,7 @@ from . import agent_engines from . import evals from .common import _AppendAgentEngineSessionEventRequestParameters +from .common import _AssembleDatasetParameters from .common import _CreateAgentEngineMemoryRequestParameters from .common import _CreateAgentEngineRequestParameters from .common import _CreateAgentEngineSandboxRequestParameters @@ -123,6 +124,12 @@ from .common import ApplicableGuideline from .common import ApplicableGuidelineDict from .common import ApplicableGuidelineOrDict +from .common import AssembleDataset +from .common import AssembleDatasetConfig +from .common import AssembleDatasetConfigDict +from .common import AssembleDatasetConfigOrDict +from .common import AssembleDatasetDict +from .common import AssembleDatasetOrDict from .common import BigQueryRequestSet from .common import BigQueryRequestSetDict from .common import BigQueryRequestSetOrDict @@ -366,6 +373,15 @@ from .common import GcsSource from .common import GcsSourceDict from .common import GcsSourceOrDict +from .common import GeminiExample +from .common import GeminiExampleDict +from .common import GeminiExampleOrDict +from .common import GeminiRequestReadConfig +from .common import GeminiRequestReadConfigDict +from .common import GeminiRequestReadConfigOrDict +from .common import GeminiTemplateConfig +from .common import GeminiTemplateConfigDict +from .common import GeminiTemplateConfigOrDict from .common import GenerateAgentEngineMemoriesConfig from .common import GenerateAgentEngineMemoriesConfigDict from .common import GenerateAgentEngineMemoriesConfigOrDict @@ -1587,6 +1603,21 @@ "ListAgentEngineSessionEventsResponse", "ListAgentEngineSessionEventsResponseDict", "ListAgentEngineSessionEventsResponseOrDict", + "AssembleDatasetConfig", + "AssembleDatasetConfigDict", + "AssembleDatasetConfigOrDict", + "GeminiExample", + "GeminiExampleDict", + "GeminiExampleOrDict", + "GeminiTemplateConfig", + "GeminiTemplateConfigDict", + "GeminiTemplateConfigOrDict", + "GeminiRequestReadConfig", + "GeminiRequestReadConfigDict", + "GeminiRequestReadConfigOrDict", + "MultimodalDatasetOperation", + "MultimodalDatasetOperationDict", + "MultimodalDatasetOperationOrDict", "CreateMultimodalDatasetConfig", "CreateMultimodalDatasetConfigDict", "CreateMultimodalDatasetConfigOrDict", @@ -1602,9 +1633,6 @@ "MultimodalDataset", "MultimodalDatasetDict", "MultimodalDatasetOrDict", - "MultimodalDatasetOperation", - "MultimodalDatasetOperationDict", - "MultimodalDatasetOperationOrDict", "GetMultimodalDatasetOperationConfig", "GetMultimodalDatasetOperationConfigDict", "GetMultimodalDatasetOperationConfigOrDict", @@ -1773,6 +1801,9 @@ "AgentEngineConfig", "AgentEngineConfigDict", "AgentEngineConfigOrDict", + "AssembleDataset", + "AssembleDatasetDict", + "AssembleDatasetOrDict", "Prompt", "PromptDict", "PromptOrDict", @@ -1876,6 +1907,7 @@ "_UpdateAgentEngineSessionRequestParameters", "_AppendAgentEngineSessionEventRequestParameters", "_ListAgentEngineSessionEventsRequestParameters", + "_AssembleDatasetParameters", "_CreateMultimodalDatasetParameters", "_DeleteMultimodalDatasetRequestParameters", "_GetMultimodalDatasetParameters", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index ea471a172e..1e9ec94c93 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -47,7 +47,7 @@ from . import evals as evals_types -def _camel_to_snake(camel_case_string: str) -> str: +def camel_to_snake(camel_case_string: str) -> str: snake_case_string = re.sub(r"(?