Skip to content
Permalink
Browse files

feat(service): send ndarray separately

  • Loading branch information...
hanxiao committed Sep 25, 2019
1 parent 8a0beec commit 0ea566ffdfeb0cb9a6bdde656a2349042bbeffe3
Showing with 181 additions and 73 deletions.
  1. +3 −2 gnes/cli/parser.py
  2. +101 −54 gnes/proto/__init__.py
  3. +1 −1 gnes/service/base.py
  4. +1 −1 gnes/service/frontend.py
  5. +75 −15 tests/test_raw_bytes_send.py
@@ -182,8 +182,9 @@ def set_service_parser(parser=None):
help='identity of the service, empty by default')
parser.add_argument('--route_table', action=ActionNoYes, default=False,
help='showing a route table with time cost after receiving the result')
parser.add_argument('--raw_bytes_in_separate', action=ActionNoYes, default=True,
help='excluding raw_bytes from protobuf message, usually yields better network efficiency')
parser.add_argument('--squeeze_pb', action=ActionNoYes, default=True,
help='sending bytes and ndarray separately apart from the protobuf message, '
'usually yields better network efficiency')
return parser


@@ -151,70 +151,126 @@ def check_msg_version(msg: 'gnes_pb2.Message'):
'the message is probably sent from a very outdated GNES version')


def extract_raw_bytes_from_msg(msg: 'gnes_pb2.Message') -> Tuple[Optional[List[bytes]], Optional[List[bytes]]]:
doc_bytes = [msg.envelope.client_id.encode()]
chunk_bytes = [msg.envelope.client_id.encode()]
def extract_bytes_from_msg(msg: 'gnes_pb2.Message') -> Tuple:
doc_bytes = []
chunk_bytes = []
doc_byte_type = b''
chunk_byte_type = b''

docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
# for train request
for d in msg.request.train.docs:
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
for c in d.chunks:
chunk_bytes.append(c.raw)
c.ClearField('raw')
for d in docs:
# oneof raw_data {
# string raw_text = 5;
# NdArray raw_image = 6;
# NdArray raw_video = 7;
# bytes raw_bytes = 8; // for other types
# }
dtype = d.WhichOneof('raw_data') or ''
doc_byte_type = dtype.encode()
if dtype == 'raw_bytes':
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
elif dtype == 'raw_image':
doc_bytes.append(d.raw_image.data)
d.raw_image.ClearField('data')
elif dtype == 'raw_video':
doc_bytes.append(d.raw_video.data)
d.raw_video.ClearField('data')
elif dtype == 'raw_text':
doc_bytes.append(d.raw_text.encode())
d.ClearField('raw_text')

# for index request
for d in msg.request.index.docs:
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
for c in d.chunks:
chunk_bytes.append(c.raw)
c.ClearField('raw')

# for query
if msg.request.search.query.raw_bytes:
doc_bytes.append(msg.request.search.query.raw_bytes)
msg.request.search.query.ClearField('raw_bytes')

for c in msg.request.search.query.chunks:
chunk_bytes.append(c.raw)
c.ClearField('raw')

return doc_bytes, chunk_bytes
# oneof content {
# string text = 2;
# NdArray blob = 3;
# bytes raw = 7;
# }
chunk_bytes.append(c.embedding.data)
c.embedding.ClearField('data')

ctype = c.WhichOneof('content') or ''
chunk_byte_type = ctype.encode()
if ctype == 'raw':
chunk_bytes.append(c.raw)
c.ClearField('raw')
elif ctype == 'blob':
chunk_bytes.append(c.blob.data)
c.blob.ClearField('data')
elif ctype == 'text':
chunk_bytes.append(c.text.encode())
c.ClearField('text')

return doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type


def fill_raw_bytes_to_msg(msg: 'gnes_pb2.Message', msg_data: List[bytes]):
doc_byte_type = msg_data[3].decode()
chunk_byte_type = msg_data[4].decode()
doc_bytes_len = int(msg_data[5])
chunk_bytes_len = int(msg_data[6])

doc_bytes = msg_data[7:(7 + doc_bytes_len)]
chunk_bytes = msg_data[(7 + doc_bytes_len):]

if len(chunk_bytes) != chunk_bytes_len:
raise ValueError('"chunk_bytes_len"=%d in message, but the actual length is %d' % (
chunk_bytes_len, len(chunk_bytes)))


