Skip to content

Commit

Permalink
feat: Add return_full_datapoint for MatchEngineIndexEndpoint `mat…
Browse files Browse the repository at this point in the history
…ch()`.

PiperOrigin-RevId: 597148566
  • Loading branch information
lingyinw authored and Copybara-Service committed Jan 10, 2024
1 parent d0f65fd commit ad8d9c1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,7 @@ def find_neighbors(
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
return_full_datapoint=return_full_datapoint,
)

# Create the FindNeighbors request
Expand Down Expand Up @@ -1434,6 +1435,7 @@ def match(
per_crowding_attribute_num_neighbors: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index for private endpoint only.
Expand Down Expand Up @@ -1465,6 +1467,11 @@ def match(
query time allows user to tune search performance. This value
increase result in both search accuracy and latency increase.
The value should be between 0.0 and 1.0.
return_full_datapoint (bool):
Optional. If set to true, the full datapoints (including all
vector values and of the nearest neighbors are returned.
Note that returning full datapoint will significantly increase the
latency and cost of the query.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
Expand Down Expand Up @@ -1502,6 +1509,7 @@ def match(
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
embedding_enabled=return_full_datapoint,
)
requests.append(request)

Expand Down
27 changes: 15 additions & 12 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ def test_private_index_endpoint_match_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
)

batch_request = match_service_pb2.BatchMatchRequest(
Expand All @@ -1081,6 +1082,7 @@ def test_private_index_endpoint_match_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
)
for i in range(len(_TEST_QUERIES))
],
Expand All @@ -1096,11 +1098,11 @@ def test_private_index_endpoint_find_neighbor_queries(
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
my_private_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.find_neighbors(
my_private_index_endpoint.find_neighbors(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
Expand Down Expand Up @@ -1130,6 +1132,7 @@ def test_private_index_endpoint_find_neighbor_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
)
for test_query in _TEST_QUERIES
],
Expand Down Expand Up @@ -1187,16 +1190,16 @@ def test_index_private_service_connect_endpoint_match_queries(
index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_match_queries(
def test_index_public_endpoint_find_neighbors_queries(
self, index_public_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.find_neighbors(
my_public_index_endpoint.find_neighbors(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
Expand All @@ -1208,7 +1211,7 @@ def test_index_public_endpoint_match_queries(
)

find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
index_endpoint=my_pubic_index_endpoint.resource_name,
index_endpoint=my_public_index_endpoint.resource_name,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=[
gca_match_service_v1beta1.FindNeighborsRequest.Query(
Expand Down Expand Up @@ -1241,11 +1244,11 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.find_neighbors(
my_public_index_endpoint.find_neighbors(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
Expand All @@ -1258,7 +1261,7 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
)

find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
index_endpoint=my_pubic_index_endpoint.resource_name,
index_endpoint=my_public_index_endpoint.resource_name,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=[
gca_match_service_v1beta1.FindNeighborsRequest.Query(
Expand Down Expand Up @@ -1337,18 +1340,18 @@ def test_index_public_endpoint_read_index_datapoints(
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.read_index_datapoints(
my_public_index_endpoint.read_index_datapoints(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
ids=_TEST_IDS,
)

read_index_datapoints_request = (
gca_match_service_v1beta1.ReadIndexDatapointsRequest(
index_endpoint=my_pubic_index_endpoint.resource_name,
index_endpoint=my_public_index_endpoint.resource_name,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
ids=_TEST_IDS,
)
Expand Down

0 comments on commit ad8d9c1

Please sign in to comment.