diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 7b82a482c5..cedc792532 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -28,7 +28,9 @@ matching_engine_index_endpoint as gca_matching_engine_index_endpoint, ) from google.cloud.aiplatform.matching_engine._protos import match_service_pb2 -from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc +from google.cloud.aiplatform.matching_engine._protos import ( + match_service_pb2_grpc, +) from google.protobuf import field_mask_pb2 import grpc @@ -130,6 +132,7 @@ def create( cls, display_name: str, network: Optional[str] = None, + public_endpoint_enabled: Optional[bool] = False, description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, project: Optional[str] = None, @@ -163,6 +166,9 @@ def create( projects/{project}/global/networks/{network}. Where {project} is a project number, as in '12345', and {network} is network name. + public_endpoint_enabled (bool): + Optional. If true, the deployed index will be + accessible through public endpoint. description (str): Optional. The description of the IndexEndpoint. labels (Dict[str, str]): @@ -203,15 +209,20 @@ def create( """ network = network or initializer.global_config.network - if not network: + if not network and not public_endpoint_enabled: raise ValueError( - "Please provide `network` argument or set network" - "using aiplatform.init(network=...)" + "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint" + ) + + if network and public_endpoint_enabled: + raise ValueError( + "`network` and `public_endpoint_enabled` argument should not be set at the same time" ) return cls._create( display_name=display_name, network=network, + public_endpoint_enabled=public_endpoint_enabled, description=description, labels=labels, project=project, @@ -227,6 +238,7 @@ def _create( cls, display_name: str, network: Optional[str] = None, + public_endpoint_enabled: Optional[bool] = False, description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, project: Optional[str] = None, @@ -253,6 +265,9 @@ def _create( projects/{project}/global/networks/{network}. Where {project} is a project number, as in '12345', and {network} is network name. + public_endpoint_enabled (bool): + Optional. If true, the deployed index will be + accessible through public endpoint. description (str): Optional. The description of the IndexEndpoint. labels (Dict[str, str]): @@ -288,9 +303,17 @@ def _create( Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object """ - gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint( - display_name=display_name, description=description, network=network - ) + + if public_endpoint_enabled: + gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint( + display_name=display_name, + description=description, + public_endpoint_enabled=public_endpoint_enabled, + ) + else: + gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint( + display_name=display_name, description=description, network=network + ) if labels: utils.validate_labels(labels) diff --git a/tests/system/aiplatform/test_matching_engine_index.py b/tests/system/aiplatform/test_matching_engine_index.py index 110baa37ab..4955c92e39 100644 --- a/tests/system/aiplatform/test_matching_engine_index.py +++ b/tests/system/aiplatform/test_matching_engine_index.py @@ -52,11 +52,15 @@ # ENDPOINT _TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name" +_TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME = "public_endpoint_name" _TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint" +_TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION = "my public endpoint" # DEPLOYED INDEX _TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}".replace("-", "_") _TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}" +_TEST_DEPLOYED_INDEX_ID_PUBLIC = f"deployed_index_id_{uuid.uuid4()}".replace("-", "_") +_TEST_DEPLOYED_INDEX_DISPLAY_NAME_PUBLIC = f"deployed_index_display_name_{uuid.uuid4()}" _TEST_MIN_REPLICA_COUNT_UPDATED = 4 _TEST_MAX_REPLICA_COUNT_UPDATED = 4 @@ -241,6 +245,27 @@ def test_create_get_list_matching_engine_index(self, shared_state): assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION + # Create endpoint and check that it is listed + public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION, + public_endpoint_enabled=True, + labels=_TEST_LABELS, + ) + assert public_index_endpoint.resource_name in [ + index_endpoint.resource_name + for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list() + ] + + assert public_index_endpoint.labels == _TEST_LABELS + assert ( + public_index_endpoint.display_name + == _TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME + ) + assert ( + public_index_endpoint.description == _TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION + ) + shared_state["resources"].append(my_index_endpoint) # Deploy endpoint @@ -250,6 +275,15 @@ def test_create_get_list_matching_engine_index(self, shared_state): display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, ) + # Deploy public endpoint + public_index_endpoint = public_index_endpoint.deploy_index( + index=index, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID_PUBLIC, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME_PUBLIC, + min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED, + max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED, + ) + # Update endpoint updated_index_endpoint = my_index_endpoint.update( display_name=_TEST_DISPLAY_NAME_UPDATE, @@ -268,6 +302,7 @@ def test_create_get_list_matching_engine_index(self, shared_state): max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED, ) + # deployed index on private endpoint. deployed_index = my_index_endpoint.deployed_indexes[0] assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID @@ -281,6 +316,20 @@ def test_create_get_list_matching_engine_index(self, shared_state): == _TEST_MAX_REPLICA_COUNT_UPDATED ) + # deployed index on public endpoint. + deployed_index_public = public_index_endpoint.deployed_indexes[0] + + assert deployed_index_public.id == _TEST_DEPLOYED_INDEX_ID_PUBLIC + assert deployed_index_public.index == index.resource_name + assert ( + deployed_index_public.automatic_resources.min_replica_count + == _TEST_MIN_REPLICA_COUNT_UPDATED + ) + assert ( + deployed_index_public.automatic_resources.max_replica_count + == _TEST_MAX_REPLICA_COUNT_UPDATED + ) + # TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC. # results = my_index_endpoint.match( # deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY] diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 5dce7878e5..24b0c1e2c3 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -547,6 +547,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc network=_TEST_INDEX_ENDPOINT_VPC_NETWORK, description=_TEST_INDEX_ENDPOINT_DESCRIPTION, labels=_TEST_LABELS, + public_endpoint_enabled=False, ) create_index_endpoint_mock.assert_called_once_with( @@ -555,6 +556,66 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_create_index_endpoint_with_public_endpoint_enabled( + self, create_index_endpoint_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + public_endpoint_enabled=True, + labels=_TEST_LABELS, + ) + + expected = gca_index_endpoint.IndexEndpoint( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + public_endpoint_enabled=True, + labels=_TEST_LABELS, + ) + + create_index_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + index_endpoint=expected, + metadata=_TEST_REQUEST_METADATA, + ) + + def test_create_index_endpoint_missing_argument_throw_error( + self, create_index_endpoint_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint" + + with pytest.raises(ValueError) as exception: + _ = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + labels=_TEST_LABELS, + ) + + assert str(exception.value) == expected_message + + def test_create_index_endpoint_set_both_throw_error( + self, create_index_endpoint_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time" + + with pytest.raises(ValueError) as exception: + _ = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + public_endpoint_enabled=True, + network=_TEST_INDEX_ENDPOINT_VPC_NETWORK, + labels=_TEST_LABELS, + ) + + assert str(exception.value) == expected_message + @pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock") def test_deploy_index(self, deploy_index_mock, undeploy_index_mock): aiplatform.init(project=_TEST_PROJECT)