From a41c8833b73362b173b74d7c9a81d394cb01d31f Mon Sep 17 00:00:00 2001 From: Christian Leopoldseder Date: Wed, 15 Apr 2026 08:13:22 -0700 Subject: [PATCH] feat: GenAI SDK client(multimodal) - Accept an explicit bigquery_uri parameter in create_from_bigquery PiperOrigin-RevId: 900174983 --- .../test_create_multimodal_datasets.py | 72 +++++++++++++++++++ vertexai/_genai/_datasets_utils.py | 29 +++++--- vertexai/_genai/datasets.py | 54 ++++++++++---- 3 files changed, 133 insertions(+), 22 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 9fa1711ac9..46f4228a5a 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -115,6 +115,40 @@ def test_create_dataset_from_bigquery(client): ) +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +def test_create_dataset_from_bigquery_with_uri(client): + dataset = client.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +def test_create_dataset_from_bigquery_preserves_other_metadata(client): + dataset = client.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + multimodal_dataset={ + "display_name": "test-from-bigquery-uri", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery-uri" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + @pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") def test_create_dataset_from_bigquery_no_display_name(client): dataset = client.datasets.create_from_bigquery( @@ -254,6 +288,44 @@ async def test_create_dataset_from_bigquery_async(client): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +async def test_create_dataset_from_bigquery_with_uri_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.asyncio +async def test_create_dataset_from_bigquery_preserves_other_metadata_async( + client, +): + dataset = await client.aio.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + multimodal_dataset={ + "display_name": "test-from-bigquery-uri", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery-uri" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") async def test_create_dataset_from_bigquery_no_display_name_async(client): diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index e063e6802a..bf2ffd7cf2 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -21,7 +21,7 @@ import google.auth.credentials from vertexai._genai.types import common -from pydantic import BaseModel +from google.genai import _common METADATA_SCHEMA_URI = ( @@ -31,18 +31,27 @@ _DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets" _DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset" -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T", bound=_common.BaseModel) -def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: +def create_from_response( + model_type: Type[T], + response: dict[str, Any], + config: Any | None = None, +) -> T: """Creates a model from a response.""" - model_field_names = model_type.model_fields.keys() - filtered_response = {} - for key, value in response.items(): - snake_key = common.camel_to_snake(key) - if snake_key in model_field_names: - filtered_response[snake_key] = value - return model_type(**filtered_response) + kwargs = ( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr(config, "response_json_schema", None), + "include_all_fields": getattr(config, "include_all_fields", None), + } + } + if config + else {} + ) + return model_type._from_response(response=response, kwargs=kwargs) def validate_multimodal_dataset_bigquery_uri( diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 046803edf0..c1d82997ff 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -924,14 +924,18 @@ def _wait_for_operation( def create_from_bigquery( self, *, - multimodal_dataset: types.MultimodalDatasetOrDict, + bigquery_uri: Optional[str] = None, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, ) -> types.MultimodalDataset: """Creates a multimodal dataset from a BigQuery table. Args: + bigquery_uri: + Optional. The BigQuery URI of the table to create the dataset from. + e.g. "bq://project.dataset.table". multimodal_dataset: - Required. A representation of a multimodal dataset. + Optional. A representation of a multimodal dataset. config: Optional. A configuration for creating the multimodal dataset. If not provided, the default configuration will be used. @@ -939,8 +943,15 @@ def create_from_bigquery( Returns: A types.MultimodalDataset object representing a multimodal dataset. """ - if isinstance(multimodal_dataset, dict): + if multimodal_dataset is None: + multimodal_dataset = types.MultimodalDataset() + elif isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + + if bigquery_uri: + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(bigquery_uri) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) if isinstance(config, dict): @@ -963,7 +974,9 @@ def create_from_bigquery( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response(types.MultimodalDataset, response) + return _datasets_utils.create_from_response( + types.MultimodalDataset, response, config + ) def create_from_pandas( self, @@ -1302,6 +1315,7 @@ def assess_tuning_resources( return _datasets_utils.create_from_response( types.TuningResourceUsageAssessmentResult, response["tuningResourceUsageAssessmentResult"], + config, ) def assess_tuning_validity( @@ -1368,6 +1382,7 @@ def assess_tuning_validity( return _datasets_utils.create_from_response( types.TuningValidationAssessmentResult, response["tuningValidationAssessmentResult"], + config, ) def assess_batch_prediction_resources( @@ -1430,7 +1445,7 @@ def assess_batch_prediction_resources( ) result = response["batchPredictionResourceUsageAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionResourceUsageAssessmentResult, result + types.BatchPredictionResourceUsageAssessmentResult, result, config ) def assess_batch_prediction_validity( @@ -1493,7 +1508,7 @@ def assess_batch_prediction_validity( ) result = response["batchPredictionValidationAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionValidationAssessmentResult, result + types.BatchPredictionValidationAssessmentResult, result, config ) @@ -2192,14 +2207,18 @@ async def _wait_for_operation( async def create_from_bigquery( self, *, - multimodal_dataset: types.MultimodalDatasetOrDict, + bigquery_uri: Optional[str] = None, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, ) -> types.MultimodalDataset: """Creates a multimodal dataset from a BigQuery table. Args: + bigquery_uri: + Optional. The BigQuery URI of the table to create the dataset from. + e.g. "bq://project.dataset.table". multimodal_dataset: - Required. A representation of a multimodal dataset. + Optional. A representation of a multimodal dataset. config: Optional. A configuration for creating the multimodal dataset. If not provided, the default configuration will be used. @@ -2207,8 +2226,15 @@ async def create_from_bigquery( Returns: A types.MultimodalDataset object representing a multimodal dataset. """ - if isinstance(multimodal_dataset, dict): + if multimodal_dataset is None: + multimodal_dataset = types.MultimodalDataset() + elif isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + + if bigquery_uri: + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(bigquery_uri) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) if isinstance(config, dict): @@ -2231,7 +2257,9 @@ async def create_from_bigquery( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response(types.MultimodalDataset, response) + return _datasets_utils.create_from_response( + types.MultimodalDataset, response, config + ) async def create_from_pandas( self, @@ -2568,6 +2596,7 @@ async def assess_tuning_resources( return _datasets_utils.create_from_response( types.TuningResourceUsageAssessmentResult, response["tuningResourceUsageAssessmentResult"], + config, ) async def assess_tuning_validity( @@ -2634,6 +2663,7 @@ async def assess_tuning_validity( return _datasets_utils.create_from_response( types.TuningValidationAssessmentResult, response["tuningValidationAssessmentResult"], + config, ) async def assess_batch_prediction_resources( @@ -2696,7 +2726,7 @@ async def assess_batch_prediction_resources( ) result = response["batchPredictionResourceUsageAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionResourceUsageAssessmentResult, result + types.BatchPredictionResourceUsageAssessmentResult, result, config ) async def assess_batch_prediction_validity( @@ -2759,5 +2789,5 @@ async def assess_batch_prediction_validity( ) result = response["batchPredictionValidationAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionValidationAssessmentResult, result + types.BatchPredictionValidationAssessmentResult, result, config )