Skip to content
Permalink
Browse files

fix(indexer): fix empty chunk and dump_interval

  • Loading branch information...
hanxiao committed Oct 10, 2019
1 parent 199a71a commit 72f4a044f5c5fd547d72b9b38f0ecac9a75427a8
Showing with 44 additions and 18 deletions.
  1. +11 −4 gnes/flow/__init__.py
  2. +19 −5 gnes/service/base.py
  3. +13 −8 gnes/service/indexer.py
  4. +1 −1 tests/test_gnes_flow.py
@@ -1,4 +1,5 @@
import copy
import os
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from functools import wraps
@@ -175,6 +176,9 @@ def query(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):

@_build_level(BuildLevel.RUNTIME)
def _call_client(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):

os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
args, p_args = self._get_parsed_args(self, set_client_cli_parser, kwargs)
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host
@@ -356,19 +360,20 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
#
# when a socket is BIND, then host must NOT be set, aka default host 0.0.0.0
# host_in and host_out is only set when corresponding socket is CONNECT
e_pargs.port_in = s_pargs.port_out

if len(edges_with_same_start) > 1 and len(edges_with_same_end) == 1:
s_pargs.socket_out = SocketType.PUB_BIND
s_pargs.host_out = BaseService.default_host
e_pargs.socket_in = SocketType.SUB_CONNECT
e_pargs.host_in = start_node
e_pargs.port_in = s_pargs.port_out
op_flow._service_edges[k] = 'PUB-sub'
elif len(edges_with_same_end) > 1 and len(edges_with_same_start) == 1:
s_pargs.socket_out = SocketType.PUSH_CONNECT
s_pargs.host_out = end_node
e_pargs.socket_in = SocketType.PULL_BIND
e_pargs.host_in = BaseService.default_host
s_pargs.port_out = e_pargs.port_in
op_flow._service_edges[k] = 'push-PULL'
elif len(edges_with_same_start) == 1 and len(edges_with_same_end) == 1:
# in this case, either side can be BIND
@@ -386,10 +391,12 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
if s_pargs.socket_out.is_bind:
s_pargs.host_out = BaseService.default_host
e_pargs.host_in = start_node
e_pargs.port_in = s_pargs.port_out
op_flow._service_edges[k] = 'PUSH-pull'
elif e_pargs.socket_in.is_bind:
s_pargs.host_out = end_node
e_pargs.host_in = BaseService.default_host
s_pargs.port_out = e_pargs.port_in
op_flow._service_edges[k] = 'push-PULL'
else:
raise FlowTopologyError('edge %s -> %s is ambiguous, at least one socket should be BIND')
@@ -423,7 +430,7 @@ def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *arg
# for thread and process backend which runs locally, host_in and host_out should not be set
p_args.host_in = BaseService.default_host
p_args.host_out = BaseService.default_host
op_flow._service_contexts.append((Flow._service2builder[v['service']], p_args))
op_flow._service_contexts.append(Flow._service2builder[v['service']](p_args))
op_flow._build_level = Flow.BuildLevel.RUNTIME
else:
raise NotImplementedError('backend=%s is not supported yet' % backend)
@@ -440,9 +447,9 @@ def __enter__(self):
'build the flow now via build() with default parameters' % self._build_level)
self.build(copy_flow=False)
self._service_stack = ExitStack()
for k, v in self._service_contexts:
self._service_stack.enter_context(k(v))

for k in self._service_contexts:
self._service_stack.enter_context(k)
self.logger.critical('flow is built and ready, current build level is %s' % self._build_level)
return self

@@ -156,8 +156,13 @@ def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketT

class MessageHandler:
def __init__(self, mh: 'MessageHandler' = None):
self.routes = {k: v for k, v in mh.routes.items()} if mh else {}
self.hooks = {k: v for k, v in mh.hooks.items()} if mh else {'pre': [], 'post': []}
self.routes = {}
self.hooks = {'pre': [], 'post': []}

if mh:
self.routes = copy.deepcopy(mh.routes)
self.hooks = copy.deepcopy(mh.hooks)

self.logger = set_logger(self.__class__.__name__)
self.service_context = None

@@ -329,6 +334,14 @@ def __init__(self, args):
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)
# self._override_handler()

def _override_handler(self):
# replace the function name by the function itself
mh = MessageHandler()
mh.routes = {k: getattr(self, v) for k, v in self.handler.routes.items()}
mh.hooks = {k: [(getattr(self, vv[0]), vv[1]) for vv in v] for k, v in self.handler.hooks.items()}
self.handler = mh

def run(self):
try:
@@ -341,9 +354,9 @@ def dump(self, respect_dump_interval: bool = True):
and self.args.dump_interval > 0
and self._model
and self.is_model_changed.is_set()
and (respect_dump_interval
and (time.perf_counter() - self.last_dump_time) > self.args.dump_interval)
or not respect_dump_interval):
and ((respect_dump_interval
and (time.perf_counter() - self.last_dump_time) > self.args.dump_interval)
or not respect_dump_interval)):
self.is_model_changed.clear()
self.logger.info('dumping changes to the model, %3.0fs since last the dump'
% (time.perf_counter() - self.last_dump_time))
@@ -385,6 +398,7 @@ def _hook_update_route_timestamp(self, msg: 'gnes_pb2.Message', *args, **kwargs)
def _run(self, ctx):
ctx.setsockopt(zmq.LINGER, 0)
self.handler.service_context = self
# print('!!!! t_id: %d service_context: %r' % (threading.get_ident(), self.handler.service_context))
self.logger.info('bind sockets...')
in_sock, _ = build_socket(ctx, self.args.host_in, self.args.port_in, self.args.socket_in,
self.args.identity)
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading

import numpy as np

@@ -25,10 +25,16 @@ class IndexerService(BS):

def post_init(self):
from ..indexer.base import BaseIndexer
# print('id: %s, before: %r' % (threading.get_ident(), self._model))
self._model = self.load_model(BaseIndexer)
self._tmp_a = threading.get_ident()
# print('id: %s, after: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
# print('tid: %s, model: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))
# if self._tmp_a != threading.get_ident():
# print('tid: %s, tmp_a: %r !!! %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
from ..indexer.base import BaseChunkIndexer, BaseDocIndexer
if isinstance(self._model, BaseChunkIndexer):
self._handler_chunk_index(msg)
@@ -41,22 +47,21 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
self.is_model_changed.set()

def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
vecs, doc_ids, offsets, weights = [], [], [], []
embed_info = []

for d in msg.request.index.docs:
if not d.chunks:
self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id)
continue

vecs += [blob2array(c.embedding) for c in d.chunks]
doc_ids += [d.doc_id] * len(d.chunks)
offsets += [c.offset for c in d.chunks]
weights += [c.weight for c in d.chunks]
embed_info += [(blob2array(c.embedding), d.doc_id, c.offset, c.weight) for c in d.chunks if
c.embedding.data]

if vecs:
if embed_info:
vecs, doc_ids, offsets, weights = zip(*embed_info)
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)
else:
self.logger.warning('chunks contain no embedded vectors, %the indexer will do nothing')
self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing')

def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
self._model.add([d.doc_id for d in msg.request.index.docs],
@@ -130,7 +130,7 @@ def _test_index_flow(self):
with flow.build(backend='thread') as f:
f.index(txt_file=self.test_file, batch_size=20)

for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
for k in [self.indexer1_bin, self.indexer2_bin]:
self.assertTrue(os.path.exists(k))

def _test_query_flow(self):

0 comments on commit 72f4a04

Please sign in to comment.
You can’t perform that action at this time.