Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(index): move sort logic out to base
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 4, 2019
1 parent 573d193 commit a2d55dd
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 37 deletions.
5 changes: 4 additions & 1 deletion gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def set_router_parser(parser=None):
_set_loadable_service_parser(parser)
parser.add_argument('--num_part', type=int, default=None,
help='explicitly set the number of parts of message')
parser.add_argument('--sort_response', type=bool, default=True,
help='sort the response (if exist) by the score')
parser.set_defaults(read_only=True)
return parser

Expand All @@ -213,7 +215,8 @@ def set_indexer_parser(parser=None):
if not parser:
parser = set_base_parser()
_set_loadable_service_parser(parser)

parser.add_argument('--sort_response', type=bool, default=True,
help='sort the response (if exist) by the score')
# encoder's port_out is indexer's port_in
parser.set_defaults(port_in=parser.get_default('port_out'),
port_out=parser.get_default('port_out') + 2,
Expand Down
20 changes: 18 additions & 2 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@
class BaseIndexer(TrainableBase):
def __init__(self,
normalize_fn: 'BaseScoreFn' = ModifierScoreFn(),
score_fn: 'BaseScoreFn' = ModifierScoreFn(), *args, **kwargs):
score_fn: 'BaseScoreFn' = ModifierScoreFn(),
is_big_score_similar: bool = False,
*args, **kwargs):
"""
Base indexer, a valid indexer must implement `add` and `query` methods
:type score_fn: advanced score function
:type normalize_fn: normalizing score function
:type is_big_score_similar: when set to true, then larger score means more similar
"""
super().__init__(*args, **kwargs)
self.normalize_fn = normalize_fn
self.score_fn = score_fn
self.is_big_score_similar = is_big_score_similar

