From 1b5ae4402b74d234d0fd8c886e935b3e8919bb50 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 19 Apr 2023 11:39:28 -0700 Subject: [PATCH] feat: add support for return public endpoint dns name in matching engine PiperOrigin-RevId: 525507137 --- .../matching_engine_index_endpoint.py | 6 ++ .../test_matching_engine_index_endpoint.py | 65 ++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index cedc792532..9d5aba2db2 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -344,6 +344,12 @@ def _create( return index_obj + @property + def public_endpoint_domain_name(self) -> Optional[str]: + """Public endpoint DNS name.""" + self._assert_gca_resource_is_available() + return self._gca_resource.public_endpoint_domain_name + def update( self, display_name: str, diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 24b0c1e2c3..9d176b96b5 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -60,6 +60,9 @@ # index_endpoint _TEST_INDEX_ENDPOINT_ID = "index_endpoint_id" +_TEST_INDEX_ENDPOINT_PUBLIC_DNS = ( + "1114627793.us-central1-249381615684.vdb.vertexai.goog" +) _TEST_INDEX_ENDPOINT_NAME = f"{_TEST_PARENT}/indexEndpoints/{_TEST_INDEX_ENDPOINT_ID}" _TEST_INDEX_ENDPOINT_DISPLAY_NAME = "index_endpoint_display_name" _TEST_INDEX_ENDPOINT_DESCRIPTION = "index_endpoint_description" @@ -308,6 +311,57 @@ def get_index_endpoint_mock(): yield get_index_endpoint_mock +@pytest.fixture +def get_index_public_endpoint_mock(): + with patch.object( + index_endpoint_service_client.IndexEndpointServiceClient, "get_index_endpoint" + ) as get_index_public_endpoint_mock: + index_endpoint = gca_index_endpoint.IndexEndpoint( + name=_TEST_INDEX_ENDPOINT_NAME, + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + public_endpoint_domain_name=_TEST_INDEX_ENDPOINT_PUBLIC_DNS, + ) + index_endpoint.deployed_indexes = [ + gca_index_endpoint.DeployedIndex( + id=_TEST_DEPLOYED_INDEX_ID, + index=_TEST_INDEX_NAME, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + ), + gca_index_endpoint.DeployedIndex( + id=f"{_TEST_DEPLOYED_INDEX_ID}_2", + index=f"{_TEST_INDEX_NAME}_2", + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + ), + ] + get_index_public_endpoint_mock.return_value = index_endpoint + yield get_index_public_endpoint_mock + + @pytest.fixture def deploy_index_mock(): with patch.object( @@ -556,7 +610,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc metadata=_TEST_REQUEST_METADATA, ) - @pytest.mark.usefixtures("get_index_endpoint_mock") + @pytest.mark.usefixtures("get_index_public_endpoint_mock") def test_create_index_endpoint_with_public_endpoint_enabled( self, create_index_endpoint_mock ): @@ -569,6 +623,10 @@ def test_create_index_endpoint_with_public_endpoint_enabled( labels=_TEST_LABELS, ) + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + expected = gca_index_endpoint.IndexEndpoint( display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, description=_TEST_INDEX_ENDPOINT_DESCRIPTION, @@ -582,6 +640,11 @@ def test_create_index_endpoint_with_public_endpoint_enabled( metadata=_TEST_REQUEST_METADATA, ) + assert ( + my_index_endpoint.public_endpoint_domain_name + == _TEST_INDEX_ENDPOINT_PUBLIC_DNS + ) + def test_create_index_endpoint_missing_argument_throw_error( self, create_index_endpoint_mock ):