Skip to content

Commit

Permalink
Merge pull request #20 from asvetlov/master
Browse files Browse the repository at this point in the history
asyncio integration
  • Loading branch information
ajdavis committed Aug 26, 2015
2 parents e08b8f1 + f0f3640 commit cb7a7ca
Show file tree
Hide file tree
Showing 21 changed files with 1,878 additions and 168 deletions.
21 changes: 10 additions & 11 deletions motor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,17 @@ def connect(self, force=False):
self.queue.append(waiter)

if self.wait_queue_timeout is not None:
deadline = self.io_loop.time() + self.wait_queue_timeout

def on_timeout():
if waiter in self.queue:
self.queue.remove(waiter)

t = self.waiter_timeouts.pop(waiter)
self.io_loop.remove_timeout(t)
self._framework.call_later_cancel(self.io_loop, t)
child_gr.throw(self._create_wait_queue_timeout())

timeout = self.io_loop.add_timeout(deadline, on_timeout)
timeout = self._framework.call_later(
self.io_loop, self.wait_queue_timeout, on_timeout)

# timeout = self.io_loop.add_timeout(
# deadline,
# functools.partial(
Expand Down Expand Up @@ -323,10 +323,10 @@ def maybe_return_socket(self, sock_info):
waiter = self.queue.popleft()
if waiter in self.waiter_timeouts:
timeout = self.waiter_timeouts.pop(waiter)
self.io_loop.remove_timeout(timeout)
self._framework.call_later_cancel(self.io_loop, timeout)

# TODO: with stack_context.NullContext():
self.io_loop.add_callback(functools.partial(waiter, sock_info))
self._framework.call_soon(self.io_loop,
functools.partial(waiter, sock_info))

elif (self.motor_sock_counter <= self.max_size
and sock_info.pool_id == self.pool_id):
Expand Down Expand Up @@ -450,9 +450,7 @@ def __init__(self, io_loop, *args, **kwargs):
delegate = self.__delegate_class__(*args, **kwargs)
super(AgnosticClientBase, self).__init__(delegate)
if io_loop:
if not self._framework.is_event_loop(io_loop):
raise TypeError(
"io_loop must be instance of IOLoop, not %r" % io_loop)
self._framework.check_event_loop(io_loop)
self.io_loop = io_loop
else:
self.io_loop = self._framework.get_event_loop()
Expand Down Expand Up @@ -522,7 +520,8 @@ def __init__(self, *args, **kwargs):
else:
io_loop = self._framework.get_event_loop()

event_class = functools.partial(util.MotorGreenletEvent, io_loop, self._framework)
event_class = functools.partial(util.MotorGreenletEvent, io_loop,
self._framework)
kwargs['_event_class'] = event_class

# Our class is not actually AgnosticClient here, it's the version of
Expand Down
177 changes: 64 additions & 113 deletions motor/frameworks/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,30 @@
import functools
import socket
import ssl
import sys

import greenlet
import collections


def get_event_loop():
return asyncio.get_event_loop()


def is_event_loop(loop):
# TODO: is there any way to assure that this is an event loop?
return True
return isinstance(loop, asyncio.AbstractEventLoop)


def check_event_loop(loop):
if not is_event_loop(loop):
raise TypeError(
"io_loop must be instance of asyncio-compatible event loop,"
"not %r" % loop)


def return_value(value):
# In Python 3.3, StopIteration can accept a value.
raise StopIteration(value)


# TODO: rename?
def get_future(loop):
return asyncio.Future(loop=loop)

Expand All @@ -50,10 +53,10 @@ def is_future(f):


def call_soon(loop, callback, *args, **kwargs):
if args or kwargs:
if kwargs:
loop.call_soon(functools.partial(callback, *args, **kwargs))
else:
loop.call_soon(callback)
loop.call_soon(callback, *args)


def call_soon_threadsafe(loop, callback):
Expand All @@ -62,14 +65,10 @@ def call_soon_threadsafe(loop, callback):

def call_later(loop, delay, callback, *args, **kwargs):
if kwargs:
return loop.call_later(
delay,
functools.partial(callback, *args, **kwargs))
return loop.call_later(delay,
functools.partial(callback, *args, **kwargs))
else:
return loop.call_later(
loop.time() + delay,
callback,
*args)
return loop.call_later(delay, callback, *args)


def call_later_cancel(loop, handle):
Expand All @@ -86,18 +85,6 @@ def get_resolver(loop):
return None


def resolve(resolver, loop, host, port, family, callback, errback):
def done_callback(future):
try:
addresses = future.result()
callback(addresses)
except:
errback(*sys.exc_info())

future = loop.getaddrinfo(host, port, family=family)
future.add_done_callback(done_callback)


def close_resolver(resolver):
pass

