Permalink
Browse files

More tests for Session.send/recv.

  • Loading branch information...
1 parent 28a146d commit 88dae276443ddc45b54f50b5f4283d17acbff418 @ellisonbg committed Jul 14, 2011
Showing with 65 additions and 7 deletions.
  1. +3 −2 IPython/zmq/session.py
  2. +62 −5 IPython/zmq/tests/test_session.py
@@ -244,7 +244,7 @@ def _unpacker_changed(self, name, old, new):
def _session_default(self):
return bytes(uuid.uuid4())
- username = Unicode(os.environ.get('USER','username'), config=True,
+ username = Unicode(os.environ.get('USER',u'username'), config=True,
help="""Username for the Session. Default is your system username.""")
# message signature related traits:
@@ -455,7 +455,8 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
The socket-like object used to send the data.
msg_or_type : str or Message/dict
Normally, msg_or_type will be a msg_type unless a message is being
- sent more than once.
+ sent more than once. If a header is supplied, this can be set to
+ None and the msg_type will be pulled from the header.
content : dict or None
The content of the message (ignored if msg_or_type is a message).
@@ -26,6 +26,22 @@ def setUp(self):
BaseZMQTestCase.setUp(self)
self.session = ss.Session()
+
+class MockSocket(zmq.Socket):
+
+ def __init__(self, *args, **kwargs):
+ super(MockSocket,self).__init__(*args,**kwargs)
+ self.data = []
+
+ def send_multipart(self, msgparts, *args, **kwargs):
+ self.data.extend(msgparts)
+
+ def send(self, part, *args, **kwargs):
+ self.data.append(part)
+
+ def recv_multipart(self, *args, **kwargs):
+ return self.data
+
class TestSession(SessionTestCase):
def test_msg(self):
@@ -40,7 +56,7 @@ def test_msg(self):
self.assertEquals(msg['header']['msg_type'], 'execute')
def test_serialize(self):
- msg = self.session.msg('execute')
+ msg = self.session.msg('execute',content=dict(a=10))
msg_list = self.session.serialize(msg, ident=b'foo')
ident, msg_list = self.session.feed_identities(msg_list)
new_msg = self.session.unserialize(msg_list)
@@ -49,22 +65,63 @@ def test_serialize(self):
self.assertEquals(new_msg['content'],msg['content'])
self.assertEquals(new_msg['parent_header'],msg['parent_header'])
+ def test_send(self):
+ socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
+
+ msg = self.session.msg('execute', content=dict(a=10))
+ self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
+ ident, msg_list = self.session.feed_identities(socket.data)
+ new_msg = self.session.unserialize(msg_list)
+ self.assertEquals(ident[0], b'foo')
+ self.assertEquals(new_msg['header'],msg['header'])
+ self.assertEquals(new_msg['content'],msg['content'])
+ self.assertEquals(new_msg['parent_header'],msg['parent_header'])
+ self.assertEquals(new_msg['buffers'],[b'bar'])
+
+ socket.data = []
+
+ content = msg['content']
+ header = msg['header']
+ parent = msg['parent_header']
+ msg_type = header['msg_type']
+ self.session.send(socket, None, content=content, parent=parent,
+ header=header, ident=b'foo', buffers=[b'bar'])
+ ident, msg_list = self.session.feed_identities(socket.data)
+ new_msg = self.session.unserialize(msg_list)
+ self.assertEquals(ident[0], b'foo')
+ self.assertEquals(new_msg['header'],msg['header'])
+ self.assertEquals(new_msg['content'],msg['content'])
+ self.assertEquals(new_msg['parent_header'],msg['parent_header'])
+ self.assertEquals(new_msg['buffers'],[b'bar'])
+
+ socket.data = []
+
+ self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
+ ident, new_msg = self.session.recv(socket)
+ self.assertEquals(ident[0], b'foo')
+ self.assertEquals(new_msg['header'],msg['header'])
+ self.assertEquals(new_msg['content'],msg['content'])
+ self.assertEquals(new_msg['parent_header'],msg['parent_header'])
+ self.assertEquals(new_msg['buffers'],[b'bar'])
+
+ socket.close()
+
def test_args(self):
"""initialization arguments for Session"""
s = self.session
self.assertTrue(s.pack is ss.default_packer)
self.assertTrue(s.unpack is ss.default_unpacker)
- self.assertEquals(s.username, os.environ.get('USER', 'username'))
+ self.assertEquals(s.username, os.environ.get('USER', u'username'))
s = ss.Session()
- self.assertEquals(s.username, os.environ.get('USER', 'username'))
+ self.assertEquals(s.username, os.environ.get('USER', u'username'))
self.assertRaises(TypeError, ss.Session, pack='hi')
self.assertRaises(TypeError, ss.Session, unpack='hi')
u = str(uuid.uuid4())
- s = ss.Session(username='carrot', session=u)
+ s = ss.Session(username=u'carrot', session=u)
self.assertEquals(s.session, u)
- self.assertEquals(s.username, 'carrot')
+ self.assertEquals(s.username, u'carrot')
def test_tracking(self):
"""test tracking messages"""

0 comments on commit 88dae27

Please sign in to comment.