Skip to content

Commit

Permalink
More tests for Session.send/recv.
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisonbg committed Jul 14, 2011
1 parent 28a146d commit 88dae27
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
5 changes: 3 additions & 2 deletions IPython/zmq/session.py
Expand Up @@ -244,7 +244,7 @@ def _unpacker_changed(self, name, old, new):
def _session_default(self):
return bytes(uuid.uuid4())

username = Unicode(os.environ.get('USER','username'), config=True,
username = Unicode(os.environ.get('USER',u'username'), config=True,
help="""Username for the Session. Default is your system username.""")

# message signature related traits:
Expand Down Expand Up @@ -455,7 +455,8 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
The socket-like object used to send the data.
msg_or_type : str or Message/dict
Normally, msg_or_type will be a msg_type unless a message is being
sent more than once.
sent more than once. If a header is supplied, this can be set to
None and the msg_type will be pulled from the header.
content : dict or None
The content of the message (ignored if msg_or_type is a message).
Expand Down
67 changes: 62 additions & 5 deletions IPython/zmq/tests/test_session.py
Expand Up @@ -26,6 +26,22 @@ def setUp(self):
BaseZMQTestCase.setUp(self)
self.session = ss.Session()


class MockSocket(zmq.Socket):

def __init__(self, *args, **kwargs):
super(MockSocket,self).__init__(*args,**kwargs)
self.data = []

def send_multipart(self, msgparts, *args, **kwargs):
self.data.extend(msgparts)

def send(self, part, *args, **kwargs):
self.data.append(part)

def recv_multipart(self, *args, **kwargs):
return self.data

class TestSession(SessionTestCase):

def test_msg(self):
Expand All @@ -40,7 +56,7 @@ def test_msg(self):
self.assertEquals(msg['header']['msg_type'], 'execute')

def test_serialize(self):
msg = self.session.msg('execute')
msg = self.session.msg('execute',content=dict(a=10))
msg_list = self.session.serialize(msg, ident=b'foo')
ident, msg_list = self.session.feed_identities(msg_list)
new_msg = self.session.unserialize(msg_list)
Expand All @@ -49,22 +65,63 @@ def test_serialize(self):
self.assertEquals(new_msg['content'],msg['content'])
self.assertEquals(new_msg['parent_header'],msg['parent_header'])

def test_send(self):
socket = MockSocket(zmq.Context.instance(),zmq.PAIR)

msg = self.session.msg('execute', content=dict(a=10))
self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
ident, msg_list = self.session.feed_identities(socket.data)
new_msg = self.session.unserialize(msg_list)
self.assertEquals(ident[0], b'foo')
self.assertEquals(new_msg['header'],msg['header'])
self.assertEquals(new_msg['content'],msg['content'])
self.assertEquals(new_msg['parent_header'],msg['parent_header'])
self.assertEquals(new_msg['buffers'],[b'bar'])

socket.data = []

content = msg['content']
header = msg['header']
parent = msg['parent_header']
msg_type = header['msg_type']
self.session.send(socket, None, content=content, parent=parent,
header=header, ident=b'foo', buffers=[b'bar'])
ident, msg_list = self.session.feed_identities(socket.data)
new_msg = self.session.unserialize(msg_list)
self.assertEquals(ident[0], b'foo')
self.assertEquals(new_msg['header'],msg['header'])
self.assertEquals(new_msg['content'],msg['content'])
self.assertEquals(new_msg['parent_header'],msg['parent_header'])
self.assertEquals(new_msg['buffers'],[b'bar'])

socket.data = []

self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
ident, new_msg = self.session.recv(socket)
self.assertEquals(ident[0], b'foo')
self.assertEquals(new_msg['header'],msg['header'])
self.assertEquals(new_msg['content'],msg['content'])
self.assertEquals(new_msg['parent_header'],msg['parent_header'])
self.assertEquals(new_msg['buffers'],[b'bar'])

socket.close()

def test_args(self):
"""initialization arguments for Session"""
s = self.session
self.assertTrue(s.pack is ss.default_packer)
self.assertTrue(s.unpack is ss.default_unpacker)
self.assertEquals(s.username, os.environ.get('USER', 'username'))
self.assertEquals(s.username, os.environ.get('USER', u'username'))

s = ss.Session()
self.assertEquals(s.username, os.environ.get('USER', 'username'))
self.assertEquals(s.username, os.environ.get('USER', u'username'))

self.assertRaises(TypeError, ss.Session, pack='hi')
self.assertRaises(TypeError, ss.Session, unpack='hi')
u = str(uuid.uuid4())
s = ss.Session(username='carrot', session=u)
s = ss.Session(username=u'carrot', session=u)
self.assertEquals(s.session, u)
self.assertEquals(s.username, 'carrot')
self.assertEquals(s.username, u'carrot')

def test_tracking(self):
"""test tracking messages"""
Expand Down

0 comments on commit 88dae27

Please sign in to comment.