Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 7 additions & 26 deletions pymongo/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import collections
import contextlib
import copy
import ipaddress
import os
import platform
import socket
Expand Down Expand Up @@ -61,20 +60,7 @@
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI as _HAVE_SNI
from pymongo.ssl_support import IPADDR_SAFE as _IPADDR_SAFE
from pymongo.ssl_support import SSLError as _SSLError


# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
# not permitted for SNI hostname.
def is_ip_address(address):
try:
ipaddress.ip_address(address)
return True
except (ValueError, UnicodeError): # noqa: B014
return False

from pymongo.ssl_support import HAS_SNI, SSLError

try:
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
Expand Down Expand Up @@ -263,7 +249,7 @@ def _raise_connection_failure(
msg = msg_prefix + msg
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, _SSLError) and "timed out" in str(error):
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
Expand Down Expand Up @@ -924,7 +910,7 @@ def _raise_connection_failure(self, error):
reason = ConnectionClosedReason.ERROR
self.close_socket(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, _SSLError)):
if isinstance(error, (IOError, OSError, SSLError)):
_raise_connection_failure(self.address, error)
else:
raise
Expand Down Expand Up @@ -1024,14 +1010,9 @@ def _configured_socket(address, options):
if ssl_context is not None:
host = address[0]
try:
# According to RFC6066, section 3, IPv4 and IPv6 literals are
# not permitted for SNI hostname.
# Previous to Python 3.7 wrap_socket would blindly pass
# IP addresses as SNI hostname.
# https://bugs.python.org/issue32185
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if _HAVE_SNI and (not is_ip_address(host) or _IPADDR_SAFE):
if HAS_SNI:
sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
sock = ssl_context.wrap_socket(sock)
Expand All @@ -1040,15 +1021,15 @@ def _configured_socket(address, options):
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (IOError, OSError, _SSLError) as exc: # noqa: B014
except (IOError, OSError, SSLError) as exc: # noqa: B014
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
_raise_connection_failure(address, exc, "SSL handshake failed: ")
if (
ssl_context.verify_mode
and not getattr(ssl_context, "check_hostname", False)
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
Expand Down Expand Up @@ -1336,7 +1317,7 @@ def connect(self):
self.address, conn_id, ConnectionClosedReason.ERROR
)

if isinstance(error, (IOError, OSError, _SSLError)):
if isinstance(error, (IOError, OSError, SSLError)):
_raise_connection_failure(self.address, error)

raise
Expand Down
15 changes: 13 additions & 2 deletions pymongo/pyopenssl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
_REVERSE_VERIFY_MAP = dict((value, key) for key, value in _VERIFY_MAP.items())


# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
# not permitted for SNI hostname.
def _is_ip_address(address):
try:
_ip_address(address)
Expand Down Expand Up @@ -104,8 +106,17 @@ def _call(self, call, *args, **kwargs):
while True:
try:
return call(*args, **kwargs)
except _RETRY_ERRORS:
self.socket_checker.select(self, True, True, timeout)
except _RETRY_ERRORS as exc:
if isinstance(exc, _SSL.WantReadError):
want_read = True
want_write = False
elif isinstance(exc, _SSL.WantWriteError):
want_read = False
want_write = True
else:
want_read = True
want_write = True
self.socket_checker.select(self, want_read, want_write, timeout)
if timeout and _time.monotonic() - start > timeout:
raise _socket.timeout("timed out")
continue
Expand Down
14 changes: 5 additions & 9 deletions pymongo/ssl_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""Support for SSL in PyMongo."""

import sys

from pymongo.errors import ConfigurationError

HAVE_SSL = True
Expand All @@ -38,7 +36,7 @@
from ssl import CERT_NONE, CERT_REQUIRED

HAS_SNI = _ssl.HAS_SNI
IPADDR_SAFE = _ssl.IS_PYOPENSSL or sys.version_info[:2] >= (3, 7)
IPADDR_SAFE = True
SSLError = _ssl.SSLError

def get_ssl_context(
Expand All @@ -53,12 +51,10 @@ def get_ssl_context(
"""Create and return an SSLContext object."""
verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED
ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23)
# SSLContext.check_hostname was added in CPython 3.4.
if hasattr(ctx, "check_hostname"):
if verify_mode != CERT_NONE:
ctx.check_hostname = not allow_invalid_hostnames
else:
ctx.check_hostname = False
if verify_mode != CERT_NONE:
ctx.check_hostname = not allow_invalid_hostnames
else:
ctx.check_hostname = False
if hasattr(ctx, "check_ocsp_endpoint"):
ctx.check_ocsp_endpoint = not disable_ocsp_endpoint_check
if hasattr(ctx, "options"):
Expand Down
10 changes: 3 additions & 7 deletions test/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,18 @@ def test_init_kms_tls_options(self):
self.assertEqual(opts._kms_ssl_contexts, {})
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
ctx = opts._kms_ssl_contexts["kmip"]
# On < 3.7 we check hostnames manually.
if sys.version_info[:2] >= (3, 7):
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
ctx = opts._kms_ssl_contexts["aws"]
if sys.version_info[:2] >= (3, 7):
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
opts = AutoEncryptionOpts(
{},
"k.d",
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
)
ctx = opts._kms_ssl_contexts["kmip"]
if sys.version_info[:2] >= (3, 7):
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)


Expand Down
7 changes: 1 addition & 6 deletions test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@
CRL_PEM = os.path.join(CERT_PATH, "crl.pem")
MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client"

_PY37PLUS = sys.version_info[:2] >= (3, 7)

# To fully test this start a mongod instance (built with SSL support) like so:
# mongod --dbpath /path/to/data/directory --sslOnNormalPorts \
# --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \
Expand Down Expand Up @@ -306,10 +304,7 @@ def test_cert_ssl_validation_hostname_matching(self):
ctx = get_ssl_context(None, None, None, None, False, True, False)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, False, False)
if _PY37PLUS or _HAVE_PYOPENSSL:
self.assertTrue(ctx.check_hostname)
else:
self.assertFalse(ctx.check_hostname)
self.assertTrue(ctx.check_hostname)

response = self.client.admin.command(HelloCompat.LEGACY_CMD)

Expand Down