Skip to content

Commit

Permalink
fix evilkost#7: call callbacks through context.ret_call
Browse files Browse the repository at this point in the history
  • Loading branch information
evilkost committed Apr 16, 2011
1 parent 64c9930 commit d47eff4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ Input:
c = brukva.Client()
c.connect()
loop = c.connection._stream.io_loop

def on_result(result):
print result
c.set('foo', 'bar', on_result)
c.get('foo', on_result)
c.hgetall('foo', [on_result, lambda r: loop.stop()] )

c.hgetall('foo', [on_result, lambda r: loop.stop()])
loop.start() # start tornado mainloop

Output:
Expand Down
60 changes: 31 additions & 29 deletions brukva/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import socket
from functools import partial
from itertools import izip
import contextlib
import logging
from collections import Iterable
import weakref
Expand All @@ -16,11 +15,19 @@

log = logging.getLogger('brukva.client')

class ForwardErrorManager(object):
class ExecutionContext(object):
def __init__(self, callbacks):
self.callbacks = callbacks
self.is_active = True

def _call_callbacks(self, value):
if self.callbacks:
if isinstance(self.callbacks, Iterable):
for cb in self.callbacks:
cb(value)
else:
self.callbacks(value)

def __enter__(self):
return self

Expand All @@ -29,11 +36,7 @@ def __exit__(self, type, value, tb):
return True

if self.is_active:
if isinstance(self.callbacks, Iterable):
for cb in self.callbacks:
cb(value)
else:
self.callbacks(value)
self._call_callbacks(value)
return True
else:
return False
Expand All @@ -44,19 +47,23 @@ def disable(self):
def enable(self):
self.is_active = True

def forward_error(callbacks):
def ret_call(self, value):
self.is_active = False
self._call_callbacks(value)
self.is_active = True

def execution_context(callbacks):
"""
Syntax sugar.
If some error occurred inside with block,
it will be suppressed and forwarded to callbacks.
Error handling can be disabled using context.disable(),
and re enabled again using context.enable().
Use contex.ret_call(value) method to call callbacks.
@type callbacks: callable or iterator over callables
@rtype: context
"""
return ForwardErrorManager(callbacks)
return ExecutionContext(callbacks)

class Message(object):
def __init__(self, kind, channel, body):
Expand Down Expand Up @@ -372,8 +379,7 @@ def call_callbacks(self, callbacks, *args, **kwargs):

@process
def execute_command(self, cmd, callbacks, *args, **kwargs):
result = None
with forward_error(callbacks):
with execution_context(callbacks) as ctx:
if callbacks is None:
callbacks = []
elif not hasattr(callbacks, '__iter__'):
Expand All @@ -400,13 +406,12 @@ def execute_command(self, cmd, callbacks, *args, **kwargs):
result = self.format_reply(cmd_line, response)

self.connection.read_done()

self.call_callbacks(callbacks, result)
ctx.ret_call(result)

@async
@process
def process_data(self, data, cmd_line, callback):
with forward_error(callback):
with execution_context(callback) as ctx:
data = data[:-2] # strip \r\n

if data == '$-1':
Expand All @@ -432,13 +437,12 @@ def process_data(self, data, cmd_line, callback):
response = ResponseError(tail, cmd_line)
else:
raise ResponseError('Unknown response type %s' % head, cmd_line)

callback(response)
ctx.ret_call(response)

@async
@process
def consume_multibulk(self, length, cmd_line, callback):
with forward_error(callback):
with execution_context(callback) as ctx:
tokens = []
while len(tokens) < length:
data = yield async(self.connection.readline)()
Expand All @@ -449,20 +453,21 @@ def consume_multibulk(self, length, cmd_line, callback):
)
token = yield self.process_data(data, cmd_line) #FIXME error
tokens.append( token )
callback(tokens)

ctx.ret_call(tokens)

@async
@process
def consume_bulk(self, length, callback):
with forward_error(callback):
with execution_context(callback) as ctx:
data = yield async(self.connection.read)(length)
if isinstance(data, Exception):
raise data
if not data:
raise ResponseError('EmptyResponse')
else:
data = data[:-2]
callback(data)
ctx.ret_call(data)
####

### MAINTENANCE
Expand Down Expand Up @@ -849,7 +854,7 @@ def publish(self, channel, message, callbacks=None):
@process
def listen(self, callbacks=None):
# 'LISTEN' is just for receiving information, it is not actually sent anywhere
with forward_error(callbacks) as forward:
with execution_context(callbacks) as ctx:
callbacks = callbacks or []
if not hasattr(callbacks, '__iter__'):
callbacks = [callbacks]
Expand All @@ -865,10 +870,8 @@ def listen(self, callbacks=None):
if isinstance(response, Exception):
raise response
result = self.format_reply(cmd_listen, response)
ctx.ret_call(result)

forward.disable()
self.call_callbacks(callbacks, result)
forward.enable()
### CAS
def watch(self, key, callbacks=None):
self.execute_command('WATCH', callbacks, key)
Expand Down Expand Up @@ -902,8 +905,7 @@ def auth(self, password, callbacks=None):

@process
def execute(self, callbacks):
results = None
with forward_error(callbacks):
with execution_context(callbacks) as ctx:
command_stack = self.command_stack
self.command_stack = []

Expand Down Expand Up @@ -967,4 +969,4 @@ def format_replies(cmd_lines, responses):
else:
results = format_replies(command_stack, responses)

self.call_callbacks(callbacks, results)
ctx.ret_call(results)

0 comments on commit d47eff4

Please sign in to comment.