Skip to content

Commit

Permalink
refactor ventilator-worker socket
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Dec 18, 2018
1 parent 94d9065 commit eafb30f
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 45 deletions.
86 changes: 41 additions & 45 deletions server/bert_serving/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Han Xiao <artex.xh@gmail.com> <https://hanxiao.github.io>
import multiprocessing
import os
import random
import sys
import threading
import time
Expand All @@ -18,11 +19,13 @@
from zmq.utils import jsonapi

from .helper import *
from .zmq_decor import multi_socket

__all__ = ['__version__', 'BertServer']
__version__ = '1.5.9'

_tf_ver_ = check_tf_version()
_num_socket_ = 8 # optimize concurrency for multi-clients


class ServerCommand:
Expand Down Expand Up @@ -78,23 +81,19 @@ def run(self):
@zmqd.context()
@zmqd.socket(zmq.PULL)
@zmqd.socket(zmq.PAIR)
@zmqd.socket(zmq.PUSH)
@zmqd.socket(zmq.PUSH)
def _run(self, _, frontend, sink, backend, backend_hprio):
@multi_socket(zmq.PUSH, num_socket=8)
def _run(self, _, frontend, sink, backend_socks):

def push_new_job(_job_id, _json_msg, _msg_len):
""" push to backend based on the msg length """
if _msg_len <= self.args.priority_batch_size:
backend_hprio.send_multipart([_job_id, _json_msg])
else:
backend.send_multipart([_job_id, _json_msg])
# backend_socks[0] is always at the highest priority
_sock = backend_socks[0] if _msg_len <= self.args.priority_batch_size else rand_backend_socket
_sock.send_multipart([_job_id, _json_msg])

# bind all sockets
self.logger.info('bind all sockets')
frontend.bind('tcp://*:%d' % self.port)
addr_front2sink = auto_bind(sink)
addr_backend = auto_bind(backend)
addr_backend_hprio = auto_bind(backend_hprio) # a new socket for high priority job
addr_backend_list = [auto_bind(b) for b in backend_socks]

# start the sink process
self.logger.info('start the sink')
Expand All @@ -106,7 +105,7 @@ def push_new_job(_job_id, _json_msg, _msg_len):
# start the backend processes
device_map = self._get_device_map()
for idx, device_id in enumerate(device_map):
process = BertWorker(idx, self.args, addr_backend, addr_backend_hprio, addr_sink, device_id,
process = BertWorker(idx, self.args, addr_backend_list, addr_sink, device_id,
self.graph_path)
self.processes.append(process)
process.start()
Expand All @@ -123,14 +122,14 @@ def push_new_job(_job_id, _json_msg, _msg_len):
self.logger.info('new config request\treq id: %d\tclient: %s' % (int(req_id), client))
status_runtime = {'client': client.decode('ascii'),
'num_process': len(self.processes),
'ventilator -> worker': addr_backend,
'ventilator -> worker (priority)': addr_backend_hprio,
'ventilator -> worker': addr_backend_list,
'worker -> sink': addr_sink,
'ventilator <-> sink': addr_front2sink,
'server_current_time': str(datetime.now()),
'num_config_request': num_req['config'],
'num_data_request': num_req['data'],
'device_map': device_map}
'device_map': device_map,
'num_sockets': _num_socket_}

