Skip to content

Commit

Permalink
Support search by id
Browse files Browse the repository at this point in the history
Signed-off-by: Jellal-HT <ychuht333@gmail.com>
Signed-off-by: Yicheng Hu <ychuht333@gmail.com>
  • Loading branch information
Jellal-HT committed Oct 19, 2021
1 parent fff6b96 commit bade352
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 1 deletion.
9 changes: 9 additions & 0 deletions grpc-proto/gen/milvus_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions grpc-proto/gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def __init__(self, channel):
request_serializer=milvus__pb2.GetMetricsRequest.SerializeToString,
response_deserializer=milvus__pb2.GetMetricsResponse.FromString,
)
self.GetVectorsByID = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/GetVectorsByID',
request_serializer=milvus__pb2.VectorIDs.SerializeToString,
response_deserializer=milvus__pb2.VectorsArray.FromString,
)


class MilvusServiceServicer(object):
Expand Down Expand Up @@ -395,6 +400,11 @@ def GetMetrics(self, request, context):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetVectorsByID(self, request, context):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_MilvusServiceServicer_to_server(servicer, server):
Expand Down Expand Up @@ -569,6 +579,11 @@ def add_MilvusServiceServicer_to_server(servicer, server):
request_deserializer=milvus__pb2.GetMetricsRequest.FromString,
response_serializer=milvus__pb2.GetMetricsResponse.SerializeToString,
),
'GetVectorsByID': grpc.unary_unary_rpc_method_handler(
servicer.GetVectorsByID,
request_deserializer=milvus__pb2.VectorIDs.FromString,
response_serializer=milvus__pb2.VectorsArray.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'milvus.proto.milvus.MilvusService', rpc_method_handlers)
Expand Down Expand Up @@ -1156,6 +1171,23 @@ def GetMetrics(request,
milvus__pb2.GetMetricsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetVectorsByID(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/GetVectorsByID',
milvus__pb2.GetMetricsRequest.SerializeToString,
milvus__pb2.GetMetricsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)


class ProxyServiceStub(object):
Expand Down
1 change: 1 addition & 0 deletions grpc-proto/milvus.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ service MilvusService {
rpc Flush(FlushRequest) returns (FlushResponse) {}
rpc Query(QueryRequest) returns (QueryResults) {}
rpc CalcDistance(CalcDistanceRequest) returns (CalcDistanceResults) {}
rpc GetVectorsByID(VectorIDs) returns (VectorsArray) {}

rpc GetPersistentSegmentInfo(GetPersistentSegmentInfoRequest) returns (GetPersistentSegmentInfoResponse) {}
rpc GetQuerySegmentInfo(GetQuerySegmentInfoRequest) returns (GetQuerySegmentInfoResponse) {}
Expand Down
38 changes: 38 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,44 @@ def search(self, collection_name, data, anns_field, param, limit,
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
return self._execute_search_requests(requests, timeout, **_kwargs)

def search_by_id(self, collection_name, query_id, anns_field, param, limit,
expression = None, partition_tags=None, output_fields = None,
timeout=None, round_decimal=-1, **kwargs):

## first part of the method: get the vector by id
rf = self._stub.HasCollection.future(Prepare.has_collection_request(collection_name), wait_for_ready=True,
timeout=timeout)
reply = rf.result()
if reply.status.error_code != 0 or not reply.value:
raise CollectionNotExistException(reply.status.error_code, "collection not exists")

request = milvus_types.VectorIDs(collection_name=collection_name, field_name = None, id_array=query_id,
partition_names=partition_tags)

future = self._stub.GetVectorsByID.future(request, wait_for_ready=True, timeout=timeout)
response = future.result()
## variable that stores the vector corresponding to the id
vector = list()
if response.data_array == None:
print("can not obtain vector")
return
else:
for datas in response.data_array:
data = bytes(datas.binary_data) or list(datas.float_data)
## vector corresponding to the query_id
vector.append(data)

## second part of the method: do vector querying
_kwargs = copy.deepcopy(kwargs)
schema = self.self.describe_collection(collection_name, timeout)
_kwargs["schema"] = schema
_kwargs["auto_id"] = schema["auto_id"]
_kwargs["round_decimal"] = round_decimal
requests = Prepare.search_requests_with_expr(collection_name, vector, anns_field, param, limit, expression,
partition_tags, output_fields, round_decimal, **_kwargs)
return self._execute_search_requests(requests, timeout, **_kwargs)


@error_handler(None)
def get_query_segment_infos(self, collection_name, timeout=30, **kwargs):
Expand Down
9 changes: 9 additions & 0 deletions pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,15 @@ def search(self, collection_name, data, anns_field, param, limit, expression=Non
kwargs["_deploy_mode"] = self._deploy_mode
return handler.search(collection_name, data, anns_field, param, limit, expression,
partition_names, output_fields, timeout, round_decimal, **kwargs)

def search_by_id(self, collection_name, query_id, anns_field, param, limit,
expression = None, partition_tags=None, output_fields = None,
timeout=None, round_decimal=-1, **kwargs):
with self._connection() as handler:
kwargs["_deploy_mode"] = self._deploy_mode
return handler.search_by_id(collection_name, query_id, anns_field, param, limit, expression,
partition_tags, output_fields, timeout, round_decimal, **kwargs)


@retry_on_rpc_failure(retry_times=10, wait=1)
def calc_distance(self, vectors_left, vectors_right, params=None, timeout=None, **kwargs):
Expand Down
9 changes: 9 additions & 0 deletions pymilvus/grpc_gen/milvus_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions pymilvus/grpc_gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def __init__(self, channel):
request_serializer=milvus__pb2.GetMetricsRequest.SerializeToString,
response_deserializer=milvus__pb2.GetMetricsResponse.FromString,
)
self.GetVectorsByID = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/GetVectorsByID',
request_serializer=milvus__pb2.VectorIDs.SerializeToString,
response_deserializer=milvus__pb2.VectorsArray.FromString,
)


class MilvusServiceServicer(object):
Expand Down Expand Up @@ -395,6 +400,12 @@ def GetMetrics(self, request, context):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetVectorsByID(self, request, context):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')



def add_MilvusServiceServicer_to_server(servicer, server):
Expand Down Expand Up @@ -569,6 +580,11 @@ def add_MilvusServiceServicer_to_server(servicer, server):
request_deserializer=milvus__pb2.GetMetricsRequest.FromString,
response_serializer=milvus__pb2.GetMetricsResponse.SerializeToString,
),
'GetVectorsByID': grpc.unary_unary_rpc_method_handler(
servicer.GetVectorsByID,
request_deserializer=milvus__pb2.VectorIDs.FromString,
response_serializer=milvus__pb2.VectorsArray.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'milvus.proto.milvus.MilvusService', rpc_method_handlers)
Expand Down Expand Up @@ -1156,6 +1172,23 @@ def GetMetrics(request,
milvus__pb2.GetMetricsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

def GetVectorsByID(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/GetVectorsByID',
milvus__pb2.GetMetricsRequest.SerializeToString,
milvus__pb2.GetMetricsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)



class ProxyServiceStub(object):
Expand Down
16 changes: 15 additions & 1 deletion pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,20 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def searchByID(self, query_id, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):

if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(0, ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
res = conn.search_by_id(self._name, query_id, anns_field, param, limit, expr,
partition_names, output_fields, timeout, round_decimal, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)


def query(self, expr, output_fields=None, partition_names=None, timeout=None):
"""
Expand Down Expand Up @@ -1200,4 +1214,4 @@ def alter_alias(self, alias, timeout=None, **kwargs):
otherwise return Status(code=1, message='alias does not exist')
"""
conn = self._get_connection()
conn.alter_alias(self._name, alias, timeout=timeout, **kwargs)
conn.alter_alias(self._name, alias, timeout=timeout, **kwargs)
10 changes: 10 additions & 0 deletions pymilvus/orm/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ def search(self, data, anns_field, param, limit, expr=None, output_fields=None,
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def searchByID(self, query_id, anns_field, param, limit, expr=None, output_fields=None, timeout=None, round_decimal=-1,
**kwargs):
conn = self._get_connection()
res = conn.search_by_id(self._collection.name, query_id, anns_field, param, limit,
expr, [self._name], output_fields, timeout, round_decimal, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)


def query(self, expr, output_fields=None, timeout=None):
"""
Expand Down

0 comments on commit bade352

Please sign in to comment.