def fill_raw_bytes_to_msg(msg: 'gnes_pb2.Message', doc_raw_bytes: Optional[List[bytes]],
chunk_raw_bytes: Optional[List[bytes]]):
c_idx = 0
d_idx = 0
for d in msg.request.train.docs:
if doc_raw_bytes and doc_raw_bytes[d_idx]:
d.raw_bytes = doc_raw_bytes[d_idx]
d_idx += 1
docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
for d in docs:
if doc_bytes and doc_bytes[d_idx]:
if doc_byte_type == 'raw':
d.raw_bytes = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_image':
d.raw_image.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_video':
d.raw_video.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_text':
d.raw_text = doc_bytes[d_idx].decode()
d_idx += 1

for c in d.chunks:
if chunk_raw_bytes and chunk_raw_bytes[c_idx]:
c.raw = chunk_raw_bytes[c_idx]
if chunk_bytes and chunk_bytes[c_idx]:
c.embedding.data = chunk_bytes[c_idx]
c_idx += 1

if chunk_byte_type == 'raw':
c.raw = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'blob':
c.blob.data = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'text':
c.text = chunk_bytes[c_idx].decode()
c_idx += 1


def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1,
raw_bytes_in_separate: bool = False, **kwargs) -> None:
squeeze_pb: bool = False, **kwargs) -> None:
try:
if timeout > 0:
sock.setsockopt(zmq.SNDTIMEO, timeout)
else:
sock.setsockopt(zmq.SNDTIMEO, -1)

if not raw_bytes_in_separate:
if not squeeze_pb:
sock.send_multipart([msg.envelope.client_id.encode(), b'0', msg.SerializeToString()])
else:
doc_bytes, chunk_bytes = extract_raw_bytes_from_msg(msg)
doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type = extract_bytes_from_msg(msg)
# now raw_bytes are removed from message, hoping for faster de/serialization
sock.send_multipart(
[msg.envelope.client_id.encode(),
b'1', msg.SerializeToString(),
b'%d' % len(doc_bytes), *doc_bytes,
b'%d' % len(chunk_bytes), *chunk_bytes])
[msg.envelope.client_id.encode(), # 0
b'1', msg.SerializeToString(), # 1, 2
doc_byte_type, chunk_byte_type, # 3, 4
b'%d' % len(doc_bytes), b'%d' % len(chunk_bytes), # 5, 6
*doc_bytes, *chunk_bytes]) # 7, 8
except zmq.error.Again:
raise TimeoutError(
'cannot send message to sock %s after timeout=%dms, please check the following:'
@@ -237,24 +293,15 @@ def recv_message(sock: 'zmq.Socket', timeout: int = -1, check_version: bool = Fa

msg = gnes_pb2.Message()
msg_data = sock.recv_multipart()
raw_bytes_in_separate = (msg_data[1] == b'1')
squeeze_pb = (msg_data[1] == b'1')
msg.ParseFromString(msg_data[2])

if check_version:
check_msg_version(msg)

# now we have a barebone msg, we need to fill in data
if raw_bytes_in_separate:
doc_bytes_len_pos = 3
doc_bytes_len = int(msg_data[doc_bytes_len_pos])
doc_bytes = msg_data[(doc_bytes_len_pos + 1):(doc_bytes_len_pos + 1 + doc_bytes_len)]
chunk_bytes_len_pos = doc_bytes_len_pos + 1 + doc_bytes_len
chunk_bytes_len = int(msg_data[chunk_bytes_len_pos])
chunk_bytes = msg_data[(chunk_bytes_len_pos + 1):]
if len(chunk_bytes) != chunk_bytes_len:
raise ValueError('"chunk_bytes_len"=%d in message, but the actual length is %d' % (
chunk_bytes_len, len(chunk_bytes)))
fill_raw_bytes_to_msg(msg, doc_bytes, chunk_bytes)
if squeeze_pb:
fill_raw_bytes_to_msg(msg, msg_data)
return msg

except ValueError:
@@ -311,7 +311,7 @@ def __init__(self, args):
self.send_recv_kwargs = dict(
check_version=self.args.check_version,
timeout=self.args.timeout,
raw_bytes_in_separate=self.args.raw_bytes_in_separate)
squeeze_pb=self.args.squeeze_pb)

def run(self):
try:
@@ -58,7 +58,7 @@ def __init__(self, args):
self.send_recv_kwargs = dict(
check_version=self.args.check_version,
timeout=self.args.timeout,
raw_bytes_in_separate=self.args.raw_bytes_in_separate)
squeeze_pb=self.args.squeeze_pb)

def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg = gnes_pb2.Message()
@@ -1,14 +1,16 @@
import random
import unittest

import numpy as np

from gnes.cli.parser import _set_client_parser
from gnes.client.base import ZmqClient
from gnes.helper import TimeContext
from gnes.proto import gnes_pb2
from gnes.proto import gnes_pb2, array2blob, blob2array
from gnes.service.base import SocketType


