Skip to content

Commit

Permalink
feat: add support for create public index endpoint in matching engine
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 524917003
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Apr 17, 2023
1 parent 4d032d5 commit 7e6022b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 7 deletions.
Expand Up @@ -28,7 +28,9 @@
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
)
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc
from google.cloud.aiplatform.matching_engine._protos import (
match_service_pb2_grpc,
)
from google.protobuf import field_mask_pb2

import grpc
Expand Down Expand Up @@ -130,6 +132,7 @@ def create(
cls,
display_name: str,
network: Optional[str] = None,
public_endpoint_enabled: Optional[bool] = False,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
Expand Down Expand Up @@ -163,6 +166,9 @@ def create(
projects/{project}/global/networks/{network}. Where
{project} is a project number, as in '12345', and {network}
is network name.
public_endpoint_enabled (bool):
Optional. If true, the deployed index will be
accessible through public endpoint.
description (str):
Optional. The description of the IndexEndpoint.
labels (Dict[str, str]):
Expand Down Expand Up @@ -203,15 +209,20 @@ def create(
"""
network = network or initializer.global_config.network

if not network:
if not network and not public_endpoint_enabled:
raise ValueError(
"Please provide `network` argument or set network"
"using aiplatform.init(network=...)"
"Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
)

if network and public_endpoint_enabled:
raise ValueError(
"`network` and `public_endpoint_enabled` argument should not be set at the same time"
)

return cls._create(
display_name=display_name,
network=network,
public_endpoint_enabled=public_endpoint_enabled,
description=description,
labels=labels,
project=project,
Expand All @@ -227,6 +238,7 @@ def _create(
cls,
display_name: str,
network: Optional[str] = None,
public_endpoint_enabled: Optional[bool] = False,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
Expand All @@ -253,6 +265,9 @@ def _create(
projects/{project}/global/networks/{network}. Where
{project} is a project number, as in '12345', and {network}
is network name.
public_endpoint_enabled (bool):
Optional. If true, the deployed index will be
accessible through public endpoint.
description (str):
Optional. The description of the IndexEndpoint.
labels (Dict[str, str]):
Expand Down Expand Up @@ -288,9 +303,17 @@ def _create(
Returns:
MatchingEngineIndexEndpoint - IndexEndpoint resource object
"""
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name, description=description, network=network
)

if public_endpoint_enabled:
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name,
description=description,
public_endpoint_enabled=public_endpoint_enabled,
)
else:
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
display_name=display_name, description=description, network=network
)

if labels:
utils.validate_labels(labels)
Expand Down
49 changes: 49 additions & 0 deletions tests/system/aiplatform/test_matching_engine_index.py
Expand Up @@ -52,11 +52,15 @@

# ENDPOINT
_TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name"
_TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME = "public_endpoint_name"
_TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint"
_TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION = "my public endpoint"

# DEPLOYED INDEX
_TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}".replace("-", "_")
_TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}"
_TEST_DEPLOYED_INDEX_ID_PUBLIC = f"deployed_index_id_{uuid.uuid4()}".replace("-", "_")
_TEST_DEPLOYED_INDEX_DISPLAY_NAME_PUBLIC = f"deployed_index_display_name_{uuid.uuid4()}"
_TEST_MIN_REPLICA_COUNT_UPDATED = 4
_TEST_MAX_REPLICA_COUNT_UPDATED = 4

Expand Down Expand Up @@ -241,6 +245,27 @@ def test_create_get_list_matching_engine_index(self, shared_state):
assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME
assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION

# Create endpoint and check that it is listed
public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
labels=_TEST_LABELS,
)
assert public_index_endpoint.resource_name in [
index_endpoint.resource_name
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
]

assert public_index_endpoint.labels == _TEST_LABELS
assert (
public_index_endpoint.display_name
== _TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME
)
assert (
public_index_endpoint.description == _TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION
)

shared_state["resources"].append(my_index_endpoint)

# Deploy endpoint
Expand All @@ -250,6 +275,15 @@ def test_create_get_list_matching_engine_index(self, shared_state):
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
)

# Deploy public endpoint
public_index_endpoint = public_index_endpoint.deploy_index(
index=index,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID_PUBLIC,
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME_PUBLIC,
min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED,
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
)

# Update endpoint
updated_index_endpoint = my_index_endpoint.update(
display_name=_TEST_DISPLAY_NAME_UPDATE,
Expand All @@ -268,6 +302,7 @@ def test_create_get_list_matching_engine_index(self, shared_state):
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
)

# deployed index on private endpoint.
deployed_index = my_index_endpoint.deployed_indexes[0]

assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID
Expand All @@ -281,6 +316,20 @@ def test_create_get_list_matching_engine_index(self, shared_state):
== _TEST_MAX_REPLICA_COUNT_UPDATED
)

# deployed index on public endpoint.
deployed_index_public = public_index_endpoint.deployed_indexes[0]

assert deployed_index_public.id == _TEST_DEPLOYED_INDEX_ID_PUBLIC
assert deployed_index_public.index == index.resource_name
assert (
deployed_index_public.automatic_resources.min_replica_count
== _TEST_MIN_REPLICA_COUNT_UPDATED
)
assert (
deployed_index_public.automatic_resources.max_replica_count
== _TEST_MAX_REPLICA_COUNT_UPDATED
)

# TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC.
# results = my_index_endpoint.match(
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY]
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Expand Up @@ -547,6 +547,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
public_endpoint_enabled=False,
)

create_index_endpoint_mock.assert_called_once_with(
Expand All @@ -555,6 +556,66 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_create_index_endpoint_with_public_endpoint_enabled(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
labels=_TEST_LABELS,
)

expected = gca_index_endpoint.IndexEndpoint(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
labels=_TEST_LABELS,
)

create_index_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
index_endpoint=expected,
metadata=_TEST_REQUEST_METADATA,
)

def test_create_index_endpoint_missing_argument_throw_error(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"

with pytest.raises(ValueError) as exception:
_ = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
labels=_TEST_LABELS,
)

assert str(exception.value) == expected_message

def test_create_index_endpoint_set_both_throw_error(
self, create_index_endpoint_mock
):
aiplatform.init(project=_TEST_PROJECT)

expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time"

with pytest.raises(ValueError) as exception:
_ = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
public_endpoint_enabled=True,
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
labels=_TEST_LABELS,
)

assert str(exception.value) == expected_message

@pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock")
def test_deploy_index(self, deploy_index_mock, undeploy_index_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit 7e6022b

Please sign in to comment.