From dee5835761b8d4e9c4e1044c20633b2349a06f21 Mon Sep 17 00:00:00 2001 From: Geert Jansen Date: Mon, 12 Jun 2017 20:23:56 -0400 Subject: [PATCH] core: add file descriptor passing At the transport level, Transport.write() and SslTransport.write() acquired a *handle* argument that can be used to pass a handle. At the endpoint level, both Client.connect() and Server.listen() got an *ipc* argument that can be set to True to enable handle passing. --- lib/gruvi/compat.py | 47 ++++++++++ lib/gruvi/endpoints.py | 128 +++++++++++++--------------- lib/gruvi/ssl.py | 12 ++- lib/gruvi/transports.py | 9 +- tests/test_endpoints.py | 184 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 307 insertions(+), 73 deletions(-) diff --git a/lib/gruvi/compat.py b/lib/gruvi/compat.py index 79e5ae0..fa43137 100644 --- a/lib/gruvi/compat.py +++ b/lib/gruvi/compat.py @@ -8,9 +8,13 @@ from __future__ import absolute_import, print_function +import os import io import sys import threading +import socket +import errno +import pyuv # Some compatibility stuff that is not in six. @@ -71,3 +75,46 @@ def writelines(self, seq): super(TextIOWrapper, self).writelines(seq) if self._write_through: self.flush() + + +# Needed until pyuv accepts PR #249 and #250 + +def pyuv_pipe_helper(handle, handle_args, op, addr): + if not isinstance(handle, pyuv.Pipe): + return False + # Store the 'ipc' constructor argument. + if handle_args and not hasattr(handle, 'ipc'): + handle.ipc = handle_args[0] + if not sys.platform.startswith('linux') or '\x00' not in addr: + return False + # Connect or bind the socket. + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.setblocking(False) + try: + if op == 'connect': + sock.connect(addr) + elif op == 'bind': + sock.bind(addr) + fd = os.dup(sock.fileno()) + except IOError as e: + # Connecting to an AF_UNIX socket never gives EAGAIN on Linux. + assert e.errno != errno.EAGAIN + # Convert from Unix errno -> libuv errno via the symbolic error name + errname = 'UV_{}'.format(errno.errocode.get(e.errno, 'UNKNOWN')) + errnum = getattr(pyuv.errno, errname, pyuv.errno.UV_UNKNOWN) + raise pyuv.error.PipeError(errnum, os.strerror(e.errno)) + finally: + sock.close() + handle.open(fd) + # Work around a bug in pyuv where abstract sockets names are reported as + # bytes by dynamically patching getsockname(). The above PRs should fix this. + if PY3: + self = handle + encoding = sys.getfilesystemencoding() + def getsockname(): + value = type(handle).getsockname(self) + if isinstance(value, bytes): + value = value.decode(encoding) + return value + handle.getsockname = getsockname + return True diff --git a/lib/gruvi/endpoints.py b/lib/gruvi/endpoints.py index 5f02611..970298b 100644 --- a/lib/gruvi/endpoints.py +++ b/lib/gruvi/endpoints.py @@ -9,14 +9,12 @@ from __future__ import absolute_import, print_function import os -import sys import socket import functools import pyuv import six -import errno -from . import logging +from . import logging, compat from .hub import get_hub, switchpoint, switch_back from .sync import Event from .errors import Timeout @@ -28,45 +26,10 @@ __all__ = ['create_connection', 'create_server', 'Endpoint', 'Client', 'Server'] -def _use_af_unix(addr): - """Return whether to open a :class:`pyuv.Pipe` via an AF_UNIX socket.""" - # This is used on Linux only to support abstract sockets. - if isinstance(addr, six.text_type) and u'\x00' not in addr \ - or isinstance(addr, six.binary_type) and b'\x00' not in addr: - return False - return sys.platform in ('linux', 'linux2', 'linux3') - -def _af_unix_helper(handle, address, op): - """Connect or bind a :class:`pyuv.Pipe` to an AF_UNIX socket. - - We use this on Linux to work around the limitation in the libuv API that - socket names cannot have null bytes in them (required for abstract - sockets on Linux). - """ - # Note that on Linux, connect() to an abstract socket never returns EAGAIN. - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.setblocking(False) - try: - if op == 'connect': - sock.connect(address) - elif op == 'bind': - sock.bind(address) - fd = os.dup(sock.fileno()) - except IOError as e: - # Connecting to an AF_UNIX socket never gives EAGAIN on Linux. - assert e.errno != errno.EAGAIN - # Convert from Unix errno -> libuv errno via the symbolic error name - errname = 'UV_{}'.format(errno.errocode.get(e.errno, 'UNKNOWN')) - errnum = getattr(pyuv.errno, errname, pyuv.errno.UV_UNKNOWN) - raise pyuv.error.PipeError(errnum, os.strerror(e.errno)) - finally: - sock.close() - handle.open(fd) - - @switchpoint def create_connection(protocol_factory, address, ssl=False, server_hostname=None, - local_address=None, family=0, flags=0, timeout=None, mode='rw'): + local_address=None, family=0, flags=0, ipc=False, timeout=None, + mode='rw'): """Create a new client connection. This method creates a new :class:`pyuv.Handle`, connects it to *address*, @@ -122,8 +85,10 @@ def create_connection(protocol_factory, address, ssl=False, server_hostname=None """ hub = get_hub() log = logging.get_logger() + handle_args = () if isinstance(address, (six.binary_type, six.text_type)): handle_type = pyuv.Pipe + handle_args = (ipc,) addresses = [address] elif isinstance(address, tuple): handle_type = pyuv.TCP @@ -141,8 +106,8 @@ def create_connection(protocol_factory, address, ssl=False, server_hostname=None raise ValueError("mode: must be either 'r' or 'w' for tty") handle = pyuv.TTY(hub.loop, address, mode == 'r') else: - handle = pyuv.Pipe(hub.loop, True) - handle.open(address) + handle = pyuv.Pipe(hub.loop, ipc) + handle.open(address) addresses = []; error = None elif isinstance(address, pyuv.Stream): handle = address @@ -151,16 +116,15 @@ def create_connection(protocol_factory, address, ssl=False, server_hostname=None raise TypeError('expecting a string, tuple, fd, or pyuv.Stream') for addr in addresses: log.debug('trying address {}', saddr(addr)) - handle = handle_type(hub.loop) + handle = handle_type(hub.loop, *handle_args) + error = None try: - error = None - if handle_type is pyuv.Pipe and _use_af_unix(addr): - _af_unix_helper(handle, addr, 'connect') - else: - with switch_back(timeout) as switcher: - handle.connect(addr, switcher) - result = hub.switch() - _, error = result[0] + if compat.pyuv_pipe_helper(handle, handle_args, 'connect', addr): + break + with switch_back(timeout) as switcher: + handle.connect(addr, switcher) + result = hub.switch() + _, error = result[0] except pyuv.error.UVError as e: error = e[0] except Timeout: @@ -177,7 +141,8 @@ def create_connection(protocol_factory, address, ssl=False, server_hostname=None protocol = protocol_factory() protocol._timeout = timeout if ssl: - context = ssl if hasattr(ssl, 'set_ciphers') else ssl(handle) if callable(ssl) \ + context = ssl if hasattr(ssl, 'set_ciphers') \ + else ssl(handle) if callable(ssl) \ else create_default_context(False) transport = SslTransport(handle, context, False, server_hostname) else: @@ -192,7 +157,7 @@ def create_connection(protocol_factory, address, ssl=False, server_hostname=None @switchpoint def create_server(protocol_factory, address=None, ssl=False, family=0, flags=0, - backlog=128): + ipc=False, backlog=128): """ Create a new network server. @@ -219,8 +184,14 @@ def create_server(protocol_factory, address=None, ssl=False, family=0, flags=0, The *family* and *flags* keyword arguments are used to customize address resolution for TCP handles as described in :func:`socket.getaddrinfo`. - The *backlog* parameter specifies the listen backlog i.e the maximum - number of not yet accepted connections to queue. + The *ipc* parameter indicates whether this server will accept new + connections via file descriptor passing. This works for `pyuv.Pipe` handles + only, and the user is required to call :meth:`Server.accept_connection` + whenever a new connection is pending. + + The *backlog* parameter specifies the listen backlog i.e the maximum number + of not yet accepted active opens to queue. To disable listening for new + connections (useful when *ipc* was set), set the backlog to ``None``. The return value is a :class:`Server` instance. """ @@ -322,13 +293,30 @@ def connections(self): """An iterator yielding the (transport, protocol) pairs for each connection.""" return self._connections.items() + def accept_connection(self, handle, ssl=False): + """Accept a new connection on *handle*. This method needs to be called + when a connection was passed via file descriptor passing.""" + self._on_new_connection(handle, None, ssl) + def _on_new_connection(self, handle, error, ssl): # Callback used with handle.listen(). - assert handle in self._handles + #assert handle in self._handles if error: self._log.warning('error {} in listen() callback', error) return - client = type(handle)(self._hub.loop) + # Pipes can listen for new connections but they can also accept handles + # of different types via file-descriptor passing. + if hasattr(handle, 'pending_handle_type'): + uvtype = handle.pending_handle_type() + handle_type = pyuv.TCP if uvtype == pyuv.UV_TCP \ + else pyuv.UDP if uvtype == pyuv.UV_UDP \ + else pyuv.Pipe + handle_args = (handle.ipc,) if hasattr(handle, 'ipc') \ + and handle_type is pyuv.Pipe else () + else: + handle_type = type(handle) + handle_args = () + client = handle_type(self._hub.loop, *handle_args) handle.accept(client) if self.max_connections is not None and len(self._connections) >= self.max_connections: self._log.warning('max connections reached, dropping new connection') @@ -344,7 +332,8 @@ def handle_connection(self, client, ssl): intended to be called directly. """ if ssl: - context = ssl if hasattr(ssl, 'set_ciphers') else ssl(client) if callable(ssl) \ + context = ssl if hasattr(ssl, 'set_ciphers') \ + else ssl(client) if callable(ssl) \ else create_default_context(True) transport = SslTransport(client, context, True) else: @@ -376,7 +365,7 @@ def connection_lost(self, transport, protocol, exc=None): """Called when a connection is lost.""" @switchpoint - def listen(self, address, ssl=False, family=0, flags=0, backlog=128): + def listen(self, address, ssl=False, family=0, flags=0, ipc=False, backlog=128): """Create a new transport, bind it to *address*, and start listening for new connections. @@ -384,8 +373,10 @@ def listen(self, address, ssl=False, family=0, flags=0, backlog=128): supported keyword arguments. """ handles = [] + handle_args = () if isinstance(address, six.string_types): handle_type = pyuv.Pipe + handle_args = (ipc,) addresses = [address] elif isinstance(address, tuple): handle_type = pyuv.TCP @@ -398,22 +389,23 @@ def listen(self, address, ssl=False, family=0, flags=0, backlog=128): else: raise TypeError('expecting a string, tuple or pyuv.Stream') for addr in addresses: - handle = handle_type(self._hub.loop) + handle = handle_type(self._hub.loop, *handle_args) try: - if handle_type is pyuv.Pipe and _use_af_unix(addr): - _af_unix_helper(handle, addr, 'bind') - else: - handle.bind(addr) + if compat.pyuv_pipe_helper(handle, handle_args, 'bind', addr): + handles.append(handle) + break + handle.bind(addr) except pyuv.error.UVError as e: self._log.warning('bind error {!r}, skipping {}', e[0], saddr(addr)) continue handles.append(handle) addresses = [] for handle in handles: - callback = functools.partial(self._on_new_connection, ssl=ssl) - handle.listen(callback, backlog) - addr = handle.getsockname() - self._log.debug('listen on {}', saddr(addr)) + if backlog is not None: + callback = functools.partial(self._on_new_connection, ssl=ssl) + handle.listen(callback, backlog) + addr = handle.getsockname() + self._log.debug('listen on {}', saddr(addr)) addresses.append(addr) self._handles += handles self._addresses += addresses diff --git a/lib/gruvi/ssl.py b/lib/gruvi/ssl.py index 3467ee5..d2dfe91 100644 --- a/lib/gruvi/ssl.py +++ b/lib/gruvi/ssl.py @@ -274,6 +274,7 @@ def __init__(self, handle, context, server_side, server_hostname=None, self._do_handshake_on_connect = do_handshake_on_connect self._close_on_unwrap = close_on_unwrap self._write_backlog = [] + self._handle_backlog = [] self._ssl_active = Event() def start(self, protocol): @@ -308,11 +309,13 @@ def get_extra_info(self, name, default=None): else: return super(SslTransport, self).get_extra_info(name, default) - def write(self, data): + def write(self, data, handle=None): # Write *data* to the transport. if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data: expecting a bytes-like instance, got {!r}" .format(type(data).__name__)) + if handle is not None and not isinstance(self._handle, pyuv.Pipe): + raise ValueError('handle: can only be sent over pyuv.Pipe') if self._error: raise compat.saved_exc(self._error) elif self._closing or self._handle.closed: @@ -320,6 +323,8 @@ def write(self, data): elif len(data) == 0: return self._write_backlog.append([data, 0]) + if handle: + self._handle_backlog.append(handle) self._process_write_backlog() def _process_write_backlog(self): @@ -339,7 +344,10 @@ def _process_write_backlog(self): # Write the ssl data that came out of the SSL pipe to the handle. # Note that flow control is done at the record level data. for chunk in ssldata: - super(SslTransport, self).write(chunk) + if self._handle_backlog: + super(SslTransport, self).write(chunk, self._handle_backlog.pop(0)) + else: + super(SslTransport, self).write(chunk) self._closing = saved if offset < len(data): self._write_backlog[0][1] = offset diff --git a/lib/gruvi/transports.py b/lib/gruvi/transports.py index 3ffe626..a31585b 100644 --- a/lib/gruvi/transports.py +++ b/lib/gruvi/transports.py @@ -284,18 +284,23 @@ def _on_write_complete(self, handle, error): self._maybe_resume_protocol() self._maybe_close() - def write(self, data): + def write(self, data, handle=None): """Write *data* to the transport.""" if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data: expecting a bytes-like instance, got {!r}" .format(type(data).__name__)) + if handle is not None and not isinstance(self._handle, pyuv.Pipe): + raise ValueError('handle: can only be sent over pyuv.Pipe') self._check_status() if not self._writable: raise TransportError('transport is not writable') if self._closing: raise TransportError('transport is closing') try: - self._handle.write(data, self._on_write_complete) + if handle: + self._handle.write(data, self._on_write_complete, handle) + else: + self._handle.write(data, self._on_write_complete) except pyuv.error.UVError as e: self._error = TransportError.from_errno(e.args[0]) self.abort() diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 94b3756..93d759c 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -11,14 +11,21 @@ import ssl import socket import unittest +import pyuv from unittest import SkipTest import gruvi +from gruvi.hub import switchpoint, get_hub from gruvi.stream import StreamProtocol from gruvi.endpoints import create_server, create_connection, getaddrinfo +from gruvi.endpoints import Client, Server from gruvi.transports import TransportError +from gruvi.protocols import Protocol +from gruvi.sync import Queue +from gruvi.util import delegate_method +from gruvi.fibers import spawn -from support import UnitTest +from support import UnitTest, socketpair class StreamProtocolNoSslHandshake(StreamProtocol): @@ -176,5 +183,180 @@ def get_resolve_timeout(self): self.assertRaises(gruvi.Timeout, getaddrinfo, 'localhost', timeout=0) +class IpcProtocol(Protocol): + """A very simple line based protocol that is used to test passing file + descriptors.""" + + def __init__(self, server=None, server_context=None): + self._server = server + self._server_context = server_context + self._queue = Queue() + self._transport = None + + @property + def transport(self): + return self._transport + + def connection_made(self, transport): + self._transport = transport + self._buffer = b'' + + def data_received(self, data): + self._buffer += data + p0 = p1 = 0 + while True: + p1 = self._buffer.find(b'\n', p0) + if p1 == -1: + break + command = self._buffer[p0:p1].decode('ascii') + self.command_received(command) + p0 = p1+1 + if p0: + self._buffer = self._buffer[p0:] + + def command_received(self, command): + if self._server is None: + self._queue.put(command) + elif command == 'handle': + handle = self._transport.get_extra_info('handle') + assert handle is not None + self._server.accept_connection(handle) + self.send('ok') + elif command == 'ssl_handle': + handle = self._transport.get_extra_info('handle') + assert handle is not None + self._server.accept_connection(handle, ssl=self._server_context) + self.send('ok') + elif command == 'ping': + self.send('pong') + elif command == 'type': + handle = self._transport.get_extra_info('handle') + assert handle is not None + htype = 'tcp' if isinstance(handle, pyuv.TCP) \ + else 'udp' if isinstance(handle, pyuv.UDP) \ + else 'pipe' + self.send(htype) + + def send(self, command, handle=None): + line = '{}\n'.format(command).encode('ascii') + if handle is not None: + self._transport.write(line, handle) + else: + self._transport.write(line) + + @switchpoint + def call(self, command, handle=None): + self.send(command, handle) + value = self._queue.get() + return value + + +class IpcServer(Server): + + def __init__(self, server_context=None): + def protocol_factory(): + return IpcProtocol(self, server_context) + super(IpcServer, self).__init__(protocol_factory) + + +class IpcClient(Client): + + def __init__(self): + super(IpcClient, self).__init__(IpcProtocol) + + protocol = Client.protocol + delegate_method(protocol, IpcProtocol.send) + delegate_method(protocol, IpcProtocol.call) + + +class TestIpc(UnitTest): + + def test_simple(self): + server = IpcServer() + pipe = self.pipename() + server.listen(pipe) + client = IpcClient() + client.connect(pipe) + self.assertEqual(client.call('ping'), 'pong') + client.close() + server.close() + + def test_pass_handle(self): + server = IpcServer() + pipe = self.pipename() + server.listen(pipe, ipc=True) + client = IpcClient() + client.connect(pipe, ipc=True) + self.assertEqual(client.call('type'), 'pipe') + s1, s2 = socketpair() + h2 = pyuv.Pipe(get_hub().loop) + h2.open(s2.fileno()) + client.call('handle', h2) + c1 = IpcClient() + c1.connect(s1.fileno()) + self.assertEqual(c1.call('ping'), 'pong') + self.assertEqual(c1.call('type'), 'tcp') + s1.close(); s2.close() + c1.close(); h2.close() + client.close(); server.close() + + def test_pass_handle_over_ssl(self): + server = IpcServer() + pipe = self.pipename() + server.listen(pipe, ipc=True, **self.ssl_s_args) + client = IpcClient() + client.connect(pipe, ipc=True, **self.ssl_cp_args) + self.assertEqual(client.call('type'), 'pipe') + s1, s2 = socketpair() + h2 = pyuv.Pipe(get_hub().loop) + h2.open(s2.fileno()) + client.call('handle', h2) + c1 = IpcClient() + c1.connect(s1.fileno()) + self.assertEqual(c1.call('ping'), 'pong') + self.assertEqual(c1.call('type'), 'tcp') + s1.close(); s2.close() + c1.close(); h2.close() + client.close(); server.close() + + def test_pass_ssl_handle(self): + server = IpcServer(self.ssl_s_args['ssl']) + pipe = self.pipename() + server.listen(pipe, ipc=True) + client = IpcClient() + client.connect(pipe, ipc=True) + self.assertEqual(client.call('type'), 'pipe') + s1, s2 = socketpair() + h2 = pyuv.Pipe(get_hub().loop) + h2.open(s2.fileno()) + client.call('ssl_handle', h2) + c1 = IpcClient() + c1.connect(s1.fileno(), **self.ssl_cp_args) + self.assertEqual(c1.call('ping'), 'pong') + self.assertEqual(c1.call('type'), 'tcp') + s1.close(); s2.close() + c1.close(); h2.close() + client.close(); server.close() + + def test_pass_ssl_handle_over_ssl(self): + server = IpcServer(self.ssl_s_args['ssl']) + pipe = self.pipename() + server.listen(pipe, ipc=True, **self.ssl_s_args) + client = IpcClient() + client.connect(pipe, ipc=True, **self.ssl_cp_args) + self.assertEqual(client.call('type'), 'pipe') + s1, s2 = socketpair() + h2 = pyuv.Pipe(get_hub().loop) + h2.open(s2.fileno()) + client.call('ssl_handle', h2) + c1 = IpcClient() + c1.connect(s1.fileno(), **self.ssl_cp_args) + self.assertEqual(c1.call('ping'), 'pong') + self.assertEqual(c1.call('type'), 'tcp') + s1.close(); s2.close() + c1.close(); h2.close() + client.close(); server.close() + + if __name__ == '__main__': unittest.main()