Skip to content

Commit

Permalink
feat: Add numeric_filter to MatchingEngineIndexEndpoint match()
Browse files Browse the repository at this point in the history
… and `find_neighbor()` private endpoint queries.

PiperOrigin-RevId: 602661540
  • Loading branch information
lingyinw authored and Copybara-Service committed Jan 30, 2024
1 parent 512b82d commit 679646a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
Expand Up @@ -1273,6 +1273,7 @@ def find_neighbors(
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
numeric_filter=numeric_filter,
)

# Create the FindNeighbors request
Expand Down Expand Up @@ -1456,6 +1457,7 @@ def match(
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
low_level_batch_size: int = 0,
numeric_filter: Optional[List[NumericNamespace]] = None,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index for private endpoint only.
Expand Down Expand Up @@ -1494,6 +1496,11 @@ def match(
This field is optional, defaults to 0 if not set. A non-positive
number disables low level batching, i.e. all queries are
executed sequentially.
numeric_filter (Optional[list[NumericNamespace]]):
Optional. A list of NumericNamespaces for filtering the matching
results. For example:
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
will match datapoints that its cost is greater than 5.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
Expand All @@ -1513,13 +1520,30 @@ def match(

# Preprocess restricts to be used for each request
restricts = []
# Token restricts
if filter:
for namespace in filter:
restrict = match_service_pb2.Namespace()
restrict.name = namespace.name
restrict.allow_tokens.extend(namespace.allow_tokens)
restrict.deny_tokens.extend(namespace.deny_tokens)
restricts.append(restrict)
numeric_restricts = []
# Numeric restricts
if numeric_filter:
for numeric_namespace in numeric_filter:
numeric_restrict = match_service_pb2.NumericNamespace()
numeric_restrict.name = numeric_namespace.name
numeric_restrict.op = match_service_pb2.NumericNamespace.Operator.Value(
numeric_namespace.op
)
if numeric_namespace.value_int is not None:
numeric_restrict.value_int = numeric_namespace.value_int
if numeric_namespace.value_float is not None:
numeric_restrict.value_float = numeric_namespace.value_float
if numeric_namespace.value_double is not None:
numeric_restrict.value_double = numeric_namespace.value_double
numeric_restricts.append(numeric_restrict)

requests = []
if queries:
Expand All @@ -1532,6 +1556,7 @@ def match(
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
numeric_restricts=numeric_restricts,
)
requests.append(request)
else:
Expand Down
21 changes: 15 additions & 6 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Expand Up @@ -237,8 +237,13 @@
]
_TEST_NUMERIC_FILTER = [
NumericNamespace(name="cost", value_double=0.3, op="EQUAL"),
NumericNamespace(name="size", value_int=10, op="GREATER"),
NumericNamespace(name="seconds", value_float=20.5, op="LESS_EQUAL"),
NumericNamespace(name="size", value_int=0, op="GREATER"),
NumericNamespace(name="seconds", value_float=-20.5, op="LESS_EQUAL"),
]
_TEST_NUMERIC_NAMESPACE = [
match_service_pb2.NumericNamespace(name="cost", value_double=0.3, op=3),
match_service_pb2.NumericNamespace(name="size", value_int=0, op=5),
match_service_pb2.NumericNamespace(name="seconds", value_float=-20.5, op=2),
]
_TEST_IDS = ["123", "456", "789"]
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
Expand Down Expand Up @@ -1045,7 +1050,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_private_service_access_index_endpoint_match_queries(
def test_private_service_access_service_access_index_endpoint_match_queries(
self, index_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -1063,6 +1068,7 @@ def test_private_service_access_index_endpoint_match_queries(
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
numeric_filter=_TEST_NUMERIC_FILTER,
)

batch_request = match_service_pb2.BatchMatchRequest(
Expand All @@ -1085,6 +1091,7 @@ def test_private_service_access_index_endpoint_match_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
numeric_restricts=_TEST_NUMERIC_NAMESPACE,
)
for i in range(len(_TEST_QUERIES))
],
Expand All @@ -1095,7 +1102,7 @@ def test_private_service_access_index_endpoint_match_queries(
index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_private_index_endpoint_find_neighbor_queries(
def test_index_private_service_access_endpoint_find_neighbor_queries(
self, index_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -1113,6 +1120,7 @@ def test_private_index_endpoint_find_neighbor_queries(
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
numeric_filter=_TEST_NUMERIC_FILTER,
)

batch_match_request = match_service_pb2.BatchMatchRequest(
Expand All @@ -1134,6 +1142,7 @@ def test_private_index_endpoint_find_neighbor_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
numeric_restricts=_TEST_NUMERIC_NAMESPACE,
)
for test_query in _TEST_QUERIES
],
Expand Down Expand Up @@ -1331,10 +1340,10 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering(
namespace="cost", value_double=0.3, op="EQUAL"
),
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
namespace="size", value_int=10, op="GREATER"
namespace="size", value_int=0, op="GREATER"
),
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
namespace="seconds", value_float=20.5, op="LESS_EQUAL"
namespace="seconds", value_float=-20.5, op="LESS_EQUAL"
),
],
),
Expand Down

0 comments on commit 679646a

Please sign in to comment.