Skip to content

Commit

Permalink
feat: add per_crowding_attribute_neighbor_count, `approx_num_neighb…
Browse files Browse the repository at this point in the history
…ors`, `fraction_leaf_nodes_to_search_override`, and `return_full_datapoint` to MatchingEngineIndexEndpoint `find_neighbors`

PiperOrigin-RevId: 579967420
  • Loading branch information
lingyinw authored and Copybara-Service committed Nov 6, 2023
1 parent a0103c5 commit 33c551e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
Expand Up @@ -956,6 +956,10 @@ def find_neighbors(
queries: List[List[float]],
num_neighbors: int = 10,
filter: Optional[List[Namespace]] = [],
per_crowding_attribute_neighbor_count: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
Expand All @@ -979,25 +983,58 @@ def find_neighbors(
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
that satisfy "red color" but not include datapoints with "squared shape".
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
per_crowding_attribute_neighbor_count (int):
Optional. Crowding is a constraint on a neighbor list produced
by nearest neighbor search requiring that no more than some
value k' of the k neighbors returned have the same value of
crowding_attribute. It's used for improving result diversity.
This field is the maximum number of matches with the same crowding tag.
approx_num_neighbors (int):
Optional. The number of neighbors to find via approximate search
before exact reordering is performed. If not set, the default
value from scam config is used; if set, this value must be > 0.
fraction_leaf_nodes_to_search_override (float):
Optional. The fraction of the number of leaves to search, set at
query time allows user to tune search performance. This value
increase result in both search accuracy and latency increase.
The value should be between 0.0 and 1.0.
return_full_datapoint (bool):
Optional. If set to true, the full datapoints (including all
vector values and of the nearest neighbors are returned.
Note that returning full datapoint will significantly increase the
latency and cost of the query.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""

if not self._public_match_client:
raise ValueError(
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
"Please make sure index has been deployed to public endpoint,and follow the example usage to call this method."
)

# Create the FindNeighbors request
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest()
find_neighbors_request.index_endpoint = self.resource_name
find_neighbors_request.deployed_index_id = deployed_index_id
find_neighbors_request.return_full_datapoint = return_full_datapoint

for query in queries:
find_neighbors_query = (
gca_match_service_v1beta1.FindNeighborsRequest.Query()
)
find_neighbors_query.neighbor_count = num_neighbors
find_neighbors_query.per_crowding_attribute_neighbor_count = (
per_crowding_attribute_neighbor_count
)
find_neighbors_query.approximate_neighbor_count = approx_num_neighbors
find_neighbors_query.fraction_leaf_nodes_to_search_override = (
fraction_leaf_nodes_to_search_override
)
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Expand Up @@ -234,6 +234,8 @@
_TEST_IDS = ["123", "456", "789"]
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
_TEST_APPROX_NUM_NEIGHBORS = 2
_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8
_TEST_RETURN_FULL_DATAPOINT = True


def uuid_mock():
Expand Down Expand Up @@ -954,6 +956,10 @@ def test_index_public_endpoint_match_queries(
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
)

find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
Expand All @@ -972,8 +978,12 @@ def test_index_public_endpoint_match_queries(
)
],
),
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
)
],
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
)

index_public_endpoint_match_queries_mock.assert_called_with(
Expand Down

2 comments on commit 33c551e

@tristincodes
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lingyinw Should the return_full_datapoint flag also conditionally trigger returning all of the restricts from the datapoint if it is set to true? Otherwise, what is the point of wanting to return the whole datapoint if the MatchNeighbor is still only being returned that includes the ID and distance? What does return_full_datapoint do for the caller of the find_neighbors method?

@lingyinw
Copy link
Contributor Author

@lingyinw lingyinw commented on 33c551e Jan 19, 2024 via email

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.