Skip to content

Commit

Permalink
Prefer SSL connections by default (MagicStack#660)
Browse files Browse the repository at this point in the history
Switch the default SSL mode from 'disabled' to 'prefer'.  This matches
libpq's behavior and is a sensible thing to do.

Fixes: MagicStack#654
  • Loading branch information
elprans authored Nov 29, 2020
1 parent ddadce9 commit 16183aa
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 37 deletions.
19 changes: 7 additions & 12 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
passfile=passfile)

addrs = []
have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
# UNIX socket name
Expand All @@ -389,6 +390,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
# TCP host/port
addrs.append((h, p))
have_tcp_addrs = True

if not addrs:
raise ValueError(
Expand All @@ -397,6 +399,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None:
ssl = os.getenv('PGSSLMODE')

if ssl is None and have_tcp_addrs:
ssl = 'prefer'

# ssl_is_advisory is only allowed to come from the sslmode parameter.
ssl_is_advisory = None
if isinstance(ssl, str):
Expand Down Expand Up @@ -435,14 +440,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if sslmode <= SSLMODES['require']:
ssl.verify_mode = ssl_module.CERT_NONE
ssl_is_advisory = sslmode <= SSLMODES['prefer']

if ssl:
for addr in addrs:
if isinstance(addr, str):
# UNIX socket
raise exceptions.InterfaceError(
'`ssl` parameter can only be enabled for TCP addresses, '
'got a UNIX socket path: {!r}'.format(addr))
elif ssl is True:
ssl = ssl_module.create_default_context()

if server_settings is not None and (
not isinstance(server_settings, dict) or
Expand Down Expand Up @@ -542,9 +541,6 @@ def connection_lost(self, exc):
async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):

if ssl_context is True:
ssl_context = ssl_module.create_default_context()

tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
ssl_context, ssl_is_advisory),
Expand Down Expand Up @@ -625,7 +621,6 @@ async def _connect_addr(

if isinstance(addr, str):
# UNIX socket
assert not params.ssl
connector = loop.create_unix_connection(proto_factory, addr)
elif params.ssl:
connector = _create_ssl_connection(
Expand Down
26 changes: 25 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,28 @@ async def connect(dsn=None, *,
Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
require an SSL connection. If ``True``, a default SSL context
returned by `ssl.create_default_context() <create_default_context_>`_
will be used.
will be used. The value can also be one of the following strings:
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
if SSL connection fails
- ``'allow'`` - currently equivalent to ``'prefer'``
- ``'require'`` - only try an SSL connection. Certificate
verifiction errors are ignored
- ``'verify-ca'`` - only try an SSL connection, and verify
that the server certificate is issued by a trusted certificate
authority (CA)
- ``'verify-full'`` - only try an SSL connection, verify
that the server certificate is issued by a trusted CA and
that the requested server host name matches that in the
certificate.
The default is ``'prefer'``: try an SSL connection and fallback to
non-SSL connection if that fails.
.. note::
*ssl* is ignored for Unix domain socket communication.
:param dict server_settings:
An optional dict of server runtime parameters. Refer to
Expand Down Expand Up @@ -1926,6 +1947,9 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
.. versionchanged:: 0.22.0
The *ssl* argument now defaults to ``'prefer'``.
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
Expand Down
48 changes: 24 additions & 24 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
'result': ([('host', 123)], {
'user': 'user',
'password': 'passw',
'database': 'testdb'})
'database': 'testdb',
'ssl': True,
'ssl_is_advisory': True})
},

{
Expand Down Expand Up @@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
'user': 'user3',
'password': '123123',
'database': 'abcdef',
'ssl': ssl.SSLContext,
'ssl': True,
'ssl_is_advisory': True})
},

Expand Down Expand Up @@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
'user': 'me',
'password': 'ask',
'database': 'db',
'ssl': ssl.SSLContext,
'ssl': True,
'ssl_is_advisory': False})
},

Expand Down Expand Up @@ -545,6 +547,7 @@ class TestConnectParams(tb.TestCase):
{
'user': 'user',
'database': 'user',
'ssl': None
}
)
},
Expand Down Expand Up @@ -574,7 +577,9 @@ class TestConnectParams(tb.TestCase):
('localhost', 5433)
], {
'user': 'spam',
'database': 'db'
'database': 'db',
'ssl': True,
'ssl_is_advisory': True
}
)
},
Expand Down Expand Up @@ -617,7 +622,7 @@ def run_testcase(self, testcase):
password = testcase.get('password')
passfile = testcase.get('passfile')
database = testcase.get('database')
ssl = testcase.get('ssl')
sslmode = testcase.get('ssl')
server_settings = testcase.get('server_settings')

expected = testcase.get('result')
Expand All @@ -640,21 +645,26 @@ def run_testcase(self, testcase):

addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=ssl,
passfile=passfile, database=database, ssl=sslmode,
connect_timeout=None, server_settings=server_settings)

params = {k: v for k, v in params._asdict().items()
if v is not None}
params = {
k: v for k, v in params._asdict().items()
if v is not None or (expected is not None and k in expected[1])
}

if isinstance(params.get('ssl'), ssl.SSLContext):
params['ssl'] = True

result = (addrs, params)

if expected is not None:
for k, v in expected[1].items():
# If `expected` contains a type, allow that to "match" any
# instance of that type tyat `result` may contain. We need
# this because different SSLContexts don't compare equal.
if isinstance(v, type) and isinstance(result[1].get(k), v):
result[1][k] = v
if 'ssl' not in expected[1]:
# Avoid the hassle of specifying the default SSL mode
# unless explicitly tested for.
params.pop('ssl', None)
params.pop('ssl_is_advisory', None)

self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))

def test_test_connect_params_environ(self):
Expand Down Expand Up @@ -1063,16 +1073,6 @@ async def verify_fails(sslmode):
await verify_fails('verify-ca')
await verify_fails('verify-full')

async def test_connection_ssl_unix(self):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)

with self.assertRaisesRegex(asyncpg.InterfaceError,
'can only be enabled for TCP addresses'):
await self.connect(
host='/tmp',
ssl=ssl_context)

async def test_connection_implicit_host(self):
conn_spec = self.get_connection_spec()
con = await asyncpg.connect(
Expand Down

0 comments on commit 16183aa

Please sign in to comment.