Skip to content

Commit

Permalink
Add missing load_default_certs() call.
Browse files Browse the repository at this point in the history
Fixes: celery#349
  • Loading branch information
moisesguimaraes committed Jan 26, 2021
1 parent 32d2b7e commit cbc53f5
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 56 deletions.
8 changes: 8 additions & 0 deletions amqp/transport.py
Expand Up @@ -535,6 +535,14 @@ def _wrap_socket_sni(self, sock, keyfile=None, certfile=None,
except AttributeError:
pass # ask forgiveness not permission

if ca_certs is None and context.verify_mode != ssl.CERT_NONE:
purpose = (
ssl.Purpose.CLIENT_AUTH
if server_side
else ssl.Purpose.SERVER_AUTH
)
context.load_default_certs(purpose)

sock = context.wrap_socket(**opts)
return sock

Expand Down
145 changes: 89 additions & 56 deletions t/unit/test_transport.py
@@ -1,6 +1,7 @@
import errno
import os
import re
import ssl
import socket
import struct
from struct import pack
Expand Down Expand Up @@ -639,112 +640,144 @@ def test_wrap_context(self):

def test_wrap_socket_sni(self):
# testing default values of _wrap_socket_sni()
sock = Mock()
with patch('ssl.SSLContext') as mock_ssl_context_class:
wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET
sock = Mock()
context = mock_ssl_context_class()
context.wrap_socket.return_value = sentinel.WRAPPED_SOCKET
ret = self.t._wrap_socket_sni(sock)

mock_ssl_context_class.load_cert_chain.assert_not_called()
mock_ssl_context_class.load_verify_locations.assert_not_called()
mock_ssl_context_class.set_ciphers.assert_not_called()
mock_ssl_context_class.verify_mode.assert_not_called()
wrap_socket_method_mock.assert_called_with(
sock=sock,
server_side=False,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
server_hostname=None
)
assert ret == sentinel.WRAPPED_SOCKET
context.load_cert_chain.assert_not_called()
context.load_verify_locations.assert_not_called()
context.set_ciphers.assert_not_called()
context.verify_mode.assert_not_called()

context.load_default_certs.assert_called_with(
ssl.Purpose.SERVER_AUTH
)
context.wrap_socket.assert_called_with(
sock=sock,
server_side=False,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
server_hostname=None
)
assert ret == sentinel.WRAPPED_SOCKET

def test_wrap_socket_sni_certfile(self):
# testing _wrap_socket_sni() with parameters certfile and keyfile
with patch('ssl.SSLContext') as mock_ssl_context_class:
load_cert_chain_method_mock = \
mock_ssl_context_class().load_cert_chain
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(
Mock(), keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE
sock, keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE
)

load_cert_chain_method_mock.assert_called_with(
sentinel.CERTFILE, sentinel.KEYFILE
)
context.load_default_certs.assert_called_with(
ssl.Purpose.SERVER_AUTH
)
context.load_cert_chain.assert_called_with(
sentinel.CERTFILE, sentinel.KEYFILE
)

def test_wrap_socket_ca_certs(self):
# testing _wrap_socket_sni() with parameter ca_certs
with patch('ssl.SSLContext') as mock_ssl_context_class:
load_verify_locations_method_mock = \
mock_ssl_context_class().load_verify_locations
self.t._wrap_socket_sni(Mock(), ca_certs=sentinel.CA_CERTS)
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(sock, ca_certs=sentinel.CA_CERTS)

load_verify_locations_method_mock.assert_called_with(sentinel.CA_CERTS)
context.load_default_certs.assert_not_called()
context.load_verify_locations.assert_called_with(sentinel.CA_CERTS)

def test_wrap_socket_ciphers(self):
# testing _wrap_socket_sni() with parameter ciphers
with patch('ssl.SSLContext') as mock_ssl_context_class:
set_ciphers_method_mock = mock_ssl_context_class().set_ciphers
self.t._wrap_socket_sni(Mock(), ciphers=sentinel.CIPHERS)
sock = Mock()
context = mock_ssl_context_class()
set_ciphers_method_mock = context.set_ciphers
self.t._wrap_socket_sni(sock, ciphers=sentinel.CIPHERS)

set_ciphers_method_mock.assert_called_with(sentinel.CIPHERS)
set_ciphers_method_mock.assert_called_with(sentinel.CIPHERS)

def test_wrap_socket_sni_cert_reqs(self):
# testing _wrap_socket_sni() with parameter cert_reqs
# testing _wrap_socket_sni() with parameter cert_reqs == ssl.CERT_NONE
with patch('ssl.SSLContext') as mock_ssl_context_class:
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(sock, cert_reqs=ssl.CERT_NONE)

context.load_default_certs.assert_not_called()
assert context.verify_mode == ssl.CERT_NONE

# testing _wrap_socket_sni() with parameter cert_reqs != ssl.CERT_NONE
with patch('ssl.SSLContext') as mock_ssl_context_class:
self.t._wrap_socket_sni(Mock(), cert_reqs=sentinel.CERT_REQS)
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(sock, cert_reqs=sentinel.CERT_REQS)

assert mock_ssl_context_class().verify_mode == sentinel.CERT_REQS
context.load_default_certs.assert_called_with(
ssl.Purpose.SERVER_AUTH
)
assert context.verify_mode == sentinel.CERT_REQS

def test_wrap_socket_sni_setting_sni_header(self):
# testing _wrap_socket_sni() without parameter server_hostname

# SSL module supports SNI
with patch('ssl.SSLContext') as mock_ssl_context_class, \
patch('ssl.HAS_SNI', new=True):
self.t._wrap_socket_sni(Mock())
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(sock)

assert mock_ssl_context_class().check_hostname is False
assert context.check_hostname is False

# SSL module does not support SNI
with patch('ssl.SSLContext') as mock_ssl_context_class, \
patch('ssl.HAS_SNI', new=False):
self.t._wrap_socket_sni(Mock())
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(sock)

assert mock_ssl_context_class().check_hostname is False
assert context.check_hostname is False

# testing _wrap_socket_sni() with parameter server_hostname
sock = Mock()

# SSL module supports SNI
with patch('ssl.SSLContext') as mock_ssl_context_class, \
patch('ssl.HAS_SNI', new=True):
# SSL module supports SNI
wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(
sock, server_hostname=sentinel.SERVER_HOSTNAME
)

wrap_socket_method_mock.assert_called_with(
sock=sock,
server_side=False,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
server_hostname=sentinel.SERVER_HOSTNAME
)
assert mock_ssl_context_class().check_hostname is True
context.wrap_socket.assert_called_with(
sock=sock,
server_side=False,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
server_hostname=sentinel.SERVER_HOSTNAME
)
assert context.check_hostname is True

# SSL module does not support SNI
with patch('ssl.SSLContext') as mock_ssl_context_class, \
patch('ssl.HAS_SNI', new=False):
# SSL module does not support SNI
wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
sock = Mock()
context = mock_ssl_context_class()
self.t._wrap_socket_sni(
sock, server_hostname=sentinel.SERVER_HOSTNAME
)
wrap_socket_method_mock.assert_called_with(
sock=sock,
server_side=False,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
server_hostname=sentinel.SERVER_HOSTNAME
)
assert mock_ssl_context_class().check_hostname is False

context.wrap_socket.assert_called_with(
sock=sock,
server_side=False,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
server_hostname=sentinel.SERVER_HOSTNAME
)
assert context.check_hostname is False

def test_shutdown_transport(self):
self.t.sock = None
Expand Down

0 comments on commit cbc53f5

Please sign in to comment.