Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,40 @@ def test_create_dataset_from_bigquery(client):
)


@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
def test_create_dataset_from_bigquery_with_uri(client):
dataset = client.datasets.create_from_bigquery(
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


def test_create_dataset_from_bigquery_preserves_other_metadata(client):
dataset = client.datasets.create_from_bigquery(
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
multimodal_dataset={
"display_name": "test-from-bigquery-uri",
"metadata": {
"gemini_request_read_config": {
"assembled_request_column_name": "test_column"
}
},
},
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigquery-uri"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
def test_create_dataset_from_bigquery_no_display_name(client):
dataset = client.datasets.create_from_bigquery(
Expand Down Expand Up @@ -254,6 +288,44 @@ async def test_create_dataset_from_bigquery_async(client):
)


@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
async def test_create_dataset_from_bigquery_with_uri_async(client):
dataset = await client.aio.datasets.create_from_bigquery(
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.asyncio
async def test_create_dataset_from_bigquery_preserves_other_metadata_async(
client,
):
dataset = await client.aio.datasets.create_from_bigquery(
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
multimodal_dataset={
"display_name": "test-from-bigquery-uri",
"metadata": {
"gemini_request_read_config": {
"assembled_request_column_name": "test_column"
}
},
},
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigquery-uri"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
async def test_create_dataset_from_bigquery_no_display_name_async(client):
Expand Down
29 changes: 19 additions & 10 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import google.auth.credentials
from vertexai._genai.types import common
from pydantic import BaseModel
from google.genai import _common


METADATA_SCHEMA_URI = (
Expand All @@ -31,18 +31,27 @@
_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"

T = TypeVar("T", bound=BaseModel)
T = TypeVar("T", bound=_common.BaseModel)


def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
def create_from_response(
model_type: Type[T],
response: dict[str, Any],
config: Any | None = None,
) -> T:
"""Creates a model from a response."""
model_field_names = model_type.model_fields.keys()
filtered_response = {}
for key, value in response.items():
snake_key = common.camel_to_snake(key)
if snake_key in model_field_names:
filtered_response[snake_key] = value
return model_type(**filtered_response)
kwargs = (
{
"config": {
"response_schema": getattr(config, "response_schema", None),
"response_json_schema": getattr(config, "response_json_schema", None),
"include_all_fields": getattr(config, "include_all_fields", None),
}
}
if config
else {}
)
return model_type._from_response(response=response, kwargs=kwargs)


def validate_multimodal_dataset_bigquery_uri(
Expand Down
54 changes: 42 additions & 12 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,23 +924,34 @@ def _wait_for_operation(
def create_from_bigquery(
self,
*,
multimodal_dataset: types.MultimodalDatasetOrDict,
bigquery_uri: Optional[str] = None,
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
) -> types.MultimodalDataset:
"""Creates a multimodal dataset from a BigQuery table.

Args:
bigquery_uri:
Optional. The BigQuery URI of the table to create the dataset from.
e.g. "bq://project.dataset.table".
multimodal_dataset:
Required. A representation of a multimodal dataset.
Optional. A representation of a multimodal dataset.
config:
Optional. A configuration for creating the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
A types.MultimodalDataset object representing a multimodal dataset.
"""
if isinstance(multimodal_dataset, dict):
if multimodal_dataset is None:
multimodal_dataset = types.MultimodalDataset()
elif isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)

if bigquery_uri:
multimodal_dataset = multimodal_dataset.model_copy(deep=True)
multimodal_dataset.set_bigquery_uri(bigquery_uri)

_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)

if isinstance(config, dict):
Expand All @@ -963,7 +974,9 @@ def create_from_bigquery(
operation=multimodal_dataset_operation,
timeout_seconds=config.timeout,
)
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
return _datasets_utils.create_from_response(
types.MultimodalDataset, response, config
)

def create_from_pandas(
self,
Expand Down Expand Up @@ -1302,6 +1315,7 @@ def assess_tuning_resources(
return _datasets_utils.create_from_response(
types.TuningResourceUsageAssessmentResult,
response["tuningResourceUsageAssessmentResult"],
config,
)

def assess_tuning_validity(
Expand Down Expand Up @@ -1368,6 +1382,7 @@ def assess_tuning_validity(
return _datasets_utils.create_from_response(
types.TuningValidationAssessmentResult,
response["tuningValidationAssessmentResult"],
config,
)

def assess_batch_prediction_resources(
Expand Down Expand Up @@ -1430,7 +1445,7 @@ def assess_batch_prediction_resources(
)
result = response["batchPredictionResourceUsageAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
types.BatchPredictionResourceUsageAssessmentResult, result, config
)

def assess_batch_prediction_validity(
Expand Down Expand Up @@ -1493,7 +1508,7 @@ def assess_batch_prediction_validity(
)
result = response["batchPredictionValidationAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionValidationAssessmentResult, result
types.BatchPredictionValidationAssessmentResult, result, config
)


Expand Down Expand Up @@ -2192,23 +2207,34 @@ async def _wait_for_operation(
async def create_from_bigquery(
self,
*,
multimodal_dataset: types.MultimodalDatasetOrDict,
bigquery_uri: Optional[str] = None,
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
) -> types.MultimodalDataset:
"""Creates a multimodal dataset from a BigQuery table.

Args:
bigquery_uri:
Optional. The BigQuery URI of the table to create the dataset from.
e.g. "bq://project.dataset.table".
multimodal_dataset:
Required. A representation of a multimodal dataset.
Optional. A representation of a multimodal dataset.
config:
Optional. A configuration for creating the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
A types.MultimodalDataset object representing a multimodal dataset.
"""
if isinstance(multimodal_dataset, dict):
if multimodal_dataset is None:
multimodal_dataset = types.MultimodalDataset()
elif isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)

if bigquery_uri:
multimodal_dataset = multimodal_dataset.model_copy(deep=True)
multimodal_dataset.set_bigquery_uri(bigquery_uri)

_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)

if isinstance(config, dict):
Expand All @@ -2231,7 +2257,9 @@ async def create_from_bigquery(
operation=multimodal_dataset_operation,
timeout_seconds=config.timeout,
)
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
return _datasets_utils.create_from_response(
types.MultimodalDataset, response, config
)

async def create_from_pandas(
self,
Expand Down Expand Up @@ -2568,6 +2596,7 @@ async def assess_tuning_resources(
return _datasets_utils.create_from_response(
types.TuningResourceUsageAssessmentResult,
response["tuningResourceUsageAssessmentResult"],
config,
)

async def assess_tuning_validity(
Expand Down Expand Up @@ -2634,6 +2663,7 @@ async def assess_tuning_validity(
return _datasets_utils.create_from_response(
types.TuningValidationAssessmentResult,
response["tuningValidationAssessmentResult"],
config,
)

async def assess_batch_prediction_resources(
Expand Down Expand Up @@ -2696,7 +2726,7 @@ async def assess_batch_prediction_resources(
)
result = response["batchPredictionResourceUsageAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
types.BatchPredictionResourceUsageAssessmentResult, result, config
)

async def assess_batch_prediction_validity(
Expand Down Expand Up @@ -2759,5 +2789,5 @@ async def assess_batch_prediction_validity(
)
result = response["batchPredictionValidationAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionValidationAssessmentResult, result
types.BatchPredictionValidationAssessmentResult, result, config
)
Loading