Skip to content

Commit

Permalink
feat: Support private service connect for `MatchingEngineIndexEndpoin…
Browse files Browse the repository at this point in the history
…t` `match()` and `read_index_datapoints()`.

PiperOrigin-RevId: 596852286
  • Loading branch information
lingyinw authored and Copybara-Service committed Jan 9, 2024
1 parent 776d0da commit 61cff4b
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 20 deletions.
Expand Up @@ -220,6 +220,9 @@ def __init__(
if self.public_endpoint_domain_name:
self._public_match_client = self._instantiate_public_match_client()

self._match_grpc_stub_cache = {}
self._private_service_connect_ip_address = None

@classmethod
def create(
cls,
Expand Down Expand Up @@ -521,40 +524,85 @@ def _instantiate_public_match_client(

def _instantiate_private_match_service_stub(
self,
deployed_index_id: str,
deployed_index_id: Optional[str] = None,
ip_address: Optional[str] = None,
) -> match_service_pb2_grpc.MatchServiceStub:
"""Helper method to instantiate private match service stub.
Args:
deployed_index_id (str):
Required. The user specified ID of the
DeployedIndex.
Optional. Required for private service access endpoint.
The user specified ID of the DeployedIndex.
ip_address (str):
Optional. Required for private service connect. The ip address
the forwarding rule makes use of.
Returns:
stub (match_service_pb2_grpc.MatchServiceStub):
Initialized match service stub.
Raises:
RuntimeError: No deployed index with id deployed_index_id found
ValueError: Should not set ip address for networks other than
private service connect.
"""
# Find the deployed index by id
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]
if ip_address:
# Should only set for Private Service Connect
if self.public_endpoint_domain_name:
raise ValueError(
"MatchingEngineIndexEndpoint is set to use ",
"public network. Could not establish connection using "
"provided ip address",
)
elif self.private_service_access_network:
raise ValueError(
"MatchingEngineIndexEndpoint is set to use ",
"private service access network. Could not establish "
"connection using provided ip address",
)
else:
# Private Service Access, find server ip for deployed index
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]

if not deployed_indexes:
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
if not deployed_indexes:
raise RuntimeError(
f"No deployed index with id '{deployed_index_id}' found"
)

# Retrieve server ip from deployed index
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
# Retrieve server ip from deployed index
ip_address = deployed_indexes[0].private_endpoints.match_grpc_address

# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(server_ip))
return match_service_pb2_grpc.MatchServiceStub(channel)
if ip_address not in self._match_grpc_stub_cache:
# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(ip_address))
self._match_grpc_stub_cache[
ip_address
] = match_service_pb2_grpc.MatchServiceStub(channel)
return self._match_grpc_stub_cache[ip_address]

@property
def public_endpoint_domain_name(self) -> Optional[str]:
"""Public endpoint DNS name."""
self._assert_gca_resource_is_available()
return self._gca_resource.public_endpoint_domain_name

@property
def private_service_access_network(self) -> Optional[str]:
""" "Private service access network."""
self._assert_gca_resource_is_available()
return self._gca_resource.network

@property
def private_service_connect_ip_address(self) -> Optional[str]:
""" "Private service connect ip address."""
return self._private_service_connect_ip_address

@private_service_connect_ip_address.setter
def private_service_connect_ip_address(self, ip_address: str) -> Optional[str]:
""" "Setter for private service connect ip address."""
self._private_service_connect_ip_address = ip_address

def update(
self,
display_name: str,
Expand Down Expand Up @@ -1300,7 +1348,8 @@ def read_index_datapoints(
if not self._public_match_client:
# Call private match service stub with BatchGetEmbeddings request
embeddings = self._batch_get_embeddings(
deployed_index_id=deployed_index_id, ids=ids
deployed_index_id=deployed_index_id,
ids=ids,
)

response = []
Expand Down Expand Up @@ -1362,7 +1411,8 @@ def _batch_get_embeddings(
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
"""
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
deployed_index_id=deployed_index_id,
ip_address=self._private_service_connect_ip_address,
)

# Create the batch get embeddings request
Expand Down Expand Up @@ -1420,7 +1470,8 @@ def match(
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
deployed_index_id=deployed_index_id,
ip_address=self._private_service_connect_ip_address,
)

# Create the batch match request
Expand Down
77 changes: 76 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Expand Up @@ -246,6 +246,7 @@
_TEST_RETURN_FULL_DATAPOINT = True
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5"
_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [
gca_index_v1beta1.IndexDatapoint(
datapoint_id="1",
Expand Down Expand Up @@ -1137,6 +1138,54 @@ def test_private_index_endpoint_find_neighbor_queries(
)
index_endpoint_match_queries_mock.assert_called_with(batch_match_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_private_service_connect_endpoint_match_queries(
self, index_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_index_endpoint.private_service_connect_ip_address = (
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
)
my_index_endpoint.match(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
)

batch_request = match_service_pb2.BatchMatchRequest(
requests=[
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
requests=[
match_service_pb2.MatchRequest(
num_neighbors=_TEST_NUM_NEIGHBOURS,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
float_val=_TEST_QUERIES[0],
restricts=[
match_service_pb2.Namespace(
name="class",
allow_tokens=["token_1"],
deny_tokens=["token_2"],
)
],
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
)
],
)
]
)

index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_match_queries(
self, index_public_endpoint_match_queries_mock
Expand Down Expand Up @@ -1330,7 +1379,7 @@ def test_index_endpoint_batch_get_embeddings(
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_private_endpoint_read_index_datapoints(
def test_index_endpoint_find_neighbors_for_private_service_access(
self, index_endpoint_batch_get_embeddings_mock
):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -1350,3 +1399,29 @@ def test_index_private_endpoint_read_index_datapoints(
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_find_neighbors_for_private_service_connect(
self, index_endpoint_batch_get_embeddings_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_index_endpoint.private_service_connect_ip = (
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
)
response = my_index_endpoint.read_index_datapoints(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
ids=["1", "2"],
)

batch_request = match_service_pb2.BatchGetEmbeddingsRequest(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
)

index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE

0 comments on commit 61cff4b

Please sign in to comment.