Skip to content

Commit

Permalink
feat: Added aiplatform.Model.update method (#952)
Browse files Browse the repository at this point in the history
* Initial commit for updating models

* Added update functionality

* Added test

* Fixed validation

* Fixed docstrings and linting

* Fixed whitespace

* Mutate copy of proto instead of the original proto

* Added return type

* Added model.update integration test

* Update google/cloud/aiplatform/models.py

Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>

* Ran linter

Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
  • Loading branch information
ivanmkc and sasha-gitg committed Jan 24, 2022
1 parent 02a92f6 commit 44e208a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 3 deletions.
70 changes: 68 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
env_var as gca_env_var_compat,
)

from google.protobuf import json_format

from google.protobuf import field_mask_pb2, json_format

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -1502,6 +1501,73 @@ def __init__(
)
self._gca_resource = self._get_gca_resource(resource_name=model_name)

def update(
self,
display_name: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
) -> "Model":
"""Updates a model.
Example usage:
my_model = my_model.update(
display_name='my-model',
description='my description',
labels={'key': 'value'},
)
Args:
display_name (str):
The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description (str):
The description of the model.
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize your Models.
Label keys and values can be no longer than 64
characters (Unicode codepoints), can only
contain lowercase letters, numeric characters,
underscores and dashes. International characters
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
Returns:
model: Updated model resource.
Raises:
ValueError: If `labels` is not the correct format.
"""

current_model_proto = self.gca_resource
copied_model_proto = current_model_proto.__class__(current_model_proto)

update_mask: List[str] = []

if display_name:
utils.validate_display_name(display_name)

copied_model_proto.display_name = display_name
update_mask.append("display_name")

if description:
copied_model_proto.description = description
update_mask.append("description")

if labels:
utils.validate_labels(labels)

copied_model_proto.labels = labels
update_mask.append("labels")

update_mask = field_mask_pb2.FieldMask(paths=update_mask)

self.api_client.update_model(model=copied_model_proto, update_mask=update_mask)

self._sync_gca_resource()

return self

# TODO(b/170979552) Add support for predict schemata
# TODO(b/170979926) Add support for metadata and metadata schema
@classmethod
Expand Down
11 changes: 10 additions & 1 deletion tests/system/aiplatform/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TestModel(e2e_base.TestEndToEnd):
_temp_prefix = f"{_TEST_PROJECT}-vertex-staging-{_TEST_LOCATION}"

def test_upload_and_deploy_xgboost_model(self, shared_state):
"""Upload XGBoost model from local file and deploy it for prediction."""
"""Upload XGBoost model from local file and deploy it for prediction. Additionally, update model name, description and labels"""

aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

Expand Down Expand Up @@ -65,3 +65,12 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
shared_state["resources"].append(endpoint)
predict_response = endpoint.predict(instances=[[0, 0, 0]])
assert len(predict_response.predictions) == 1

model = model.update(
display_name="new_name",
description="new_description",
labels={"my_label": "updated"},
)
assert model.display_name == "new_name"
assert model.display_name == "new_description"
assert model.labels == {"my_label": "updated"}
49 changes: 49 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
encryption_spec as gca_encryption_spec,
)

from google.protobuf import field_mask_pb2

from test_endpoints import create_endpoint_mock # noqa: F401

Expand Down Expand Up @@ -177,6 +178,27 @@
_TEST_CONTAINER_REGISTRY_DESTINATION


@pytest.fixture
def mock_model():
model = mock.MagicMock(models.Model)
model.name = _TEST_ID
model._latest_future = None
model._exception = None
model._gca_resource = gca_model.Model(
display_name=_TEST_MODEL_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABEL,
)
yield model


@pytest.fixture
def update_model_mock(mock_model):
with patch.object(model_service_client.ModelServiceClient, "update_model") as mock:
mock.return_value = mock_model
yield mock


@pytest.fixture
def get_endpoint_mock():
with mock.patch.object(
Expand All @@ -199,6 +221,7 @@ def get_model_mock():
get_model_mock.return_value = gca_model.Model(
display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME,
)

yield get_model_mock


Expand Down Expand Up @@ -1660,3 +1683,29 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model(
]
staged_model_file_name = staged_model_file_path.split("/")[-1]
assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"]

@pytest.mark.usefixtures("get_model_mock")
def test_update(self, update_model_mock, get_model_mock):

test_model = models.Model(_TEST_ID)

test_model.update(
display_name=_TEST_MODEL_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABEL,
)

current_model_proto = gca_model.Model(
display_name=_TEST_MODEL_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABEL,
name=_TEST_MODEL_RESOURCE_NAME,
)

update_mask = field_mask_pb2.FieldMask(
paths=["display_name", "description", "labels"]
)

update_model_mock.assert_called_once_with(
model=current_model_proto, update_mask=update_mask
)

0 comments on commit 44e208a

Please sign in to comment.