Skip to content
Permalink
Browse files

feat(proto): speedup send/recv by separating raw_bytes from pb

  • Loading branch information...
hanxiao committed Sep 25, 2019
1 parent 82951d9 commit 10788951261f102dd121b191db807f18336b69f6
Showing with 204 additions and 43 deletions.
  1. +8 −8 gnes/client/base.py
  2. +96 −30 gnes/proto/__init__.py
  3. +5 −5 gnes/service/frontend.py
  4. +95 −0 tests/test_raw_bytes_send.py
@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, List, Union

import grpc
import zmq
from termcolor import colored
from typing import Tuple, List, Union, Type

from ..helper import set_logger
from ..proto import gnes_pb2_grpc
from ..proto import send_message, gnes_pb2, recv_message
from ..proto import send_message as _send_message, gnes_pb2, recv_message as _recv_message
from ..service.base import build_socket


@@ -97,15 +98,14 @@ def close(self):
self.receiver.close()
self.ctx.term()

def send_message(self, message: "gnes_pb2.Message", timeout: int = -1):
def send_message(self, message: "gnes_pb2.Message", **kwargs):
self.logger.debug('send message: %s' % message.envelope)
send_message(self.sender, message, timeout=timeout)
_send_message(self.sender, message, **kwargs)

def recv_message(self, timeout: int = -1) -> gnes_pb2.Message:
r = recv_message(
def recv_message(self, **kwargs) -> gnes_pb2.Message:
r = _recv_message(
self.receiver,
timeout=timeout,
check_version=self.args.check_version)
check_version=self.args.check_version, **kwargs)
self.logger.debug('recv a message: %s' % r.envelope)
return r

@@ -15,7 +15,7 @@

import ctypes
import random
from typing import List, Iterator
from typing import List, Iterator, Tuple
from typing import Optional

import numpy as np
@@ -121,14 +121,98 @@ def merge_routes(msg: 'gnes_pb2.Message', prev_msgs: List['gnes_pb2.Message']):
msg.envelope.routes.extend(sorted(routes.values(), key=lambda x: (x.start_time.seconds, x.start_time.nanos)))


def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1) -> None:
def is_valid_version(msg: 'gnes_pb2.Message'):
from .. import __version__, __proto_version__
if hasattr(msg.envelope, 'gnes_version'):
if not msg.envelope.gnes_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "gnes_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __version__ != msg.envelope.gnes_version:
raise AttributeError('mismatched GNES version! '
'incoming message has GNES version %s, whereas local GNES version %s' % (
msg.envelope.gnes_version, __version__))

if hasattr(msg.envelope, 'proto_version'):
if not msg.envelope.proto_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "proto_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __proto_version__ != msg.envelope.proto_version:
raise AttributeError('mismatched protobuf version! '
'incoming message has protobuf version %s, whereas local protobuf version %s' % (
msg.envelope.proto_version, __proto_version__))

if not hasattr(msg.envelope, 'proto_version') and not hasattr(msg.envelope, 'gnes_version'):
raise AttributeError('version_check=True locally, '
'but incoming message contains no version info in its envelope. '
'the message is probably sent from a very outdated GNES version')


def extract_raw_bytes_from_msg(msg: 'gnes_pb2.Message') -> Tuple[Optional[List[bytes]], Optional[List[bytes]]]:
doc_bytes = [msg.envelope.client_id.encode()]
chunk_bytes = [msg.envelope.client_id.encode()]

# for train request
for d in msg.request.train.docs:
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
for c in d.chunks:
chunk_bytes.append(c.raw)
c.ClearField('raw')

# for index request
for d in msg.request.index.docs:
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
for c in d.chunks:
chunk_bytes.append(c.raw)
c.ClearField('raw')

# for query
if msg.request.search.query.raw_bytes:
doc_bytes.append(msg.request.search.query.raw_bytes)
msg.request.search.query.ClearField('raw_bytes')

for c in msg.request.search.query.chunks:
chunk_bytes.append(c.raw)
c.ClearField('raw')

return doc_bytes, chunk_bytes


def fill_raw_bytes_to_msg(msg: 'gnes_pb2.Message', doc_raw_bytes: Optional[List[bytes]],
chunk_raw_bytes: Optional[List[bytes]]):
c_idx = 0
d_idx = 0
for d in msg.request.train.docs:
if doc_raw_bytes and doc_raw_bytes[d_idx]:
d.raw_bytes = doc_raw_bytes[d_idx]
d_idx += 1
for c in d.chunks:
if chunk_raw_bytes and chunk_raw_bytes[c_idx]:
c.raw = chunk_raw_bytes[c_idx]
c_idx += 1


def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1,
raw_bytes_in_separate: bool = False) -> None:
try:
if timeout > 0:
sock.setsockopt(zmq.SNDTIMEO, timeout)
else:
sock.setsockopt(zmq.SNDTIMEO, -1)

