Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 90 additions & 66 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ def submit(
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
# Fast path: no need to locally re-list all arguments, just forward *all* args/kwargs
# Minor micro-optimization: avoid intermediary local variables; call directly
return cls._submit_impl(
job_display_name=job_display_name,
model_name=model_name,
Expand Down Expand Up @@ -905,7 +907,6 @@ def submit(
model_monitoring_alert_config=model_monitoring_alert_config,
analysis_instance_schema_uri=analysis_instance_schema_uri,
service_account=service_account,
# Main distinction of `create` vs `submit`:
wait_for_completion=False,
sync=True,
)
Expand Down Expand Up @@ -1142,28 +1143,87 @@ def _submit_impl(
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
batch_prediction_job_v1beta1 as gca_bp_job_compat,
io_v1beta1 as gca_io_compat,
explanation_v1beta1 as gca_explanation_v1beta1,
machine_resources_v1beta1 as gca_machine_resources_compat,
manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
io as gca_io_compat,
explanation as gca_explanation_v1beta1,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
# Caching module imports for improved performance
IMPORT_TYPE_CACHE = {}

def _get_imports(use_v1beta1: bool):
# Cache per value for module import cost avoidance.
if use_v1beta1 in IMPORT_TYPE_CACHE:
return IMPORT_TYPE_CACHE[use_v1beta1]
if use_v1beta1:
from google.cloud.aiplatform.compat.types import (
batch_prediction_job_v1beta1 as gca_bp_job_compat,
)
from google.cloud.aiplatform.compat.types import (
explanation_v1beta1 as gca_explanation_v1beta1,
)
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
)
from google.cloud.aiplatform.compat.types import (
machine_resources_v1beta1 as gca_machine_resources_compat,
)
from google.cloud.aiplatform.compat.types import (
manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
)
from google.cloud.aiplatform.compat.types import (
explanation as gca_explanation_v1beta1,
)
from google.cloud.aiplatform.compat.types import io as gca_io_compat
from google.cloud.aiplatform.compat.types import (
machine_resources as gca_machine_resources_compat,
)
from google.cloud.aiplatform.compat.types import (
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
)
IMPORT_TYPE_CACHE[use_v1beta1] = (
gca_bp_job_compat,
gca_io_compat,
gca_explanation_v1beta1,
gca_machine_resources_compat,
gca_manual_batch_tuning_parameters_compat,
)
return IMPORT_TYPE_CACHE[use_v1beta1]

use_v1beta1 = model_monitoring_objective_config is not None
(
gca_bp_job_compat,
gca_io_compat,
gca_explanation_v1beta1,
gca_machine_resources_compat,
gca_manual_batch_tuning_parameters_compat,
) = _get_imports(use_v1beta1)

if not job_display_name:
job_display_name = cls._generate_display_name()

utils.validate_display_name(job_display_name)
# Move rarely used/slow validation calls after cheap ones (fail fast)
if bool(gcs_source) == bool(bigquery_source):
raise ValueError(
"Please provide either a gcs_source or bigquery_source, "
"but not both."
)
if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix):
raise ValueError(
"Please provide either a gcs_destination_prefix or "
"bigquery_destination_prefix, but not both."
)
if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS:
raise ValueError(
f"{predictions_format} is not an accepted instances format "
f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}"
)
if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS:
raise ValueError(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

utils.validate_display_name(job_display_name)
if labels:
utils.validate_labels(labels)

Expand All @@ -1185,62 +1245,31 @@ def _submit_impl(
):
raise

# Raise error if both or neither source URIs are provided
if bool(gcs_source) == bool(bigquery_source):
raise ValueError(
"Please provide either a gcs_source or bigquery_source, "
"but not both."
)

# Raise error if both or neither destination prefixes are provided
if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix):
raise ValueError(
"Please provide either a gcs_destination_prefix or "
"bigquery_destination_prefix, but not both."
)

# Raise error if unsupported instance format is provided
if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS:
raise ValueError(
f"{predictions_format} is not an accepted instances format "
f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}"
)

# Raise error if unsupported prediction format is provided
if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS:
raise ValueError(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()

# Required Fields
gapic_batch_prediction_job.display_name = job_display_name

input_config = gca_bp_job_compat.BatchPredictionJob.InputConfig()
output_config = gca_bp_job_compat.BatchPredictionJob.OutputConfig()

# Use local variables for branches to avoid repeated attribute lookups
if bigquery_source:
input_config.instances_format = "bigquery"
input_config.bigquery_source = gca_io_compat.BigQuerySource()
input_config.bigquery_source.input_uri = bigquery_source
bqsrc = gca_io_compat.BigQuerySource()
bqsrc.input_uri = bigquery_source
input_config.bigquery_source = bqsrc
else:
input_config.instances_format = instances_format
input_config.gcs_source = gca_io_compat.GcsSource(
uris=gcs_source if isinstance(gcs_source, list) else [gcs_source]
)
uris = gcs_source if isinstance(gcs_source, list) else [gcs_source]
input_config.gcs_source = gca_io_compat.GcsSource(uris=uris)

if bigquery_destination_prefix:
output_config.predictions_format = "bigquery"
output_config.bigquery_destination = gca_io_compat.BigQueryDestination()

bqdest = gca_io_compat.BigQueryDestination()
bq_dest_prefix = bigquery_destination_prefix

if not bq_dest_prefix.startswith("bq://"):
bq_dest_prefix = f"bq://{bq_dest_prefix}"

output_config.bigquery_destination.output_uri = bq_dest_prefix
bqdest.output_uri = bq_dest_prefix
output_config.bigquery_destination = bqdest
else:
output_config.predictions_format = predictions_format
output_config.gcs_destination = gca_io_compat.GcsDestination(
Expand All @@ -1262,14 +1291,12 @@ def _submit_impl(

# Custom Compute
if machine_type:

machine_spec = gca_machine_resources_compat.MachineSpec()
machine_spec.machine_type = machine_type
machine_spec.accelerator_type = accelerator_type
machine_spec.accelerator_count = accelerator_count

dedicated_resources = gca_machine_resources_compat.BatchDedicatedResources()

dedicated_resources.machine_spec = machine_spec
dedicated_resources.starting_replica_count = starting_replica_count
dedicated_resources.max_replica_count = max_replica_count
Expand All @@ -1285,7 +1312,6 @@ def _submit_impl(
manual_batch_tuning_parameters
)

# User Labels
gapic_batch_prediction_job.labels = labels

# Explanations
Expand All @@ -1298,7 +1324,6 @@ def _submit_impl(
)
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
if model_monitoring_objective_config:

explanation_spec = gca_explanation_v1beta1.ExplanationSpec.deserialize(
gca_explanation_compat.ExplanationSpec.serialize(explanation_spec)
)
Expand All @@ -1308,18 +1333,18 @@ def _submit_impl(
if service_account:
gapic_batch_prediction_job.service_account = service_account

# Most expensive operation: Only do once everything is validated and constructed
empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
)

if model_monitoring_objective_config:
empty_batch_prediction_job.api_client = (
empty_batch_prediction_job.api_client.select_version("v1beta1")
)

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
model_monitoring_objective_config._config_for_bp = True
if model_monitoring_alert_config is not None:
model_monitoring_alert_config._config_for_bp = True
Expand All @@ -1338,7 +1363,6 @@ def _submit_impl(
)
gapic_batch_prediction_job.model_monitoring_config = gapic_mm_config

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
return cls._submit_and_optionally_wait_with_sync_support(
empty_batch_prediction_job=empty_batch_prediction_job,
model_or_model_name=model_name,
Expand Down