def add(self, keys: Any, docs: Any, weights: List[float], *args, **kwargs):
pass
Expand Down Expand Up @@ -59,7 +68,14 @@ def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, *
r.chunk.doc_id = _doc_id
r.chunk.offset = _offset
r.chunk.weight = _weight
_score = get_unary_score(value=_relevance, name=self.__class__.__name__)
_score = get_unary_score(value=_relevance,
name=self.__class__.__name__,
operands=[
dict(name='doc_chunk',
doc_id=_doc_id,
offset=_offset),
dict(name='query_chunk',
offset=q_chunk.offset)])
_score = self.normalize_fn(_score)
_score = self.score_fn(_score, q_chunk, r.chunk)
r.score.CopyFrom(_score)
Expand Down
19 changes: 13 additions & 6 deletions gnes/indexer/chunk/bindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(self,
self.ef = ef
self.insert_iterations = insert_iterations
self.query_iterations = query_iterations

self.data_path = data_path
self._weight_norm = 2 ** 16 - 1

def post_init(self):
self.bindexer = IndexCore(self.num_bytes, 4, self.ef,
Expand All @@ -67,9 +65,18 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
keys, offsets = zip(*keys)
keys = np.array(keys, dtype=np.uint32).tobytes()
offsets = np.array(offsets, dtype=np.uint16).tobytes()
weights = np.array([w * self._weight_norm for w in weights], dtype=np.uint16).tobytes()
weights = self.float2uint_weight(weights).tobytes()
self.bindexer.index_trie(vectors.tobytes(), num_rows, keys, offsets, weights)

@staticmethod
def float2uint_weight(weights: List[float], norm: int = 2 ** 16 - 1):
weights = norm * np.array(weights)
return np.array(weights, dtype=np.uint16)

@staticmethod
def uint2float_weight(weight: int, norm: int = 2 ** 16 - 1):
return weight / norm

def query(self,
keys: np.ndarray,
top_k: int,
Expand All @@ -91,15 +98,15 @@ def query(self,
q_idx, doc_ids, offsets, weights = self.bindexer.find_batch_trie(
keys, num_rows)
for (i, q, o, w) in zip(doc_ids, q_idx, offsets, weights):
result[q].append((i, o, w / self._weight_norm, 1))
result[q].append((i, o, self.uint2float_weight(w), 0))

# search the indexed items with similar value
doc_ids, offsets, weights, dists, q_idx = self.bindexer.nsw_search(
keys, num_rows, top_k)
for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx):
if d == 0:
continue
result[q].append((i, o, w / self._weight_norm, d))
result[q].append((i, o, self.uint2float_weight(w), d))

# get the top-k
for q in range(num_rows):
Expand All @@ -108,7 +115,7 @@ def query(self,
doc_ids, offsets, weights, dists, q_idx = self.bindexer.force_search(
keys, num_rows, top_k)
for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx):
result[q].append((i, o, w / self._weight_norm, d))
result[q].append((i, o, self.uint2float_weight(w), d))
return result

def __getstate__(self):
Expand Down
14 changes: 11 additions & 3 deletions gnes/indexer/chunk/hbindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self,
self.n_clusters = num_clusters
self.n_idx = n_idx
self.data_path = data_path
self._weight_norm = 2 ** 16 - 1
if self.n_idx <= 0:
raise ValueError('There should be at least 1 clustering slot')

Expand All @@ -63,11 +62,20 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
keys, offsets = zip(*keys)
keys = np.array(keys, dtype=np.uint32).tobytes()
offsets = np.array(offsets, dtype=np.uint16).tobytes()
weights = np.array(weights * self._weight_norm, dtype=np.uint16).tobytes()
weights = self.float2uint_weight(weights).tobytes()
clusters = vectors[:, :self.n_idx].tobytes()
vectors = vectors[:, self.n_idx:].astype(np.uint8).tobytes()
self.hbindexer.index_trie(vectors, clusters, keys, offsets, weights, n)

@staticmethod
def float2uint_weight(weights: List[float], norm: int = 2 ** 16 - 1):
weights = norm * np.array(weights)
return np.array(weights, dtype=np.uint16)

@staticmethod
def uint2float_weight(weight: int, norm: int = 2 ** 16 - 1):
return weight / norm

def query(self,
vectors: np.ndarray,
top_k: int,
Expand All @@ -87,7 +95,7 @@ def query(self,
doc_ids, offsets, weights, dists, q_idx = self.hbindexer.query(
vectors, clusters, n, top_k * self.n_idx)
for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx):
result[q][(i, o, w / self._weight_norm)] = d
result[q][(i, o, self.uint2float_weight(w))] = d

return [list(ret.items()) for ret in result]

Expand Down
2 changes: 1 addition & 1 deletion gnes/indexer/chunk/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs
) -> List[List[Tuple]]:
keys = np.expand_dims(keys, axis=1)
dist = keys - np.expand_dims(self._vectors, axis=0)
score = 1 - np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes
score = np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes

ret = []
for ids in score:
Expand Down
3 changes: 2 additions & 1 deletion gnes/proto/gnes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ message Response {
Status status = 1;
uint32 top_k = 2;
repeated ScoredResult topk_results = 3;

bool is_big_score_similar = 4;
bool is_sorted = 5;
}
}

Expand Down
1 change: 0 additions & 1 deletion gnes/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
'DocFillReducer': 'reduce',
'PublishRouter': 'map',
'DocBatchRouter': 'map',
'SortedTopkRouter': 'map',
}

register_all_class(_cls2file_map, 'router')
9 changes: 3 additions & 6 deletions gnes/router/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *


class BaseTopkReduceRouter(BaseReduceRouter):
def __init__(self, reduce_op: str = 'sum', descending: bool = True, *args, **kwargs):
def __init__(self, reduce_op: str = 'sum', *args, **kwargs):
super().__init__(*args, **kwargs)
self._reduce_op = reduce_op
self.descending = descending

def post_init(self):
self.reduce_op = CombinedScoreFn(score_mode=self._reduce_op)
Expand All @@ -80,16 +79,14 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *

# count score by iterating over chunks
for c in all_scored_results:
k = self.get_key(c)
score_dict[k].append(c.score)
score_dict[self.get_key(c)].append(c.score)

for k, v in score_dict.items():
score_dict[k] = self.reduce_op(*v)

msg.response.search.ClearField('topk_results')

# sort and add docs
for k, v in sorted(score_dict.items(), key=lambda kv: kv[1].value, reverse=self.descending):
for k, v in score_dict.items():
r = msg.response.search.topk_results.add()
r.score.CopyFrom(v)
self.set_key(r, k)
Expand Down
9 changes: 0 additions & 9 deletions gnes/router/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@
from ..proto import gnes_pb2


class SortedTopkRouter(BaseMapRouter):
def __init__(self, descending: bool = True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.descending = descending

def apply(self, msg: 'gnes_pb2.Message', *args, **kwargs):
msg.response.search.topk_results.sort(key=lambda x: x.score.value, reverse=self.descending)


class PublishRouter(BaseMapRouter):

def __init__(self, num_part: int, *args, **kwargs):
Expand Down
13 changes: 11 additions & 2 deletions gnes/router/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@


class DocFillReducer(BaseReduceRouter):
"""
Gather all documents raw content from multiple shards.
This is only useful when you have
- multiple doc-indexer and docs are spreaded over multiple shards.
- require full-doc retrieval with the original content, not just an doc id
Ideally, only each doc can only belong to one shard.
"""
def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs):
final_docs = []
for idx in range(len(msg.response.search.topk_results)):
Expand All @@ -45,7 +52,9 @@ def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):

