Skip to content

Commit

Permalink
feat: add update endpoint (#1162)
Browse files Browse the repository at this point in the history
* feat: add update endpoint

* add validate_traffic and validate_traffic_split

* remove validation, add system tests

* Text fixes

* Nox blacken change

Co-authored-by: Sam Goodman <goodmansam@google.com>
  • Loading branch information
morgandu and Sam Goodman committed May 26, 2022
1 parent b4a0bee commit 0ecfe1e
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 45 deletions.
114 changes: 109 additions & 5 deletions google/cloud/aiplatform/models.py
Expand Up @@ -51,6 +51,7 @@
from google.protobuf import field_mask_pb2, json_format

_DEFAULT_MACHINE_TYPE = "n1-standard-2"
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -485,7 +486,7 @@ def _allocate_traffic(
new_traffic_split[deployed_model] += 1
unallocated_traffic -= 1

new_traffic_split["0"] = traffic_percentage
new_traffic_split[_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY] = traffic_percentage

return new_traffic_split

Expand Down Expand Up @@ -611,7 +612,6 @@ def _validate_deploy_args(
raise ValueError("Traffic percentage cannot be negative.")

elif traffic_split:
# TODO(b/172678233) verify every referenced deployed model exists
if sum(traffic_split.values()) != 100:
raise ValueError(
"Sum of all traffic within traffic split needs to be 100."
Expand Down Expand Up @@ -1290,6 +1290,110 @@ def _instantiate_prediction_client(
prediction_client=True,
)

def update(
self,
display_name: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
traffic_split: Optional[Dict[str, int]] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
update_request_timeout: Optional[float] = None,
) -> "Endpoint":
"""Updates an endpoint.
Example usage:
my_endpoint = my_endpoint.update(
display_name='my-updated-endpoint',
description='my updated description',
labels={'key': 'value'},
traffic_split={
'123456': 20,
'234567': 80,
},
)
Args:
display_name (str):
Optional. The display name of the Endpoint.
The name can be up to 128 characters long and can be consist of any UTF-8
characters.
description (str):
Optional. The description of the Endpoint.
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to organize your Endpoints.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
See https://goo.gl/xmQnxf for more information and examples of labels.
traffic_split (Dict[str, int]):
Optional. A map from a DeployedModel's ID to the percentage of this Endpoint's
traffic that should be forwarded to that DeployedModel.
If a DeployedModel's ID is not listed in this map, then it receives no traffic.
The traffic percentage values must add up to 100, or map must be empty if
the Endpoint is to not accept any traffic at a moment.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as metadata.
update_request_timeout (float):
Optional. The timeout for the update request in seconds.
Returns:
Endpoint - Updated endpoint resource.
Raises:
ValueError: If `labels` is not the correct format.
"""

self.wait()

current_endpoint_proto = self.gca_resource
copied_endpoint_proto = current_endpoint_proto.__class__(current_endpoint_proto)

update_mask: List[str] = []

if display_name:
utils.validate_display_name(display_name)
copied_endpoint_proto.display_name = display_name
update_mask.append("display_name")

if description:
copied_endpoint_proto.description = description
update_mask.append("description")

if labels:
utils.validate_labels(labels)
copied_endpoint_proto.labels = labels
update_mask.append("labels")

if traffic_split:
update_mask.append("traffic_split")
copied_endpoint_proto.traffic_split = traffic_split

update_mask = field_mask_pb2.FieldMask(paths=update_mask)

_LOGGER.log_action_start_against_resource(
"Updating",
"endpoint",
self,
)

update_endpoint_lro = self.api_client.update_endpoint(
endpoint=copied_endpoint_proto,
update_mask=update_mask,
metadata=request_metadata,
timeout=update_request_timeout,
)

_LOGGER.log_action_started_against_resource_with_lro(
"Update", "endpoint", self.__class__, update_endpoint_lro
)

update_endpoint_lro.result()

_LOGGER.log_action_completed_against_resource("endpoint", "updated", self)

return self

def predict(
self,
instances: List,
Expand Down Expand Up @@ -1445,15 +1549,15 @@ def list(
credentials=credentials,
)

def list_models(self) -> Sequence[gca_endpoint_compat.DeployedModel]:
def list_models(self) -> List[gca_endpoint_compat.DeployedModel]:
"""Returns a list of the models deployed to this Endpoint.
Returns:
deployed_models (Sequence[aiplatform.gapic.DeployedModel]):
deployed_models (List[aiplatform.gapic.DeployedModel]):
A list of the models deployed in this Endpoint.
"""
self._sync_gca_resource()
return self._gca_resource.deployed_models
return list(self._gca_resource.deployed_models)

def undeploy_all(self, sync: bool = True) -> "Endpoint":
"""Undeploys every model deployed to this Endpoint.
Expand Down
9 changes: 9 additions & 0 deletions tests/system/aiplatform/test_model_upload.py
Expand Up @@ -76,3 +76,12 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
assert model.display_name == "new_name"
assert model.description == "new_description"
assert model.labels == {"my_label": "updated"}

assert len(endpoint.list_models) == 1
endpoint.deploy(model, traffic_percentage=100)
assert len(endpoint.list_models) == 2
traffic_split = {
deployed_model.id: 50 for deployed_model in endpoint.list_models()
}
endpoint.update(traffic_split=traffic_split)
assert endpoint.traffic_split == traffic_split
133 changes: 93 additions & 40 deletions tests/unit/aiplatform/test_endpoints.py
Expand Up @@ -25,6 +25,8 @@
from google.api_core import operation as ga_operation
from google.auth import credentials as auth_credentials

from google.protobuf import field_mask_pb2

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -58,6 +60,8 @@
_TEST_ID_2 = "4366591682456584192"
_TEST_ID_3 = "5820582938582924817"
_TEST_DESCRIPTION = "test-description"
_TEST_REQUEST_METADATA = ()
_TEST_TIMEOUT = None

_TEST_ENDPOINT_NAME = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}"
Expand Down Expand Up @@ -270,6 +274,16 @@ def create_endpoint_mock():
yield create_endpoint_mock


@pytest.fixture
def update_endpoint_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "update_endpoint"
) as update_endpoint_mock:
update_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
update_endpoint_mock.return_value = update_endpoint_lro_mock
yield update_endpoint_mock


@pytest.fixture
def deploy_model_mock():
with mock.patch.object(
Expand Down Expand Up @@ -726,6 +740,54 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
timeout=None,
)

@pytest.mark.usefixtures("get_endpoint_mock")
def test_update_endpoint(self, update_endpoint_mock):
endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
endpoint.update(
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABELS,
)

expected_endpoint = gca_endpoint.Endpoint(
name=_TEST_ENDPOINT_NAME,
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABELS,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)

expected_update_mask = field_mask_pb2.FieldMask(
paths=["display_name", "description", "labels"]
)

update_endpoint_mock.assert_called_once_with(
endpoint=expected_endpoint,
update_mask=expected_update_mask,
metadata=_TEST_REQUEST_METADATA,
timeout=_TEST_TIMEOUT,
)

@pytest.mark.usefixtures("get_endpoint_with_models_mock")
def test_update_traffic_split(self, update_endpoint_mock):
endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
endpoint.update(traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10})

expected_endpoint = gca_endpoint.Endpoint(
name=_TEST_ENDPOINT_NAME,
display_name=_TEST_DISPLAY_NAME,
deployed_models=_TEST_DEPLOYED_MODELS,
traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10},
)
expected_update_mask = field_mask_pb2.FieldMask(paths=["traffic_split"])

update_endpoint_mock.assert_called_once_with(
endpoint=expected_endpoint,
update_mask=expected_update_mask,
metadata=_TEST_REQUEST_METADATA,
timeout=_TEST_TIMEOUT,
)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_deploy(self, deploy_model_mock, sync):
Expand Down Expand Up @@ -920,7 +982,7 @@ def test_deploy_raise_error_max_replica(self, sync):
)
test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@pytest.mark.usefixtures("get_endpoint_with_models_mock", "get_model_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_deploy_raise_error_traffic_split(self, sync):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -973,48 +1035,39 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):
timeout=None,
)

@pytest.mark.usefixtures("get_model_mock")
@pytest.mark.usefixtures("get_endpoint_with_models_mock", "get_model_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_deploy_with_traffic_split(self, deploy_model_mock, sync):
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "get_endpoint"
) as get_endpoint_mock:
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
name=_TEST_ENDPOINT_NAME,
traffic_split={"model1": 100},
)

test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(
model=test_model,
traffic_split={"model1": 30, "0": 70},
sync=sync,
deploy_request_timeout=None,
)
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(
model=test_model,
traffic_split={_TEST_ID: 10, _TEST_ID_2: 40, _TEST_ID_3: 10, "0": 40},
sync=sync,
deploy_request_timeout=None,
)

if not sync:
test_endpoint.wait()
automatic_resources = gca_machine_resources.AutomaticResources(
min_replica_count=1,
max_replica_count=1,
)
deployed_model = gca_endpoint.DeployedModel(
automatic_resources=automatic_resources,
model=test_model.resource_name,
display_name=None,
)
deploy_model_mock.assert_called_once_with(
endpoint=test_endpoint.resource_name,
deployed_model=deployed_model,
traffic_split={"model1": 30, "0": 70},
metadata=(),
timeout=None,
)
if not sync:
test_endpoint.wait()
automatic_resources = gca_machine_resources.AutomaticResources(
min_replica_count=1,
max_replica_count=1,
)
deployed_model = gca_endpoint.DeployedModel(
automatic_resources=automatic_resources,
model=test_model.resource_name,
display_name=None,
)
deploy_model_mock.assert_called_once_with(
endpoint=test_endpoint.resource_name,
deployed_model=deployed_model,
traffic_split={_TEST_ID: 10, _TEST_ID_2: 40, _TEST_ID_3: 10, "0": 40},
metadata=(),
timeout=None,
)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@pytest.mark.parametrize("sync", [True, False])
Expand Down

0 comments on commit 0ecfe1e

Please sign in to comment.