Skip to content

Commit

Permalink
feat: Support publisher models in BatchPredictionJob.create
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 536581722
  • Loading branch information
Ark-kun authored and copybara-github committed May 31, 2023
1 parent 0463678 commit 13b11c6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
26 changes: 17 additions & 9 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform import model_monitoring
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.preview import _publisher_model
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils
Expand Down Expand Up @@ -624,15 +625,22 @@ def create(
utils.validate_labels(labels)

if isinstance(model_name, str):
model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
parse_resource_name_method=aiplatform.Model._parse_resource_name,
format_resource_name_method=aiplatform.Model._format_resource_name,
project=project,
location=location,
resource_id_validator=super()._revisioned_resource_id_validator,
)
try:
model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
parse_resource_name_method=aiplatform.Model._parse_resource_name,
format_resource_name_method=aiplatform.Model._format_resource_name,
project=project,
location=location,
resource_id_validator=super()._revisioned_resource_id_validator,
)
except ValueError:
# Do not raise exception if model_name is a valid PublisherModel name
if not _publisher_model._PublisherModel._parse_resource_name(
model_name
):
raise

# Raise error if both or neither source URIs are provided
if bool(gcs_source) == bool(bigquery_source):
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
_TEST_MODEL_VERSION_ID = "2"
_TEST_VERSIONED_MODEL_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ALT_ID}@{_TEST_MODEL_VERSION_ID}"

_TEST_PUBLISHER_MODEL_NAME = (
f"publishers/google/models/text-model-name@{_TEST_MODEL_VERSION_ID}"
)

_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_ID}"
_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job"

Expand Down Expand Up @@ -1267,6 +1271,28 @@ def test_batch_predict_job_with_versioned_model(
== _TEST_VERSIONED_MODEL_NAME
)

@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_job_with_publisher_model(
self, create_batch_prediction_job_mock
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

# Make SDK batch_predict method call
_ = jobs.BatchPredictionJob.create(
model_name=_TEST_PUBLISHER_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=True,
service_account=_TEST_SERVICE_ACCOUNT,
)
assert (
create_batch_prediction_job_mock.call_args_list[0][1][
"batch_prediction_job"
].model
== _TEST_PUBLISHER_MODEL_NAME
)


@pytest.fixture
def get_mdm_job_mock():
Expand Down

0 comments on commit 13b11c6

Please sign in to comment.