Skip to content

Commit

Permalink
Support Unix domain paths
Browse files Browse the repository at this point in the history
  • Loading branch information
ajdavis committed Nov 2, 2018
1 parent 3a8bac2 commit 18ac49b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog
Next Release
------------

Support for Unix domain paths with ``uds_path`` parameter.

The ``interactive_server()`` function now prepares the server to autorespond to
the ``getFreeMonitoringStatus`` command from the mongo shell.

Expand Down
94 changes: 81 additions & 13 deletions mockupdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
except ImportError:
from cStringIO import StringIO

try:
from urllib.parse import quote_plus
except ImportError:
# Python 2
from urllib import quote_plus

# Pure-Python bson lib vendored in from PyMongo 3.0.3.
from mockupdb import _bson
import mockupdb._bson.codec_options as _codec_options
Expand Down Expand Up @@ -432,7 +438,12 @@ def request_id(self):
@property
def client_port(self):
"""Client connection's TCP port."""
return self._client.getpeername()[1]
address = self._client.getpeername()
if isinstance(address, tuple):
return address[1]

# Maybe a Unix domain socket connection.
return 0

@property
def server(self):
Expand Down Expand Up @@ -1194,11 +1205,24 @@ class MockupDB(object):
if `auto_ismaster` is True, default 0.
- `max_wire_version`: the maxWireVersion to include in ismaster responses
if `auto_ismaster` is True, default 6.
- `uds_path`: a Unix domain socket path. MockupDB will attempt to delete
the path if it already exists.
"""
def __init__(self, port=None, verbose=False,
request_timeout=10, auto_ismaster=None,
ssl=False, min_wire_version=0, max_wire_version=6):
self._address = ('localhost', port)
ssl=False, min_wire_version=0, max_wire_version=6,
uds_path=None):
if port is not None and uds_path is not None:
raise TypeError(
("You can't pass port=%s and uds_path=%s,"
" pass only one or neither") % (port, uds_path))

self._uds_path = uds_path
if uds_path:
self._address = (uds_path, 0)
else:
self._address = ('localhost', port)

self._verbose = verbose
self._label = None
self._ssl = ssl
Expand Down Expand Up @@ -1231,8 +1255,12 @@ def __init__(self, port=None, verbose=False,

@_synchronized
def run(self):
"""Begin serving. Returns the bound port."""
self._listening_sock, self._address = bind_socket(self._address)
"""Begin serving. Returns the bound port, or 0 for domain socket."""
self._listening_sock, self._address = (
bind_domain_socket(self._address)
if self._uds_path
else bind_tcp_socket(self._address))

if self._ssl:
certfile = os.path.join(os.path.dirname(__file__), 'server.pem')
self._listening_sock = _ssl.wrap_socket(
Expand Down Expand Up @@ -1266,6 +1294,12 @@ def stop(self):
for thread in threads:
thread.join(10)

if self._uds_path:
try:
os.unlink(self._uds_path)
except OSError:
pass

def receives(self, *args, **kwargs):
"""Pop the next `Request` and assert it matches.
Expand Down Expand Up @@ -1481,7 +1515,7 @@ def address(self):
@property
def address_string(self):
"""The listening "host:port"."""
return '%s:%d' % self._address
return format_addr(self._address)

@property
def host(self):
Expand All @@ -1496,8 +1530,10 @@ def port(self):
@property
def uri(self):
"""Connection string to pass to `~pymongo.mongo_client.MongoClient`."""
assert self.host and self.port
uri = 'mongodb://%s:%s' % self._address
if self._uds_path:
uri = 'mongodb://%s' % (quote_plus(self._uds_path),)
else:
uri = 'mongodb://%s' % (format_addr(self._address),)
return uri + '/?ssl=true' if self._ssl else uri

