diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index a682adc6ff..d00c0a57ef 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -1273,6 +1273,7 @@ def find_neighbors( per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count, approx_num_neighbors=approx_num_neighbors, fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, + numeric_filter=numeric_filter, ) # Create the FindNeighbors request @@ -1456,6 +1457,7 @@ def match( approx_num_neighbors: Optional[int] = None, fraction_leaf_nodes_to_search_override: Optional[float] = None, low_level_batch_size: int = 0, + numeric_filter: Optional[List[NumericNamespace]] = None, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index for private endpoint only. @@ -1494,6 +1496,11 @@ def match( This field is optional, defaults to 0 if not set. A non-positive number disables low level batching, i.e. all queries are executed sequentially. + numeric_filter (Optional[list[NumericNamespace]]): + Optional. A list of NumericNamespaces for filtering the matching + results. For example: + [NumericNamespace(name="cost", value_int=5, op="GREATER")] + will match datapoints that its cost is greater than 5. Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. @@ -1513,6 +1520,7 @@ def match( # Preprocess restricts to be used for each request restricts = [] + # Token restricts if filter: for namespace in filter: restrict = match_service_pb2.Namespace() @@ -1520,6 +1528,22 @@ def match( restrict.allow_tokens.extend(namespace.allow_tokens) restrict.deny_tokens.extend(namespace.deny_tokens) restricts.append(restrict) + numeric_restricts = [] + # Numeric restricts + if numeric_filter: + for numeric_namespace in numeric_filter: + numeric_restrict = match_service_pb2.NumericNamespace() + numeric_restrict.name = numeric_namespace.name + numeric_restrict.op = match_service_pb2.NumericNamespace.Operator.Value( + numeric_namespace.op + ) + if numeric_namespace.value_int is not None: + numeric_restrict.value_int = numeric_namespace.value_int + if numeric_namespace.value_float is not None: + numeric_restrict.value_float = numeric_namespace.value_float + if numeric_namespace.value_double is not None: + numeric_restrict.value_double = numeric_namespace.value_double + numeric_restricts.append(numeric_restrict) requests = [] if queries: @@ -1532,6 +1556,7 @@ def match( per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors, approx_num_neighbors=approx_num_neighbors, fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, + numeric_restricts=numeric_restricts, ) requests.append(request) else: diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 22be012350..dba6424edb 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -237,8 +237,13 @@ ] _TEST_NUMERIC_FILTER = [ NumericNamespace(name="cost", value_double=0.3, op="EQUAL"), - NumericNamespace(name="size", value_int=10, op="GREATER"), - NumericNamespace(name="seconds", value_float=20.5, op="LESS_EQUAL"), + NumericNamespace(name="size", value_int=0, op="GREATER"), + NumericNamespace(name="seconds", value_float=-20.5, op="LESS_EQUAL"), +] +_TEST_NUMERIC_NAMESPACE = [ + match_service_pb2.NumericNamespace(name="cost", value_double=0.3, op=3), + match_service_pb2.NumericNamespace(name="size", value_int=0, op=5), + match_service_pb2.NumericNamespace(name="seconds", value_float=-20.5, op=2), ] _TEST_IDS = ["123", "456", "789"] _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3 @@ -1045,7 +1050,7 @@ def test_index_endpoint_match_queries_backward_compatibility( index_endpoint_match_queries_mock.assert_called_with(batch_request) @pytest.mark.usefixtures("get_index_endpoint_mock") - def test_private_service_access_index_endpoint_match_queries( + def test_private_service_access_service_access_index_endpoint_match_queries( self, index_endpoint_match_queries_mock ): aiplatform.init(project=_TEST_PROJECT) @@ -1063,6 +1068,7 @@ def test_private_service_access_index_endpoint_match_queries( approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE, + numeric_filter=_TEST_NUMERIC_FILTER, ) batch_request = match_service_pb2.BatchMatchRequest( @@ -1085,6 +1091,7 @@ def test_private_service_access_index_endpoint_match_queries( per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + numeric_restricts=_TEST_NUMERIC_NAMESPACE, ) for i in range(len(_TEST_QUERIES)) ], @@ -1095,7 +1102,7 @@ def test_private_service_access_index_endpoint_match_queries( index_endpoint_match_queries_mock.assert_called_with(batch_request) @pytest.mark.usefixtures("get_index_endpoint_mock") - def test_private_index_endpoint_find_neighbor_queries( + def test_index_private_service_access_endpoint_find_neighbor_queries( self, index_endpoint_match_queries_mock ): aiplatform.init(project=_TEST_PROJECT) @@ -1113,6 +1120,7 @@ def test_private_index_endpoint_find_neighbor_queries( approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, + numeric_filter=_TEST_NUMERIC_FILTER, ) batch_match_request = match_service_pb2.BatchMatchRequest( @@ -1134,6 +1142,7 @@ def test_private_index_endpoint_find_neighbor_queries( per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + numeric_restricts=_TEST_NUMERIC_NAMESPACE, ) for test_query in _TEST_QUERIES ], @@ -1331,10 +1340,10 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering( namespace="cost", value_double=0.3, op="EQUAL" ), gca_index_v1beta1.IndexDatapoint.NumericRestriction( - namespace="size", value_int=10, op="GREATER" + namespace="size", value_int=0, op="GREATER" ), gca_index_v1beta1.IndexDatapoint.NumericRestriction( - namespace="seconds", value_float=20.5, op="LESS_EQUAL" + namespace="seconds", value_float=-20.5, op="LESS_EQUAL" ), ], ),