Expand All @@ -118,139 +105,103 @@ def asyncio_motor_sock_method(method):
when I/O is ready.
"""
@functools.wraps(method)
def _motor_sock_method(self, *args, **kwargs):
def wrapped_method(self, *args, **kwargs):
child_gr = greenlet.getcurrent()
main = child_gr.parent
assert main is not None, "Should be on child greenlet"

future = None
timeout_handle = None

if self.timeout:
def timeout_err():
if future:
future.cancel()

if self._transport:
self._transport.abort()

child_gr.throw(socket.error("timed out"))

timeout_handle = self.loop.call_later(self.timeout, timeout_err)

# This is run by the event loop on the main greenlet when operation
# completes; switch back to child to continue processing
def callback(_):
if timeout_handle:
timeout_handle.cancel()

try:
child_gr.switch(future.result())
except asyncio.CancelledError:
# Timeout. We've already thrown an error on the child greenlet.
pass
res = future.result()
except asyncio.TimeoutError:
child_gr.throw(socket.timeout("timed out"))
except Exception as ex:
child_gr.throw(socket.error(str(ex)))
child_gr.throw(ex)
else:
child_gr.switch(res)

future = asyncio.async(method(self, *args, **kwargs), loop=self.loop)
coro = method(self, *args, **kwargs)
if self.timeout:
coro = asyncio.wait_for(coro, self.timeout, loop=self.loop)
future = asyncio.async(coro, loop=self.loop)
future.add_done_callback(callback)
return main.switch()

return _motor_sock_method
return wrapped_method


class AsyncioMotorSocket(asyncio.Protocol):
class AsyncioMotorSocket:
"""A fake socket instance that pauses and resumes the current greenlet.
Pauses the calling greenlet when making blocking calls, and uses the
asyncio event loop to schedule the greenlet for resumption when I/O is ready.
asyncio event loop to schedule the greenlet for resumption when I/O
is ready.
We only implement those socket methods actually used by PyMongo.
"""
def __init__(self, loop, options):
self.loop = loop
self.options = options
self.timeout = None
self.ctx = None
self._transport = None
self._connected_future = asyncio.Future(loop=self.loop)
self._buffer = collections.deque()
self._buffer_len = 0
self._recv_future = asyncio.Future(loop=self.loop)

if options.use_ssl:
# TODO: cache at Pool level.
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
if options.certfile is not None:
ctx.load_cert_chain(options.certfile, options.keyfile)
if options.ca_certs is not None:
ctx.load_verify_locations(options.ca_certs)
if options.cert_reqs is not None:
ctx.verify_mode = options.cert_reqs
if ctx.verify_mode in (ssl.CERT_OPTIONAL, ssl.CERT_REQUIRED):
ctx.check_hostname = True

self.ctx = ctx
self._writer = None
self._reader = None

def settimeout(self, timeout):
self.timeout = timeout

@asyncio_motor_sock_method
@asyncio.coroutine
def connect(self):
protocol_factory = lambda: self

# TODO: will call getaddrinfo again.
host, port = self.options.address
self._transport, protocol = yield from self.loop.create_connection(
protocol_factory, host, port,
ssl=self.ctx)
is_unix_socket = (self.options.family == getattr(socket,
'AF_UNIX', None))
if self.options.use_ssl:
# TODO: cache at Pool level.
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
if self.options.certfile is not None:
ctx.load_cert_chain(self.options.certfile,
self.options.keyfile)
if self.options.ca_certs is not None:
ctx.load_verify_locations(self.options.ca_certs)
if self.options.cert_reqs is not None:
ctx.verify_mode = self.options.cert_reqs
if ctx.verify_mode in (ssl.CERT_OPTIONAL, ssl.CERT_REQUIRED):
ctx.check_hostname = True
else:
ctx = None

if is_unix_socket:
path = self.options.address[0]
reader, writer = yield from asyncio.open_unix_connection(
path, loop=self.loop, ssl=ctx)
else:
host, port = self.options.address
reader, writer = yield from asyncio.open_connection(
host=host, port=port, ssl=ctx, loop=self.loop)
sock = writer.transport.get_extra_info('socket')
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE,
self.options.socket_keepalive)
self._reader, self._writer = reader, writer

def sendall(self, data):
assert greenlet.getcurrent().parent is not None,\
"Should be on child greenlet"

# TODO: backpressure? errors?
self._transport.write(data)
self._writer.write(data)

@asyncio_motor_sock_method
@asyncio.coroutine
def recv(self, num_bytes):
while self._buffer_len < num_bytes:
yield from self._recv_future

data = bytes().join(self._buffer)
rv = data[:num_bytes]
remainder = data[num_bytes:]

self._buffer.clear()
if remainder:
self._buffer.append(remainder)

self._buffer_len = len(remainder)

rv = yield from self._reader.readexactly(num_bytes)
return rv

def close(self):
if self._transport:
self._transport.close()

# Protocol interface.
def connection_made(self, transport):
pass
# self._connected_future.set_result(None)

def data_received(self, data):
self._buffer_len += len(data)
self._buffer.append(data)

# TODO: comment
future = self._recv_future
self._recv_future = asyncio.Future(loop=self.loop)
future.set_result(None)
if self._writer:
self._writer.close()

def connection_lost(self, exc):
pass

# A create_socket() function is part of Motor's framework interface.
create_socket = AsyncioMotorSocket
6 changes: 6 additions & 0 deletions motor/frameworks/tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def is_event_loop(loop):
return isinstance(loop, ioloop.IOLoop)


def check_event_loop(loop):
if not is_event_loop(loop):
raise TypeError(
"io_loop must be instance of IOLoop, not %r" % loop)


def return_value(value):
raise gen.Return(value)

Expand Down
2 changes: 1 addition & 1 deletion motor/motor_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def __init__(
self.keyfile = keyfile
self.ca_certs = ca_certs
self.cert_reqs = cert_reqs
self.socket_keepalive=socket_keepalive
self.socket_keepalive = socket_keepalive
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def run(self):

setup(name='motor',
version='0.5.dev0',
packages=['motor'],
packages=['motor', 'motor.frameworks'],
description=description,
long_description=long_description,
author='A. Jesse Jiryu Davis',
Expand Down
2 changes: 2 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def setup_package(tornado_warnings):
if not tornado_warnings:
suppress_tornado_warnings()


def is_server_resolvable():
"""Returns True if 'server' is resolvable."""
socket_timeout = socket.getdefaulttimeout()
Expand All @@ -71,6 +72,7 @@ def is_server_resolvable():
finally:
socket.setdefaulttimeout(socket_timeout)


def teardown_package():
if env.auth:
env.sync_cx.admin.remove_user(db_user)
Expand Down
Loading

0 comments on commit cb7a7ca

Please sign in to comment.