diff --git a/pymongo/pool.py b/pymongo/pool.py index e2f9698212..d68ba238f2 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -15,7 +15,6 @@ import collections import contextlib import copy -import ipaddress import os import platform import socket @@ -61,20 +60,7 @@ from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI as _HAVE_SNI -from pymongo.ssl_support import IPADDR_SAFE as _IPADDR_SAFE -from pymongo.ssl_support import SSLError as _SSLError - - -# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are -# not permitted for SNI hostname. -def is_ip_address(address): - try: - ipaddress.ip_address(address) - return True - except (ValueError, UnicodeError): # noqa: B014 - return False - +from pymongo.ssl_support import HAS_SNI, SSLError try: from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl @@ -263,7 +249,7 @@ def _raise_connection_failure( msg = msg_prefix + msg if isinstance(error, socket.timeout): raise NetworkTimeout(msg) from error - elif isinstance(error, _SSLError) and "timed out" in str(error): + elif isinstance(error, SSLError) and "timed out" in str(error): # Eventlet does not distinguish TLS network timeouts from other # SSLErrors (https://github.com/eventlet/eventlet/issues/692). # Luckily, we can work around this limitation because the phrase @@ -924,7 +910,7 @@ def _raise_connection_failure(self, error): reason = ConnectionClosedReason.ERROR self.close_socket(reason) # SSLError from PyOpenSSL inherits directly from Exception. - if isinstance(error, (IOError, OSError, _SSLError)): + if isinstance(error, (IOError, OSError, SSLError)): _raise_connection_failure(self.address, error) else: raise @@ -1024,14 +1010,9 @@ def _configured_socket(address, options): if ssl_context is not None: host = address[0] try: - # According to RFC6066, section 3, IPv4 and IPv6 literals are - # not permitted for SNI hostname. - # Previous to Python 3.7 wrap_socket would blindly pass - # IP addresses as SNI hostname. - # https://bugs.python.org/issue32185 # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. - if _HAVE_SNI and (not is_ip_address(host) or _IPADDR_SAFE): + if HAS_SNI: sock = ssl_context.wrap_socket(sock, server_hostname=host) else: sock = ssl_context.wrap_socket(sock) @@ -1040,7 +1021,7 @@ def _configured_socket(address, options): # Raise _CertificateError directly like we do after match_hostname # below. raise - except (IOError, OSError, _SSLError) as exc: # noqa: B014 + except (IOError, OSError, SSLError) as exc: # noqa: B014 sock.close() # We raise AutoReconnect for transient and permanent SSL handshake # failures alike. Permanent handshake failures, like protocol @@ -1048,7 +1029,7 @@ def _configured_socket(address, options): _raise_connection_failure(address, exc, "SSL handshake failed: ") if ( ssl_context.verify_mode - and not getattr(ssl_context, "check_hostname", False) + and not ssl_context.check_hostname and not options.tls_allow_invalid_hostnames ): try: @@ -1336,7 +1317,7 @@ def connect(self): self.address, conn_id, ConnectionClosedReason.ERROR ) - if isinstance(error, (IOError, OSError, _SSLError)): + if isinstance(error, (IOError, OSError, SSLError)): _raise_connection_failure(self.address, error) raise diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 3736a4f381..1a57ff4f2b 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -70,6 +70,8 @@ _REVERSE_VERIFY_MAP = dict((value, key) for key, value in _VERIFY_MAP.items()) +# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are +# not permitted for SNI hostname. def _is_ip_address(address): try: _ip_address(address) @@ -104,8 +106,17 @@ def _call(self, call, *args, **kwargs): while True: try: return call(*args, **kwargs) - except _RETRY_ERRORS: - self.socket_checker.select(self, True, True, timeout) + except _RETRY_ERRORS as exc: + if isinstance(exc, _SSL.WantReadError): + want_read = True + want_write = False + elif isinstance(exc, _SSL.WantWriteError): + want_read = False + want_write = True + else: + want_read = True + want_write = True + self.socket_checker.select(self, want_read, want_write, timeout) if timeout and _time.monotonic() - start > timeout: raise _socket.timeout("timed out") continue diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 6adf629ad3..d1381ce0e4 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -14,8 +14,6 @@ """Support for SSL in PyMongo.""" -import sys - from pymongo.errors import ConfigurationError HAVE_SSL = True @@ -38,7 +36,7 @@ from ssl import CERT_NONE, CERT_REQUIRED HAS_SNI = _ssl.HAS_SNI - IPADDR_SAFE = _ssl.IS_PYOPENSSL or sys.version_info[:2] >= (3, 7) + IPADDR_SAFE = True SSLError = _ssl.SSLError def get_ssl_context( @@ -53,12 +51,10 @@ def get_ssl_context( """Create and return an SSLContext object.""" verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23) - # SSLContext.check_hostname was added in CPython 3.4. - if hasattr(ctx, "check_hostname"): - if verify_mode != CERT_NONE: - ctx.check_hostname = not allow_invalid_hostnames - else: - ctx.check_hostname = False + if verify_mode != CERT_NONE: + ctx.check_hostname = not allow_invalid_hostnames + else: + ctx.check_hostname = False if hasattr(ctx, "check_ocsp_endpoint"): ctx.check_ocsp_endpoint = not disable_ocsp_endpoint_check if hasattr(ctx, "options"): diff --git a/test/test_encryption.py b/test/test_encryption.py index 366c406b03..c0d278d577 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -145,13 +145,10 @@ def test_init_kms_tls_options(self): self.assertEqual(opts._kms_ssl_contexts, {}) opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}}) ctx = opts._kms_ssl_contexts["kmip"] - # On < 3.7 we check hostnames manually. - if sys.version_info[:2] >= (3, 7): - self.assertEqual(ctx.check_hostname, True) + self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) ctx = opts._kms_ssl_contexts["aws"] - if sys.version_info[:2] >= (3, 7): - self.assertEqual(ctx.check_hostname, True) + self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) opts = AutoEncryptionOpts( {}, @@ -159,8 +156,7 @@ def test_init_kms_tls_options(self): kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}, ) ctx = opts._kms_ssl_contexts["kmip"] - if sys.version_info[:2] >= (3, 7): - self.assertEqual(ctx.check_hostname, True) + self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) diff --git a/test/test_ssl.py b/test/test_ssl.py index 0c45275fac..9b58c2251b 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -65,8 +65,6 @@ CRL_PEM = os.path.join(CERT_PATH, "crl.pem") MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client" -_PY37PLUS = sys.version_info[:2] >= (3, 7) - # To fully test this start a mongod instance (built with SSL support) like so: # mongod --dbpath /path/to/data/directory --sslOnNormalPorts \ # --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \ @@ -306,10 +304,7 @@ def test_cert_ssl_validation_hostname_matching(self): ctx = get_ssl_context(None, None, None, None, False, True, False) self.assertFalse(ctx.check_hostname) ctx = get_ssl_context(None, None, None, None, False, False, False) - if _PY37PLUS or _HAVE_PYOPENSSL: - self.assertTrue(ctx.check_hostname) - else: - self.assertFalse(ctx.check_hostname) + self.assertTrue(ctx.check_hostname) response = self.client.admin.command(HelloCompat.LEGACY_CMD)