Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Search By ID #752

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 64 additions & 56 deletions grpc-proto/gen/milvus_pb2.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions grpc-proto/milvus.proto
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ message SearchRequest {
repeated common.KeyValuePair search_params = 9; // must
uint64 travel_timestamp = 10;
uint64 guarantee_timestamp = 11; // guarantee_timestamp
schema.IDs searchIDs = 12; // search by ids
}

message Hits {
Expand Down
13 changes: 13 additions & 0 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ def is_legal_search_data(data: Any) -> bool:

return True

def is_legal_search_ids(ids):
if not isinstance(ids, list):
return False

for vector in ids:
if not isinstance(vector, int):
return False

return True


def is_legal_output_fields(output_fields: Any) -> bool:
if output_fields is None:
Expand Down Expand Up @@ -320,6 +330,9 @@ def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too
elif key in ("search_data",):
if not is_legal_search_data(value):
_raise_param_error(key, value)
elif key in ("search_ids",):
if not is_legal_search_ids(value):
_raise_param_error(key, value)
elif key in ("output_fields",):
if not is_legal_output_fields(value):
_raise_param_error(key, value)
Expand Down
18 changes: 18 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,24 @@ def search(self, collection_name, data, anns_field, param, limit,
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
return self._execute_search_requests(requests, timeout, **_kwargs)

@error_handler(None)
@check_has_collection
def search_by_id(self, collection_name, search_ids, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
timeout=None, round_decimal=-1, **kwargs):
## do vector querying
_kwargs = copy.deepcopy(kwargs)
collection_schema = self.describe_collection(collection_name, timeout)
auto_id = collection_schema["auto_id"]
_kwargs["schema"] = collection_schema
requests = Prepare.search_requests_with_ids(collection_name, search_ids, anns_field, param, limit, expression,
partition_names, output_fields, round_decimal, **_kwargs)
_kwargs.pop("schema")
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
return self._execute_search_requests(requests, timeout, **_kwargs)


@error_handler(None)
def get_query_segment_infos(self, collection_name, timeout=30, **kwargs):
Expand Down
55 changes: 55 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,61 @@ def dump(v):

return requests

@classmethod
def search_requests_with_ids(cls, collection_name, search_ids, anns_field, param, limit, expr=None,
partition_names=None,
output_fields=None, round_decimal=-1, **kwargs):
schema = kwargs.get("schema", None)
fields_schema = schema.get("fields", None) # list
fields_name_locs = {fields_schema[loc]["name"]: loc
for loc in range(len(fields_schema))}

requests = []

if len(search_ids) <= 0:
return requests

nq = len(search_ids)
## TODO: add MaxSearchResultSize check

if anns_field not in fields_name_locs:
raise ParamError(f"Field {anns_field} doesn't exist in schema")

param_copy = copy.deepcopy(param)
metric_type = param_copy.pop("metric_type", "L2")
params = param_copy.pop("params", {})
if not isinstance(params, dict):
raise ParamError("Search params must be a dict")
search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params,
"round_decimal": round_decimal}

def dump(v):
if isinstance(v, dict):
return ujson.dumps(v)
return str(v)

request = milvus_types.SearchRequest(
collection_name=collection_name,
partition_names=partition_names,
output_fields=output_fields,
)

request.dsl_type = common_types.DslType.BoolExprV1
if expr is not None:
request.dsl = expr
request.search_params.extend([common_types.KeyValuePair(key=str(key), value=dump(value))
for key, value in search_params.items()])

# extract_search_ids
if (not isinstance(search_ids, list)) or len(search_ids) == 0 or not isinstance(search_ids[0], int):
raise ParamError("search ids array is empty or not a list or ids are not int type")

request.searchIDs.int_id.data.extend(search_ids)

requests.append(request)

return requests

@classmethod
def create_alias_request(cls, collection_name, alias):
return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias)
Expand Down
54 changes: 54 additions & 0 deletions pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,60 @@ def search(self, collection_name, data, anns_field, param, limit, expression=Non
kwargs["_deploy_mode"] = self._deploy_mode
return handler.search(collection_name, data, anns_field, param, limit, expression,
partition_names, output_fields, timeout, round_decimal, **kwargs)

@retry_on_rpc_failure(retry_times=10, wait=1)
def search_by_id(self, collection_name, search_ids, anns_field, param, limit, expression=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
"""
Searches a collection based on the given expression and returns query results.
:param collection_name: The name of the collection to search.
:type collection_name: str
:param search_ids: List of ids of the vectors to search, the length of search_ids is number of query (nq).
:type search_ids: list[int]
:param anns_field: The vector field used to search of collection.
:type anns_field: str
:param param: The parameters of search, such as nprobe, etc.
:type param: dict
:param limit: The max number of returned record, we also called this parameter as topk.
:type limit: int
:param expression: The boolean expression used to filter attribute.
:type expression: str
:param partition_names: The names of partitions to search.
:type partition_names: list[str]
:param output_fields: The fields to return in the search result, not supported now.
:type output_fields: list[str]
:param timeout: An optional duration of time in seconds to allow for the RPC. When timeout
is set to None, client waits until server response or error occur.
:type timeout: float
:param round_decimal: The specified number of decimal places of returned distance
:type round_decimal: int
:param kwargs:
* *_async* (``bool``) --
Indicate if invoke asynchronously. When value is true, method returns a SearchFuture object;
otherwise, method returns results from server.
* *_callback* (``function``) --
The callback function which is invoked after server response successfully. It only take
effect when _async is set to True.
:return: Query result. QueryResult is iterable and is a 2d-array-like class, the first dimension is
the number of vectors to query (nq), the second dimension is the number of limit(topk).
:rtype: QueryResult
:raises RpcError: If gRPC encounter an error
:raises ParamError: If parameters are invalid
:raises BaseException: If the return result from server is not ok
"""
check_pass_param(
limit=limit,
round_decimal=round_decimal,
anns_field=anns_field,
search_ids=search_ids,
partition_name_array=partition_names,
output_fields=output_fields,
)
with self._connection() as handler:
kwargs["_deploy_mode"] = self._deploy_mode
return handler.search_by_id(collection_name, search_ids, anns_field, param, limit, expression,
partition_names, output_fields, timeout, round_decimal, **kwargs)


@retry_on_rpc_failure(retry_times=10, wait=1)
def calc_distance(self, vectors_left, vectors_right, params=None, timeout=None, **kwargs):
Expand Down
120 changes: 64 additions & 56 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

84 changes: 84 additions & 0 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,90 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def search_by_id(self, search_ids, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
"""
Conducts a vector similarity search with an optional boolean expression as filter.
:param search_ids: List of ids of the vectors to search, the length of search_ids is number of query (nq).
:type search_ids: list[int]
:param anns_field: The vector field used to search of collection.
:type anns_field: str
:param param: The parameters of search, such as ``nprobe``.
:type param: dict
:param limit: The max number of returned record, also known as ``topk``.
:type limit: int
:param expr: The boolean expression used to filter attribute.
:type expr: str
:param partition_names: The names of partitions to search.
:type partition_names: list[str]
:param output_fields: The fields to return in the search result, not supported now.
:type output_fields: list[str]
:param timeout: An optional duration of time in seconds to allow for the RPC. When timeout
is set to None, client waits until server response or error occur.
:type timeout: float
:param round_decimal: The specified number of decimal places of returned distance
:type round_decimal: int
:param kwargs:
* *_async* (``bool``) --
Indicate if invoke asynchronously. When value is true, method returns a
SearchFuture object; otherwise, method returns results from server directly.
* *_callback* (``function``) --
The callback function which is invoked after server response successfully.
It functions only if _async is set to True.
:return: SearchResult:
SearchResult is iterable and is a 2d-array-like class, the first dimension is
the number of vectors to query (nq), the second dimension is the number of limit(topk).
:rtype: SearchResult
:raises RpcError: If gRPC encounter an error.
:raises ParamError: If parameters are invalid.
:raises DataTypeNotMatchException: If wrong type of param is passed.
:raises BaseException: If the return result from server is not ok.
Comment on lines +716 to +719
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the right format

:example:
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
<pymilvus.client.stub.Milvus object at 0x7f8579002dc0>
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
... ])
>>> collection = Collection("test_collection_search", schema)
>>> # insert
>>> data = [
... [i for i in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> collection.insert(data)
>>> collection.num_entities
10
>>> collection.load()
>>> # search
>>> search_param = {
... "search_ids": [1],
... "anns_field": "films",
... "param": {"metric_type": "L2"},
... "limit": 2,
... "expr": "film_id > 0",
... }
>>> res = collection.search_by_id(**search_param)
>>> assert len(res) == 1
>>> hits = res[0]
>>> assert len(hits) == 2
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ")
- Total hits: 2, hits ids: [1, 9]
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ")
- Top1 hit id: 1, distance: 0.0, score: 0.0
"""
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(0, ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
res = conn.search_by_id(self._name, search_ids, anns_field, param, limit, expr,
partition_names, output_fields, timeout, round_decimal, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def query(self, expr, output_fields=None, partition_names=None, timeout=None):
"""
Expand Down
79 changes: 79 additions & 0 deletions pymilvus/orm/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,85 @@ def search(self, data, anns_field, param, limit, expr=None, output_fields=None,
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def search_by_id(self, search_ids, anns_field, param, limit, expr=None, output_fields=None, timeout=None, round_decimal=-1,
**kwargs):
"""
Vector similarity search with an optional boolean expression as filters.
:param search_ids: List of ids of the vectors to search, the length of search_ids is number of query (nq).
:type search_ids: list[int]
:param anns_field: The vector field used to search of collection.
:type anns_field: str
:param param: The parameters of search, such as nprobe, etc.
:type param: dict
:param limit: The max number of returned record, we also called this parameter as topk.
:param round_decimal: The specified number of decimal places of returned distance
:type round_decimal: int
:param expr: The boolean expression used to filter attribute.
:type expr: str
:param output_fields: The fields to return in the search result, not supported now.
:type output_fields: list[str]
:param timeout: An optional duration of time in seconds to allow for the RPC. When timeout
is set to None, client waits until server response or error occur.
:type timeout: float
:param kwargs:
* *_async* (``bool``) --
Indicate if invoke asynchronously. When value is true, method returns a
SearchFuture object; otherwise, method returns results from server directly.
* *_callback* (``function``) --
The callback function which is invoked after server response successfully. It only
takes effect when _async is set to True.
:return: SearchResult:
SearchResult is iterable and is a 2d-array-like class, the first dimension is
the number of vectors to query (nq), the second dimension is the number of limit(topk).
:rtype: SearchResult
:raises RpcError: If gRPC encounter an error.
:raises ParamError: If parameters are invalid.
:raises BaseException: If the return result from server is not ok.
:example:
>>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
<pymilvus.client.stub.Milvus object at 0x7f8579002dc0>
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
... ])
>>> collection = Collection("test_collection_search", schema)
>>> partition = Partition(collection, "comedy", "comedy films")
>>> # insert
>>> data = [
... [i for i in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> partition.insert(data)
>>> partition.num_entities
10
>>> partition.load()
>>> # search
>>> search_param = {
... "search_ids": [1],
... "anns_field": "films",
... "param": {"metric_type": "L2"},
... "limit": 2,
... "expr": "film_id > 0",
... }
>>> res = partition.search_by_id(**search_param)
>>> assert len(res) == 1
>>> hits = res[0]
>>> assert len(hits) == 2
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ")
- Total hits: 2, hits ids: [1, 5]
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ")
- Top1 hit id: 1, distance: 0.0, score: 0.0
"""
conn = self._get_connection()
res = conn.search_by_id(self._collection.name, search_ids, anns_field, param, limit,
expr, [self._name], output_fields, timeout, round_decimal, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)


def query(self, expr, output_fields=None, timeout=None):
"""
Expand Down