class TestProto(unittest.TestCase):
class TestSqueezedSendRecv(unittest.TestCase):
def setUp(self):
self.c1_args = _set_client_parser().parse_args([
'--port_in', str(5678),
@@ -40,7 +42,7 @@ def test_send_recv_raw_bytes(self):
for j in range(random.randint(10, 20)):
d = msg.request.index.docs.add()
d.raw_bytes = b'a' * random.randint(100, 1000)
c1.send_message(msg, raw_bytes_in_separate=True)
c1.send_message(msg, squeeze_pb=True)
r_msg = c2.recv_message()
for d, r_d in zip(msg.request.index.docs, r_msg.request.index.docs):
self.assertEqual(d.raw_bytes, r_d.raw_bytes)
@@ -52,11 +54,11 @@ def test_send_recv_response(self):
msg = gnes_pb2.Message()
msg.envelope.client_id = c1.args.identity
msg.response.train.status = 2
c1.send_message(msg, raw_bytes_in_separate=True)
c1.send_message(msg, squeeze_pb=True)
r_msg = c2.recv_message()
self.assertEqual(msg.response.train.status, r_msg.response.train.status)

def test_benchmark(self):
def build_msgs(self):
all_msgs = []
num_msg = 20
for j in range(num_msg):
@@ -67,38 +69,96 @@ def test_benchmark(self):
# each doc is about 1MB to 10MB
d.raw_bytes = b'a' * random.randint(1000000, 10000000)
all_msgs.append(msg)
return all_msgs

def build_msgs2(self, seed=0):
all_msgs = []
num_msg = 20
random.seed(seed)
np.random.seed(seed)
for j in range(num_msg):
msg = gnes_pb2.Message()
msg.envelope.client_id = 'abc'
for _ in range(random.randint(10, 20)):
d = msg.request.index.docs.add()
# each doc is about 1MB to 10MB
for _ in range(random.randint(10, 20)):
c = d.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([10, 20, 30])))
c.blob.CopyFrom(array2blob(np.random.random([10, 20, 30])))
all_msgs.append(msg)
return all_msgs

def test_benchmark(self):
all_msgs = self.build_msgs()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send, raw_bytes_in_separate=False'):
with TimeContext('send, squeeze_pb=False'):
for m in all_msgs:
c1.send_message(m)
with TimeContext('recv, raw_bytes_in_separate=False'):
with TimeContext('recv, squeeze_pb=False'):
for _ in all_msgs:
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send, raw_bytes_in_separate=True'):
with TimeContext('send, squeeze_pb=True'):
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=True)
with TimeContext('recv, raw_bytes_in_separate=True'):
c1.send_message(m, squeeze_pb=True)
with TimeContext('recv, squeeze_pb=True'):
for _ in all_msgs:
c2.recv_message()

def test_benchmark2(self):
all_msgs = self.build_msgs()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, raw_bytes_in_separate=False'):
with TimeContext('send->recv, squeeze_pb=False'):
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=False)
c1.send_message(m, squeeze_pb=False)
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, raw_bytes_in_separate=True'):
with TimeContext('send->recv, squeeze_pb=True'):
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=True)
c1.send_message(m, squeeze_pb=True)
c2.recv_message()

def test_benchmark3(self):
all_msgs = self.build_msgs()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
for m in all_msgs:
c1.send_message(m, raw_bytes_in_separate=True)
c1.send_message(m, squeeze_pb=True)
r_m = c2.recv_message()
for d, r_d in zip(m.request.index.docs, r_m.request.index.docs):
self.assertEqual(d.raw_bytes, r_d.raw_bytes)

def test_benchmark4(self):
all_msgs = self.build_msgs2()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, squeeze_pb=False'):
for m in all_msgs:
c1.send_message(m, squeeze_pb=False)
c2.recv_message()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, squeeze_pb=True'):
for m in all_msgs:
c1.send_message(m, squeeze_pb=True)
c2.recv_message()

def test_benchmark5(self):
all_msgs = self.build_msgs2()
all_msgs_bak = self.build_msgs2()

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, squeeze_pb=True'):
for m, m1 in zip(all_msgs, all_msgs_bak):
c1.send_message(m, squeeze_pb=True)
r_m = c2.recv_message()

for d, r_d in zip(m1.request.index.docs, r_m.request.index.docs):
for c, r_c in zip(d.chunks, r_d.chunks):
np.allclose(blob2array(c.embedding), blob2array(r_c.embedding))
np.allclose(blob2array(c.blob), blob2array(r_c.blob))

0 comments on commit 0ea566f

Please sign in to comment.
You can’t perform that action at this time.