Skip to content

Commit

Permalink
feat: add support for per_crowding_attribute_num_neighbors `approx_…
Browse files Browse the repository at this point in the history
…num_neighbors`to MatchingEngineIndexEndpoint `match()`

PiperOrigin-RevId: 578666956
  • Loading branch information
lingyinw authored and Copybara-Service committed Nov 1, 2023
1 parent 77dec1e commit 4e357d5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
Expand Up @@ -1071,6 +1071,8 @@ def match(
self,
deployed_index_id: str,
queries: List[List[float]],
per_crowding_attribute_num_neighbors: int,
approx_num_neighbors: int,
num_neighbors: int = 1,
filter: Optional[List[Namespace]] = [],
) -> List[List[MatchNeighbor]]:
Expand All @@ -1081,6 +1083,15 @@ def match(
Required. The ID of the DeployedIndex to match the queries against.
queries (List[List[float]]):
Required. A list of queries. Each query is a list of floats, representing a single embedding.
per_crowding_attribute_num_neighbors (int):
Optional. Crowding is a constraint on a neighbor list produced by nearest neighbor
search requiring that no more than some value k' of the k neighbors
returned have the same value of crowding_attribute.
It's used for improving result diversity.
This field is the maximum number of matches with the same crowding tag.
approx_num_neighbors (int):
The number of neighbors to find via approximate search before exact reordering is performed.
If not set, the default value from scam config is used; if set, this value must be > 0.
num_neighbors (int):
Required. The number of nearest neighbors to be retrieved from database for
each query.
Expand Down Expand Up @@ -1123,6 +1134,8 @@ def match(
num_neighbors=num_neighbors,
deployed_index_id=deployed_index_id,
float_val=query,
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
)
for namespace in filter:
restrict = match_service_pb2.Namespace()
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Expand Up @@ -232,6 +232,8 @@
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
]
_TEST_IDS = ["123", "456", "789"]
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
_TEST_APPROX_NUM_NEIGHBORS = 2


def uuid_mock():
Expand Down Expand Up @@ -866,6 +868,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
)

batch_request = match_service_pb2.BatchMatchRequest(
Expand All @@ -884,6 +888,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
deny_tokens=["token_2"],
)
],
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
)
],
)
Expand Down

0 comments on commit 4e357d5

Please sign in to comment.