sink.send_multipart([client, msg, jsonapi.dumps({**status_runtime,
**self.status_args,
Expand All @@ -141,7 +140,11 @@ def push_new_job(_job_id, _json_msg, _msg_len):
(int(req_id), int(msg_len), client))
# register a new job at sink
sink.send_multipart([client, ServerCommand.new_job, msg_len, req_id])
# renew the backend socket to prevent large job queueing up
rand_backend_socket = random.choice(backend_socks[1:])

# push a new job, note super large job will be pushed to one socket only,
# leaving other sockets free
job_id = client + b'#' + req_id
if int(msg_len) > self.max_batch_size:
seqs = jsonapi.loads(msg)
Expand Down Expand Up @@ -278,16 +281,15 @@ def _run(self, receiver, frontend, sender):


class BertWorker(Process):
def __init__(self, id, args, worker_address, worker_address_hprio, sink_address, device_id, graph_path):
def __init__(self, id, args, worker_address_list, sink_address, device_id, graph_path):
super().__init__()
self.worker_id = id
self.device_id = device_id
self.logger = set_logger(colored('WORKER-%d' % self.worker_id, 'yellow'), args.verbose)
self.max_seq_len = args.max_seq_len
self.daemon = True
self.exit_flag = multiprocessing.Event()
self.worker_address = worker_address
self.worker_address_hprio = worker_address_hprio
self.worker_address = worker_address_list
self.sink_address = sink_address
self.prefetch_factor = 10
self.gpu_memory_fraction = args.gpu_memory_fraction
Expand Down Expand Up @@ -336,57 +338,51 @@ def model_fn(features, labels, mode, params):
def run(self):
self._run()

@zmqd.socket(zmq.PULL)
@zmqd.socket(zmq.PULL)
@zmqd.socket(zmq.PUSH)
def _run(self, receiver, receiver_hprio, sink):
@multi_socket(zmq.PULL, num_socket=_num_socket_)
def _run(self, sink, receivers):
self.logger.info('use device %s, load graph from %s' %
('cpu' if self.device_id < 0 else ('gpu: %d' % self.device_id), self.graph_path))

tf = import_tf(self.device_id, self.verbose)
estimator = self.get_estimator(tf)

receiver.connect(self.worker_address)
receiver_hprio.connect(self.worker_address_hprio)
for sock, addr in zip(receivers, self.worker_address):
sock.connect(addr)

sink.connect(self.sink_address)
for r in estimator.predict(self.input_fn_builder(receiver, receiver_hprio, tf), yield_single_examples=False):
for r in estimator.predict(self.input_fn_builder(receivers, tf), yield_single_examples=False):
send_ndarray(sink, r['client_id'], r['encodes'])
self.logger.info('job done\tsize: %s\tclient: %s' % (r['encodes'].shape, r['client_id']))

def input_fn_builder(self, sock, sock_hprio, tf):
def input_fn_builder(self, socks, tf):
from .bert.extract_features import convert_lst_to_features
from .bert.tokenization import FullTokenizer

def gen():
def build_job(socket):
client_id, raw_msg = socket.recv_multipart()
msg = jsonapi.loads(raw_msg)
self.logger.info('new job\tsize: %d\tclient: %s' % (len(msg), client_id))
# check if msg is a list of list, if yes consider the input is already tokenized
is_tokenized = all(isinstance(el, list) for el in msg)
tmp_f = list(convert_lst_to_features(msg, self.max_seq_len, tokenizer, is_tokenized))
return {
'client_id': client_id,
'input_ids': [f.input_ids for f in tmp_f],
'input_mask': [f.input_mask for f in tmp_f],
'input_type_ids': [f.input_type_ids for f in tmp_f]
}

tokenizer = FullTokenizer(vocab_file=os.path.join(self.model_dir, 'vocab.txt'))
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
poller.register(sock_hprio, zmq.POLLIN)
for sock in socks:
poller.register(sock, zmq.POLLIN)

self.logger.info('ready and listening!')

while not self.exit_flag.is_set():
events = dict(poller.poll())
if sock_hprio in events:
self.logger.info('a high priority job received')
yield build_job(sock_hprio)
if sock in events:
yield build_job(sock)
for sock_idx, sock in enumerate(socks):
if sock in events:
client_id, raw_msg = sock.recv_multipart()
msg = jsonapi.loads(raw_msg)
self.logger.info('new job\tsocket: %d\tsize: %d\tclient: %s' % (sock_idx, len(msg), client_id))
# check if msg is a list of list, if yes consider the input is already tokenized
is_tokenized = all(isinstance(el, list) for el in msg)
tmp_f = list(convert_lst_to_features(msg, self.max_seq_len, tokenizer, is_tokenized))
yield {
'client_id': client_id,
'input_ids': [f.input_ids for f in tmp_f],
'input_mask': [f.input_mask for f in tmp_f],
'input_type_ids': [f.input_type_ids for f in tmp_f]
}

def input_fn():
return (tf.data.Dataset.from_generator(
Expand Down
60 changes: 60 additions & 0 deletions server/bert_serving/server/zmq_decor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from contextlib import ExitStack

from zmq.decorators import _Decorator

__all__ = ['multi_socket']

from functools import wraps

import zmq


class _MyDecorator(_Decorator):
def __call__(self, *dec_args, **dec_kwargs):
kw_name, dec_args, dec_kwargs = self.process_decorator_args(*dec_args, **dec_kwargs)
num_socket = dec_kwargs.pop('num_socket')

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
targets = [self.get_target(*args, **kwargs) for _ in range(num_socket)]
with ExitStack() as stack:
for target in targets:
obj = stack.enter_context(target(*dec_args, **dec_kwargs))
args = args + (obj,)

return func(*args, **kwargs)

return wrapper

return decorator


class _SocketDecorator(_MyDecorator):
def process_decorator_args(self, *args, **kwargs):
"""Also grab context_name out of kwargs"""
kw_name, args, kwargs = super(_SocketDecorator, self).process_decorator_args(*args, **kwargs)
self.context_name = kwargs.pop('context_name', 'context')
return kw_name, args, kwargs

def get_target(self, *args, **kwargs):
"""Get context, based on call-time args"""
context = self._get_context(*args, **kwargs)
return context.socket

def _get_context(self, *args, **kwargs):
if self.context_name in kwargs:
ctx = kwargs[self.context_name]

if isinstance(ctx, zmq.Context):
return ctx

for arg in args:
if isinstance(arg, zmq.Context):
return arg
# not specified by any decorator
return zmq.Context.instance()


def multi_socket(*args, **kwargs):
return _SocketDecorator()(*args, **kwargs)

0 comments on commit eafb30f

Please sign in to comment.