sock.send_multipart([msg.envelope.client_id.encode(), msg.SerializeToString()])
if not raw_bytes_in_separate:
sock.send_multipart([msg.envelope.client_id.encode(), b'0', msg.SerializeToString()])
else:
doc_bytes, chunk_bytes = extract_raw_bytes_from_msg(msg)
# now raw_bytes are removed from message, hoping for faster de/serialization
sock.send_multipart([msg.envelope.client_id.encode(), b'1', msg.SerializeToString()])
sock.send_multipart(doc_bytes)
sock.send_multipart(chunk_bytes)
except zmq.error.Again:
raise TimeoutError(
'cannot send message to sock %s after timeout=%dms, please check the following:'
@@ -148,36 +232,18 @@ def recv_message(sock: 'zmq.Socket', timeout: int = -1, check_version: bool = Fa
else:
sock.setsockopt(zmq.RCVTIMEO, -1)

_, msg_data = sock.recv_multipart()
msg = gnes_pb2.Message()
_, raw_bytes_in_separate, msg_data = sock.recv_multipart()
msg.ParseFromString(msg_data)

if check_version and msg.envelope:
from .. import __version__, __proto_version__
if hasattr(msg.envelope, 'gnes_version'):
if not msg.envelope.gnes_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "gnes_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __version__ != msg.envelope.gnes_version:
raise AttributeError('mismatched GNES version! '
'incoming message has GNES version %s, whereas local GNES version %s' % (
msg.envelope.gnes_version, __version__))
if hasattr(msg.envelope, 'proto_version'):
if not msg.envelope.proto_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "proto_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __proto_version__ != msg.envelope.proto_version:
raise AttributeError('mismatched protobuf version! '
'incoming message has protobuf version %s, whereas local protobuf version %s' % (
msg.envelope.proto_version, __proto_version__))
if not hasattr(msg.envelope, 'proto_version') and not hasattr(msg.envelope, 'gnes_version'):
raise AttributeError('version_check=True locally, '
'but incoming message contains no version info in its envelope. '
'the message is probably sent from a very outdated GNES version')
if check_version:
is_valid_version(msg)

# now we have a barebone msg, we need to fill in data
if raw_bytes_in_separate == b'1':
doc_bytes = sock.recv_multipart()
chunk_bytes = sock.recv_multipart()
fill_raw_bytes_to_msg(msg, doc_bytes[1:], chunk_bytes[1:])
return msg

except ValueError:
@@ -88,8 +88,8 @@ def remove_envelope(self, m: 'gnes_pb2.Message'):

def Call(self, request, context):
with self.zmq_context as zmq_client:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))
zmq_client.send_message(self.add_envelope(request, zmq_client), timeout=self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(timeout=self.args.timeout))

def Train(self, request, context):
return self.Call(request, context)
@@ -105,16 +105,16 @@ def StreamCall(self, request_iterator, context):
num_request = 0

for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), -1)
zmq_client.send_message(self.add_envelope(request, zmq_client), timeout=-1)
num_request += 1

if zmq_client.receiver.poll(1):
msg = zmq_client.recv_message(self.args.timeout)
msg = zmq_client.recv_message(timeout=self.args.timeout)
num_request -= 1
yield self.remove_envelope(msg)

for _ in range(num_request):
msg = zmq_client.recv_message(self.args.timeout)
msg = zmq_client.recv_message(timeout=self.args.timeout)
yield self.remove_envelope(msg)

class ZmqContext:
@@ -0,0 +1,95 @@
import random
import unittest

from gnes.cli.parser import _set_client_parser
from gnes.client.base import ZmqClient
from gnes.helper import TimeContext
from gnes.proto import gnes_pb2
from gnes.service.base import SocketType


class TestProto(unittest.TestCase):
def setUp(self):
self.c1_args = _set_client_parser().parse_args([
'--port_in', str(5678),
'--port_out', str(5679),
'--socket_out', str(SocketType.PUSH_BIND),
'--no-check_version'
])
self.c2_args = _set_client_parser().parse_args([
'--port_in', str(self.c1_args.port_out),
'--port_out', str(self.c1_args.port_in),
'--socket_in', str(SocketType.PULL_CONNECT),
'--no-check_version'
])

def test_send_recv(self):
with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
msg = gnes_pb2.Message()
msg.envelope.client_id = c1.args.identity
d = msg.request.index.docs.add()
d.raw_bytes = b'aa'
c1.send_message(msg)
r_msg = c2.recv_message()
self.assertEqual(r_msg.request.index.docs[0].raw_bytes, d.raw_bytes)

def test_send_recv_raw_bytes(self):
with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
msg = gnes_pb2.Message()
msg.envelope.client_id = c1.args.identity
for j in range(random.randint(10, 20)):
d = msg.request.index.docs.add()
d.raw_bytes = b'a' * random.randint(100, 1000)
c1.send_message(msg, raw_bytes_in_separate=True)
r_msg = c2.recv_message()
for d, r_d in zip(msg.request.index.docs, r_msg.request.index.docs):
self.assertEqual(d.raw_bytes, r_d.raw_bytes)
print('.', end='')
print('checked %d docs' % len(msg.request.index.docs))

def test_benchmark(self):
all_msgs = []
num_msg = 20
for j in range(num_msg):
msg = gnes_pb2.Message()
msg.envelope.client_id = 'abc'
for j in range(random.randint(10, 20)):
d = msg.request.index.docs.add()
# each doc is about 1MB to 10MB
d.raw_bytes = b'a' * random.randint(1000000, 10000000)
all_msgs.append(msg)

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send, raw_bytes_in_separate=False'):
for m in all_msgs:
c1.send_message(m)
with TimeContext('recv, raw_bytes_in_separate=False'):
for _ in all_msgs:
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send, raw_bytes_in_separate=True'):
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=True)
with TimeContext('recv, raw_bytes_in_separate=True'):
for _ in all_msgs:
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, raw_bytes_in_separate=False'):
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=False)
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, raw_bytes_in_separate=True'):
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=True)
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=True)
r_m = c2.recv_message()
for d, r_d in zip(m.request.index.docs, r_m.request.index.docs):
self.assertEqual(d.raw_bytes, r_d.raw_bytes)

0 comments on commit 1078895

Please sign in to comment.
You can’t perform that action at this time.