class Chunk2DocTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their doc_id, result in a topk doc list
Gather all chunks by their doc_id, result in a topk doc list.
This is almost always useful, as the final result should be group by doc_id
not chunk
"""

def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
Expand All @@ -57,7 +66,7 @@ def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):

class ChunkTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their chunk_id, aka doc_id-offset, result in a topk chunk list
Gather all chunks by their chunk_id from all shards, aka doc_id-offset, result in a topk chunk list
"""

def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
Expand Down
12 changes: 12 additions & 0 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ def dump(self):
else:
self.logger.info('no dumping as "read_only" set to true.')

def post_handler(self, msg: 'gnes_pb2.Message'):
if 'sort_result' in self.args.sort_result and self.args.sort_result and msg.response.search.topk_results:
msg.response.search.topk_results.sort(key=lambda x: x.score.value,
reverse=msg.response.search.is_big_score_similar)

msg.response.search.is_sorted = True
self.logger.info('sorted %d results in %s order' %
(len(msg.response.search.topk_results),
'descending' if msg.response.search.is_big_score_similar else 'ascending'))

def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck):
try:
fn = self.handler.serve(msg)
Expand All @@ -273,9 +283,11 @@ def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck):
ret = fn(self, msg)
if ret is None:
# assume 'msg' is modified inside fn()
self.post_handler(msg)
send_message(out_sock, msg, timeout=self.args.timeout)
elif isinstance(ret, types.GeneratorType):
for r_msg in ret:
self.post_handler(r_msg)
send_message(out_sock, r_msg, timeout=self.args.timeout)
else:
raise ServiceError('unknown return type from the handler: %s' % fn)
Expand Down
25 changes: 20 additions & 5 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,23 @@ def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
[d for d in msg.request.index.docs],
[d.weight for d in msg.request.index.docs])

def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'):
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)
msg.response.search.top_k = len(results)
msg.response.search.is_big_score_similar = self._model.is_big_score_similar

@handler.register(gnes_pb2.Request.QueryRequest)
def _handler_chunk_search(self, msg: 'gnes_pb2.Message'):
from ..indexer.base import BaseChunkIndexer
if not isinstance(self._model, BaseChunkIndexer):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)

# assume the chunk search will change the whatever sort order the message has
msg.response.search.is_sorted = False
results = self._model.query_and_score(msg.request.search.query.chunks, top_k=msg.request.search.top_k)
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)
msg.response.search.top_k = len(results)
self._put_result_into_message(results, msg)

@handler.register(gnes_pb2.Response.QueryResponse)
def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
Expand All @@ -82,6 +88,15 @@ def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)

# check if chunk_indexer and doc_indexer has the same sorting order
if msg.response.search.is_big_score_similar is not None and \
msg.response.search.is_big_score_similar != self._model.is_big_score_similar:
raise ServiceError(
'is_big_score_similar is inconsistent. last topk-list: is_big_score_similar=%s, but '
'this indexer: is_big_score_similar=%s' % (
msg.response.search.is_big_score_similar, self._model.is_big_score_similar))

# assume the doc search will change the whatever sort order the message has
msg.response.search.is_sorted = False
results = self._model.query_and_score(msg.response.search.topk_results)
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)
self._put_result_into_message(results, msg)

0 comments on commit a2d55dd

Please sign in to comment.