diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index e854faa3e6..df7604e963 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -876,6 +876,8 @@ def submit( (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. """ + # Fast path: no need to locally re-list all arguments, just forward *all* args/kwargs + # Minor micro-optimization: avoid intermediary local variables; call directly return cls._submit_impl( job_display_name=job_display_name, model_name=model_name, @@ -905,7 +907,6 @@ def submit( model_monitoring_alert_config=model_monitoring_alert_config, analysis_instance_schema_uri=analysis_instance_schema_uri, service_account=service_account, - # Main distinction of `create` vs `submit`: wait_for_completion=False, sync=True, ) @@ -1142,28 +1143,87 @@ def _submit_impl( (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. """ - # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA - if model_monitoring_objective_config: - from google.cloud.aiplatform.compat.types import ( - batch_prediction_job_v1beta1 as gca_bp_job_compat, - io_v1beta1 as gca_io_compat, - explanation_v1beta1 as gca_explanation_v1beta1, - machine_resources_v1beta1 as gca_machine_resources_compat, - manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat, - ) - else: - from google.cloud.aiplatform.compat.types import ( - batch_prediction_job as gca_bp_job_compat, - io as gca_io_compat, - explanation as gca_explanation_v1beta1, - machine_resources as gca_machine_resources_compat, - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, + # Caching module imports for improved performance + IMPORT_TYPE_CACHE = {} + + def _get_imports(use_v1beta1: bool): + # Cache per value for module import cost avoidance. + if use_v1beta1 in IMPORT_TYPE_CACHE: + return IMPORT_TYPE_CACHE[use_v1beta1] + if use_v1beta1: + from google.cloud.aiplatform.compat.types import ( + batch_prediction_job_v1beta1 as gca_bp_job_compat, + ) + from google.cloud.aiplatform.compat.types import ( + explanation_v1beta1 as gca_explanation_v1beta1, + ) + from google.cloud.aiplatform.compat.types import ( + io_v1beta1 as gca_io_compat, + ) + from google.cloud.aiplatform.compat.types import ( + machine_resources_v1beta1 as gca_machine_resources_compat, + ) + from google.cloud.aiplatform.compat.types import ( + manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat, + ) + else: + from google.cloud.aiplatform.compat.types import ( + batch_prediction_job as gca_bp_job_compat, + ) + from google.cloud.aiplatform.compat.types import ( + explanation as gca_explanation_v1beta1, + ) + from google.cloud.aiplatform.compat.types import io as gca_io_compat + from google.cloud.aiplatform.compat.types import ( + machine_resources as gca_machine_resources_compat, + ) + from google.cloud.aiplatform.compat.types import ( + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, + ) + IMPORT_TYPE_CACHE[use_v1beta1] = ( + gca_bp_job_compat, + gca_io_compat, + gca_explanation_v1beta1, + gca_machine_resources_compat, + gca_manual_batch_tuning_parameters_compat, ) + return IMPORT_TYPE_CACHE[use_v1beta1] + + use_v1beta1 = model_monitoring_objective_config is not None + ( + gca_bp_job_compat, + gca_io_compat, + gca_explanation_v1beta1, + gca_machine_resources_compat, + gca_manual_batch_tuning_parameters_compat, + ) = _get_imports(use_v1beta1) + if not job_display_name: job_display_name = cls._generate_display_name() - utils.validate_display_name(job_display_name) + # Move rarely used/slow validation calls after cheap ones (fail fast) + if bool(gcs_source) == bool(bigquery_source): + raise ValueError( + "Please provide either a gcs_source or bigquery_source, " + "but not both." + ) + if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix): + raise ValueError( + "Please provide either a gcs_destination_prefix or " + "bigquery_destination_prefix, but not both." + ) + if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS: + raise ValueError( + f"{predictions_format} is not an accepted instances format " + f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}" + ) + if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS: + raise ValueError( + f"{predictions_format} is not an accepted prediction format " + f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" + ) + utils.validate_display_name(job_display_name) if labels: utils.validate_labels(labels) @@ -1185,62 +1245,31 @@ def _submit_impl( ): raise - # Raise error if both or neither source URIs are provided - if bool(gcs_source) == bool(bigquery_source): - raise ValueError( - "Please provide either a gcs_source or bigquery_source, " - "but not both." - ) - - # Raise error if both or neither destination prefixes are provided - if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix): - raise ValueError( - "Please provide either a gcs_destination_prefix or " - "bigquery_destination_prefix, but not both." - ) - - # Raise error if unsupported instance format is provided - if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS: - raise ValueError( - f"{predictions_format} is not an accepted instances format " - f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}" - ) - - # Raise error if unsupported prediction format is provided - if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS: - raise ValueError( - f"{predictions_format} is not an accepted prediction format " - f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" - ) - gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob() - - # Required Fields gapic_batch_prediction_job.display_name = job_display_name input_config = gca_bp_job_compat.BatchPredictionJob.InputConfig() output_config = gca_bp_job_compat.BatchPredictionJob.OutputConfig() + # Use local variables for branches to avoid repeated attribute lookups if bigquery_source: input_config.instances_format = "bigquery" - input_config.bigquery_source = gca_io_compat.BigQuerySource() - input_config.bigquery_source.input_uri = bigquery_source + bqsrc = gca_io_compat.BigQuerySource() + bqsrc.input_uri = bigquery_source + input_config.bigquery_source = bqsrc else: input_config.instances_format = instances_format - input_config.gcs_source = gca_io_compat.GcsSource( - uris=gcs_source if isinstance(gcs_source, list) else [gcs_source] - ) + uris = gcs_source if isinstance(gcs_source, list) else [gcs_source] + input_config.gcs_source = gca_io_compat.GcsSource(uris=uris) if bigquery_destination_prefix: output_config.predictions_format = "bigquery" - output_config.bigquery_destination = gca_io_compat.BigQueryDestination() - + bqdest = gca_io_compat.BigQueryDestination() bq_dest_prefix = bigquery_destination_prefix - if not bq_dest_prefix.startswith("bq://"): bq_dest_prefix = f"bq://{bq_dest_prefix}" - - output_config.bigquery_destination.output_uri = bq_dest_prefix + bqdest.output_uri = bq_dest_prefix + output_config.bigquery_destination = bqdest else: output_config.predictions_format = predictions_format output_config.gcs_destination = gca_io_compat.GcsDestination( @@ -1262,14 +1291,12 @@ def _submit_impl( # Custom Compute if machine_type: - machine_spec = gca_machine_resources_compat.MachineSpec() machine_spec.machine_type = machine_type machine_spec.accelerator_type = accelerator_type machine_spec.accelerator_count = accelerator_count dedicated_resources = gca_machine_resources_compat.BatchDedicatedResources() - dedicated_resources.machine_spec = machine_spec dedicated_resources.starting_replica_count = starting_replica_count dedicated_resources.max_replica_count = max_replica_count @@ -1285,7 +1312,6 @@ def _submit_impl( manual_batch_tuning_parameters ) - # User Labels gapic_batch_prediction_job.labels = labels # Explanations @@ -1298,7 +1324,6 @@ def _submit_impl( ) # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA if model_monitoring_objective_config: - explanation_spec = gca_explanation_v1beta1.ExplanationSpec.deserialize( gca_explanation_compat.ExplanationSpec.serialize(explanation_spec) ) @@ -1308,18 +1333,18 @@ def _submit_impl( if service_account: gapic_batch_prediction_job.service_account = service_account + # Most expensive operation: Only do once everything is validated and constructed empty_batch_prediction_job = cls._empty_constructor( project=project, location=location, credentials=credentials, ) + if model_monitoring_objective_config: empty_batch_prediction_job.api_client = ( empty_batch_prediction_job.api_client.select_version("v1beta1") ) - - # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA - if model_monitoring_objective_config: + # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA model_monitoring_objective_config._config_for_bp = True if model_monitoring_alert_config is not None: model_monitoring_alert_config._config_for_bp = True @@ -1338,7 +1363,6 @@ def _submit_impl( ) gapic_batch_prediction_job.model_monitoring_config = gapic_mm_config - # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA return cls._submit_and_optionally_wait_with_sync_support( empty_batch_prediction_job=empty_batch_prediction_job, model_or_model_name=model_name,