Skip to content

Commit

Permalink
fix: read_index_endpoint private endpoint support.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589880026
  • Loading branch information
lingyinw authored and Copybara-Service committed Dec 11, 2023
1 parent 0a4d772 commit 3d8835e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1285,24 +1285,32 @@ def read_index_datapoints(
"""
if not self._public_match_client:
# Call private match service stub with BatchGetEmbeddings request
response = self._batch_get_embeddings(
embeddings = self._batch_get_embeddings(
deployed_index_id=deployed_index_id, ids=ids
)
return [
gca_index_v1beta1.IndexDatapoint(

response = []
for embedding in embeddings:
index_datapoint = gca_index_v1beta1.IndexDatapoint(
datapoint_id=embedding.id,
feature_vector=embedding.float_val,
restricts=gca_index_v1beta1.IndexDatapoint.Restriction(
namespace=embedding.restricts.name,
allow_list=embedding.restricts.allow_tokens,
),
deny_list=embedding.restricts.deny_tokens,
crowding_attributes=gca_index_v1beta1.CrowdingEmbedding(
str(embedding.crowding_tag)
),
restricts=[
gca_index_v1beta1.IndexDatapoint.Restriction(
namespace=restrict.name,
allow_list=restrict.allow_tokens,
deny_list=restrict.deny_tokens,
)
for restrict in embedding.restricts
],
)
for embedding in response.embeddings
]
if embedding.crowding_attribute:
index_datapoint.crowding_tag = (
gca_index_v1beta1.IndexDatapoint.CrowdingTag(
crowding_attribute=str(embedding.crowding_attribute)
)
)
response.append(index_datapoint)
return response

# Create the ReadIndexDatapoints request
read_index_datapoints_request = (
Expand All @@ -1326,7 +1334,7 @@ def _batch_get_embeddings(
*,
deployed_index_id: str,
ids: List[str] = [],
) -> List[List[match_service_pb2.Embedding]]:
) -> List[match_service_pb2.Embedding]:
"""
Reads the datapoints/vectors of the given IDs on the specified index
which is deployed to private endpoint.
Expand Down
50 changes: 49 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,26 @@
_TEST_RETURN_FULL_DATAPOINT = True
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]
_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [
gca_index_v1beta1.IndexDatapoint(
datapoint_id="1",
feature_vector=[0.1, 0.2, 0.3],
restricts=[
gca_index_v1beta1.IndexDatapoint.Restriction(
namespace="class",
allow_list=["token_1"],
deny_list=["token_2"],
)
],
),
gca_index_v1beta1.IndexDatapoint(
datapoint_id="2",
feature_vector=[0.5, 0.2, 0.3],
crowding_tag=gca_index_v1beta1.IndexDatapoint.CrowdingTag(
crowding_attribute="1"
),
),
]


def uuid_mock():
Expand Down Expand Up @@ -505,7 +525,13 @@ def index_endpoint_batch_get_embeddings_mock():
match_service_pb2.Embedding(
id="1",
float_val=[0.1, 0.2, 0.3],
crowding_attribute=1,
restricts=[
match_service_pb2.Namespace(
name="class",
allow_tokens=["token_1"],
deny_tokens=["token_2"],
)
],
),
match_service_pb2.Embedding(
id="2",
Expand Down Expand Up @@ -1249,3 +1275,25 @@ def test_index_endpoint_batch_get_embeddings(
)

index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_find_neighbors_for_private(
self, index_endpoint_batch_get_embeddings_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

response = my_index_endpoint.read_index_datapoints(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=["1", "2"]
)

batch_request = match_service_pb2.BatchGetEmbeddingsRequest(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
)

index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE

0 comments on commit 3d8835e

Please sign in to comment.