@property
Expand Down Expand Up @@ -1555,7 +1591,7 @@ def _accept_loop(self):
if select.select([self._listening_sock.fileno()], [], [], 1):
client, client_addr = self._listening_sock.accept()
client.setblocking(True)
self._log('connection from %s:%s' % client_addr)
self._log('connection from %s' % format_addr(client_addr))
server_thread = threading.Thread(
target=functools.partial(
self._server_loop, client, client_addr))
Expand Down Expand Up @@ -1611,7 +1647,7 @@ def _server_loop(self, client, client_addr):
traceback.print_exc()
break

self._log('disconnected: %s:%d' % client_addr)
self._log('disconnected: %s' % format_addr(client_addr))
client.close()

def _log(self, msg):
Expand Down Expand Up @@ -1642,10 +1678,24 @@ def next(self):
__next__ = next

def __repr__(self):
if self._uds_path:
return 'MockupDB(uds_path=%s)' % (self._uds_path,)

return 'MockupDB(%s, %s)' % self._address


def bind_socket(address):
def format_addr(address):
"""Turn a TCP or Unix domain socket address into a string."""
if isinstance(address, tuple):
if address[1]:
return '%s:%d' % address
else:
return address[0]

return address


def bind_tcp_socket(address):
"""Takes (host, port) and returns (socket_object, (host, port)).
If the passed-in port is None, bind an unused port and return it.
Expand All @@ -1669,6 +1719,20 @@ def bind_socket(address):
raise socket.error('could not bind socket')


def bind_domain_socket(address):
"""Takes (socket path, 0) and returns (socket_object, (path, 0))."""
path, _ = address
try:
os.unlink(path)
except OSError:
pass

sock = socket.socket(socket.AF_UNIX)
sock.bind(path)
sock.listen(128)
return sock, (path, 0)


OPCODES = {OP_MSG: OpMsg,
OP_QUERY: OpQuery,
OP_INSERT: OpInsert,
Expand Down Expand Up @@ -1923,7 +1987,7 @@ def raise_args_err(message='bad arguments', error_class=TypeError):


def interactive_server(port=27017, verbose=True, all_ok=False, name='MockupDB',
ssl=False):
ssl=False, uds_path=None):
"""A `MockupDB` that the mongo shell can connect to.
Call `~.MockupDB.run` on the returned server, and clean it up with
Expand All @@ -1932,11 +1996,15 @@ def interactive_server(port=27017, verbose=True, all_ok=False, name='MockupDB',
If ``all_ok`` is True, replies {ok: 1} to anything unmatched by a specific
responder.
"""
if uds_path is not None:
port = None

server = MockupDB(port=port,
verbose=verbose,
request_timeout=int(1e6),
ssl=ssl,
auto_ismaster=True)
auto_ismaster=True,
uds_path=uds_path)
if all_ok:
server.autoresponds({})
server.autoresponds('whatsmyuri', you='localhost:12345')
Expand Down
24 changes: 24 additions & 0 deletions tests/test_mockupdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""Test MockupDB."""

import contextlib
import os
import ssl
import sys
import tempfile

if sys.version_info[0] < 3:
from io import BytesIO as StringIO
Expand Down Expand Up @@ -318,6 +320,28 @@ def test_wire_version(self):
self.assertEqual(ismaster['minWireVersion'], 1)
self.assertEqual(ismaster['maxWireVersion'], 42)

@unittest.skipIf(sys.platform == 'win32', 'Windows')
def test_unix_domain_socket(self):
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.sock')
tmp.close()
server = MockupDB(auto_ismaster={'maxWireVersion': 3},
uds_path=tmp.name)
server.run()
self.assertTrue(server.uri.endswith('.sock'),
'Expected URI "%s" to end with ".sock"' % (server.uri,))
self.assertEqual(server.host, tmp.name)
self.assertEqual(server.port, 0)
self.assertEqual(server.address, (tmp.name, 0))
self.assertEqual(server.address_string, tmp.name)
client = MongoClient(server.uri)
with going(client.test.command, {'foo': 1}) as future:
server.receives().ok()

response = future()
self.assertEqual(1, response['ok'])
server.stop()
self.assertFalse(os.path.exists(tmp.name))


class TestResponse(unittest.TestCase):
def test_ok(self):
Expand Down

0 comments on commit 18ac49b

Please sign in to comment.