diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 85e3ceff7d..5501c07d2f 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -220,6 +220,9 @@ def __init__( if self.public_endpoint_domain_name: self._public_match_client = self._instantiate_public_match_client() + self._match_grpc_stub_cache = {} + self._private_service_connect_ip_address = None + @classmethod def create( cls, @@ -521,33 +524,62 @@ def _instantiate_public_match_client( def _instantiate_private_match_service_stub( self, - deployed_index_id: str, + deployed_index_id: Optional[str] = None, + ip_address: Optional[str] = None, ) -> match_service_pb2_grpc.MatchServiceStub: """Helper method to instantiate private match service stub. Args: deployed_index_id (str): - Required. The user specified ID of the - DeployedIndex. + Optional. Required for private service access endpoint. + The user specified ID of the DeployedIndex. + ip_address (str): + Optional. Required for private service connect. The ip address + the forwarding rule makes use of. Returns: stub (match_service_pb2_grpc.MatchServiceStub): Initialized match service stub. + Raises: + RuntimeError: No deployed index with id deployed_index_id found + ValueError: Should not set ip address for networks other than + private service connect. """ - # Find the deployed index by id - deployed_indexes = [ - deployed_index - for deployed_index in self.deployed_indexes - if deployed_index.id == deployed_index_id - ] + if ip_address: + # Should only set for Private Service Connect + if self.public_endpoint_domain_name: + raise ValueError( + "MatchingEngineIndexEndpoint is set to use ", + "public network. Could not establish connection using " + "provided ip address", + ) + elif self.private_service_access_network: + raise ValueError( + "MatchingEngineIndexEndpoint is set to use ", + "private service access network. Could not establish " + "connection using provided ip address", + ) + else: + # Private Service Access, find server ip for deployed index + deployed_indexes = [ + deployed_index + for deployed_index in self.deployed_indexes + if deployed_index.id == deployed_index_id + ] - if not deployed_indexes: - raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found") + if not deployed_indexes: + raise RuntimeError( + f"No deployed index with id '{deployed_index_id}' found" + ) - # Retrieve server ip from deployed index - server_ip = deployed_indexes[0].private_endpoints.match_grpc_address + # Retrieve server ip from deployed index + ip_address = deployed_indexes[0].private_endpoints.match_grpc_address - # Set up channel and stub - channel = grpc.insecure_channel("{}:10000".format(server_ip)) - return match_service_pb2_grpc.MatchServiceStub(channel) + if ip_address not in self._match_grpc_stub_cache: + # Set up channel and stub + channel = grpc.insecure_channel("{}:10000".format(ip_address)) + self._match_grpc_stub_cache[ + ip_address + ] = match_service_pb2_grpc.MatchServiceStub(channel) + return self._match_grpc_stub_cache[ip_address] @property def public_endpoint_domain_name(self) -> Optional[str]: @@ -555,6 +587,22 @@ def public_endpoint_domain_name(self) -> Optional[str]: self._assert_gca_resource_is_available() return self._gca_resource.public_endpoint_domain_name + @property + def private_service_access_network(self) -> Optional[str]: + """ "Private service access network.""" + self._assert_gca_resource_is_available() + return self._gca_resource.network + + @property + def private_service_connect_ip_address(self) -> Optional[str]: + """ "Private service connect ip address.""" + return self._private_service_connect_ip_address + + @private_service_connect_ip_address.setter + def private_service_connect_ip_address(self, ip_address: str) -> Optional[str]: + """ "Setter for private service connect ip address.""" + self._private_service_connect_ip_address = ip_address + def update( self, display_name: str, @@ -1300,7 +1348,8 @@ def read_index_datapoints( if not self._public_match_client: # Call private match service stub with BatchGetEmbeddings request embeddings = self._batch_get_embeddings( - deployed_index_id=deployed_index_id, ids=ids + deployed_index_id=deployed_index_id, + ids=ids, ) response = [] @@ -1362,7 +1411,8 @@ def _batch_get_embeddings( List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs. """ stub = self._instantiate_private_match_service_stub( - deployed_index_id=deployed_index_id + deployed_index_id=deployed_index_id, + ip_address=self._private_service_connect_ip_address, ) # Create the batch get embeddings request @@ -1420,7 +1470,8 @@ def match( List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ stub = self._instantiate_private_match_service_stub( - deployed_index_id=deployed_index_id + deployed_index_id=deployed_index_id, + ip_address=self._private_service_connect_ip_address, ) # Create the batch match request diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 6b6af65e68..76ec65692c 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -246,6 +246,7 @@ _TEST_RETURN_FULL_DATAPOINT = True _TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name" _TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"] +_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5" _TEST_READ_INDEX_DATAPOINTS_RESPONSE = [ gca_index_v1beta1.IndexDatapoint( datapoint_id="1", @@ -1137,6 +1138,54 @@ def test_private_index_endpoint_find_neighbor_queries( ) index_endpoint_match_queries_mock.assert_called_with(batch_match_request) + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_private_service_connect_endpoint_match_queries( + self, index_endpoint_match_queries_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.private_service_connect_ip_address = ( + _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS + ) + my_index_endpoint.match( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + ) + + batch_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + num_neighbors=_TEST_NUM_NEIGHBOURS, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + float_val=_TEST_QUERIES[0], + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + ) + ], + ) + ] + ) + + index_endpoint_match_queries_mock.assert_called_with(batch_request) + @pytest.mark.usefixtures("get_index_public_endpoint_mock") def test_index_public_endpoint_match_queries( self, index_public_endpoint_match_queries_mock @@ -1330,7 +1379,7 @@ def test_index_endpoint_batch_get_embeddings( index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) @pytest.mark.usefixtures("get_index_endpoint_mock") - def test_index_private_endpoint_read_index_datapoints( + def test_index_endpoint_find_neighbors_for_private_service_access( self, index_endpoint_batch_get_embeddings_mock ): aiplatform.init(project=_TEST_PROJECT) @@ -1350,3 +1399,29 @@ def test_index_private_endpoint_read_index_datapoints( index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE + + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_endpoint_find_neighbors_for_private_service_connect( + self, index_endpoint_batch_get_embeddings_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.private_service_connect_ip = ( + _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS + ) + response = my_index_endpoint.read_index_datapoints( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + ids=["1", "2"], + ) + + batch_request = match_service_pb2.BatchGetEmbeddingsRequest( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"] + ) + + index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) + + assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE