Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(unittest): fix unit test for send recv
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 25, 2019
1 parent 31f53bc commit 4314501
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/test_raw_bytes_send.py
@@ -1,3 +1,4 @@
import copy
import random import random
import unittest import unittest


Expand Down Expand Up @@ -42,12 +43,12 @@ def test_send_recv_raw_bytes(self):
for j in range(random.randint(10, 20)): for j in range(random.randint(10, 20)):
d = msg.request.index.docs.add() d = msg.request.index.docs.add()
d.raw_bytes = b'a' * random.randint(100, 1000) d.raw_bytes = b'a' * random.randint(100, 1000)
raw_bytes = [d.raw_bytes for d in msg.request.index.docs] raw_bytes = copy.deepcopy([d.raw_bytes for d in msg.request.index.docs])
c1.send_message(msg, squeeze_pb=True) c1.send_message(msg, squeeze_pb=True)
r_msg = c2.recv_message() r_msg = c2.recv_message()
for d, r_d in zip(msg.request.index.docs, r_msg.request.index.docs): for d, o_d, r_d in zip(msg.request.index.docs, raw_bytes, r_msg.request.index.docs):
self.assertEqual(d.raw_bytes, b'') self.assertEqual(d.raw_bytes, b'')
self.assertEqual(raw_bytes, r_d.raw_bytes) self.assertEqual(o_d, r_d.raw_bytes)
print('.', end='') print('.', end='')
print('checked %d docs' % len(msg.request.index.docs)) print('checked %d docs' % len(msg.request.index.docs))


Expand Down Expand Up @@ -127,15 +128,15 @@ def test_benchmark2(self):


def test_benchmark3(self): def test_benchmark3(self):
all_msgs = self.build_msgs() all_msgs = self.build_msgs()
all_msgs_bak = copy.deepcopy(all_msgs)


with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
for m in all_msgs: for m, m1 in zip(all_msgs, all_msgs_bak):
raw_bytes = [d.raw_bytes for d in m.request.index.docs]
c1.send_message(m, squeeze_pb=True) c1.send_message(m, squeeze_pb=True)
r_m = c2.recv_message() r_m = c2.recv_message()
for d, r_d in zip(m.request.index.docs, r_m.request.index.docs): for d, o_d, r_d in zip(m.request.index.docs, m1.request.index.docs, r_m.request.index.docs):
self.assertEqual(d.raw_bytes, b'') self.assertEqual(d.raw_bytes, b'')
self.assertEqual(raw_bytes, r_d.raw_bytes) self.assertEqual(o_d.raw_bytes, r_d.raw_bytes)


def test_benchmark4(self): def test_benchmark4(self):
all_msgs = self.build_msgs2() all_msgs = self.build_msgs2()
Expand All @@ -154,7 +155,7 @@ def test_benchmark4(self):


def test_benchmark5(self): def test_benchmark5(self):
all_msgs = self.build_msgs2() all_msgs = self.build_msgs2()
all_msgs_bak = self.build_msgs2() all_msgs_bak = copy.deepcopy(all_msgs)


with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, squeeze_pb=True'): with TimeContext('send->recv, squeeze_pb=True'):
Expand All @@ -165,4 +166,4 @@ def test_benchmark5(self):
for d, r_d in zip(m1.request.index.docs, r_m.request.index.docs): for d, r_d in zip(m1.request.index.docs, r_m.request.index.docs):
for c, r_c in zip(d.chunks, r_d.chunks): for c, r_c in zip(d.chunks, r_d.chunks):
np.allclose(blob2array(c.embedding), blob2array(r_c.embedding)) np.allclose(blob2array(c.embedding), blob2array(r_c.embedding))
np.allclose(blob2array(c.blob), blob2array(r_c.blob)) np.allclose(blob2array(c.blob), blob2array(r_c.blob))

0 comments on commit 4314501

Please sign in to comment.