Skip to content

Commit

Permalink
added ZMQStream.flush()
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Oct 14, 2010
1 parent 4a8ffff commit 65d7769
Showing 1 changed file with 69 additions and 17 deletions.
86 changes: 69 additions & 17 deletions zmq/eventloop/zmqstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,20 @@ class ZMQStream(object):

socket = None
io_loop = None
poller = None

def __init__(self, socket, io_loop=None):
self.socket = socket
self.io_loop = io_loop or ioloop.IOLoop.instance()
self.poller = zmq.Poller()

self._send_queue = Queue()
self._recv_callback = None
self._send_callback = None
self._close_callback = None
self._errback = None
self._recv_copy = False
self._flushed = False

self._state = zmq.POLLERR
with stack_context.NullContext():
Expand Down Expand Up @@ -163,10 +166,6 @@ def on_send(self, callback):
if callback is None, send callbacks are disabled.
"""
self._send_callback = stack_context.wrap(callback)
if callback is None:
self._drop_io_state(zmq.POLLOUT)
else:
self._add_io_state(zmq.POLLOUT)

def on_err(self, callback):
"""register a callback to be called on POLLERR events
Expand All @@ -178,7 +177,6 @@ def on_err(self, callback):
callback : callable
callback will be passed no arguments.
"""
# self._add_io_state(zmq.POLLOUT)
self._errback = stack_context.wrap(callback)


Expand All @@ -200,6 +198,7 @@ def send_multipart(self, msg, flags=0, copy=False, callback=None):
else:
# noop callback
self.on_send(lambda *args: None)
self._add_io_state(zmq.POLLOUT)

def send_unicode(self, u, flags=0, encoding='utf-8', callback=None):
"""Send a unicode message with an encoding.
Expand Down Expand Up @@ -227,6 +226,56 @@ def send_pyobj(self, obj, flags=0, protocol=-1, callback=None):
msg = pickle.dumps(obj, protocol)
return self.send(msg, flags, callback=callback)

def _finish_flushing(self):
"""callback for unsetting _flushed flag."""
self._flushed = False

def flush(self, limit=None):
"""Flush pending messages.
This method safely handles all pending incoming/outgoing messages, bypassing the inner loop.
A limit can be specified, to prevent blocking under high load.
All the usual callbacks are called.
Parameters
----------
limit : None or int
The maximum number of messages to send or receive.
If specified, flush will return when *either*
send or recv reaches the limit (not the sum).
Returns
-------
(int, int) : (msgs_received, msgs_sent)
"""
# initialize counters
sent = recvd = 0

flag = (self.receiving() and zmq.POLLIN) | (self.sending() and zmq.POLLOUT)
self.poller.register(self.socket, flag)
events = self.poller.poll(0)
while events and (not limit or (sent < limit and recvd < limit)):
s,event = events[0]
if event & zmq.POLLIN: # receiving
self._handle_recv()
recvd += 1
if event & zmq.POLLOUT and self.sending():
self._handle_send()
sent += 1
flag = (self.receiving() and zmq.POLLIN) | (self.sending() and zmq.POLLOUT)
self.poller.register(self.socket, flag)

events = self.poller.poll(0)
# skip send/recv callbacks this iteration,
# but reregister them at the end of the loop
if recvd or sent:
# only bypass if we did something here
self._flushed = True
dc = ioloop.DelayedCallback(self._finish_flushing, 0, self.io_loop)
dc.start()
return recvd, sent

def set_close_callback(self, callback):
"""Call the given callback when the stream is closed."""
self._close_callback = stack_context.wrap(callback)
Expand Down Expand Up @@ -279,17 +328,20 @@ def _handle_events(self, fd, events):
if not self.socket:
logging.warning("Got events for closed stream %s", fd)
return
# dispatch events:
if events & zmq.POLLERR:
self._handle_error()
return
if events & zmq.POLLIN:
self._handle_recv()
if not self.socket:
return
if not self.socket:
return
if events & zmq.POLLOUT:
self._handle_send()
if not self.socket:
return
if events & zmq.POLLERR:
self._handle_error()
return
if not self.socket:
return

# rebuild the poll state
state = zmq.POLLERR
if self.receiving():
state |= zmq.POLLIN
Expand All @@ -301,6 +353,8 @@ def _handle_events(self, fd, events):

def _handle_recv(self):
"""Handle a recv event."""
if self._flushed:
return
try:
msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
except zmq.ZMQError, e:
Expand All @@ -320,7 +374,10 @@ def _handle_recv(self):

def _handle_send(self):
"""Handle a send event."""
if self._flushed:
return
if not self.sending():
logging.error("Shouldn't have handled a send event")
return

msg = self._send_queue.get()
Expand All @@ -332,15 +389,10 @@ def _handle_send(self):
callback = self._send_callback
self._run_callback(callback, msg, status)

# unregister from event loop:
if not self.sending():
self._drop_io_state(zmq.POLLOUT)

# self.update_state()

def _handle_error(self):
"""Handle a POLLERR event."""
# if evt & zmq.POLLERR:
logging.error("handling error..")
if self._errback is not None:
self._errback()
Expand Down

0 comments on commit 65d7769

Please sign in to comment.