diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index ab6ad877e1..4986581b7a 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -1071,6 +1071,8 @@ def match( self, deployed_index_id: str, queries: List[List[float]], + per_crowding_attribute_num_neighbors: int, + approx_num_neighbors: int, num_neighbors: int = 1, filter: Optional[List[Namespace]] = [], ) -> List[List[MatchNeighbor]]: @@ -1081,6 +1083,15 @@ def match( Required. The ID of the DeployedIndex to match the queries against. queries (List[List[float]]): Required. A list of queries. Each query is a list of floats, representing a single embedding. + per_crowding_attribute_num_neighbors (int): + Optional. Crowding is a constraint on a neighbor list produced by nearest neighbor + search requiring that no more than some value k' of the k neighbors + returned have the same value of crowding_attribute. + It's used for improving result diversity. + This field is the maximum number of matches with the same crowding tag. + approx_num_neighbors (int): + The number of neighbors to find via approximate search before exact reordering is performed. + If not set, the default value from scam config is used; if set, this value must be > 0. num_neighbors (int): Required. The number of nearest neighbors to be retrieved from database for each query. @@ -1123,6 +1134,8 @@ def match( num_neighbors=num_neighbors, deployed_index_id=deployed_index_id, float_val=query, + per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors, + approx_num_neighbors=approx_num_neighbors, ) for namespace in filter: restrict = match_service_pb2.Namespace() diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 87ab5a9c5f..0d11d16757 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -232,6 +232,8 @@ Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"]) ] _TEST_IDS = ["123", "456", "789"] +_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3 +_TEST_APPROX_NUM_NEIGHBORS = 2 def uuid_mock(): @@ -866,6 +868,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, filter=_TEST_FILTER, + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, ) batch_request = match_service_pb2.BatchMatchRequest( @@ -884,6 +888,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): deny_tokens=["token_2"], ) ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, ) ], )