diff --git a/oslo_messaging/_drivers/impl_rabbit.py b/oslo_messaging/_drivers/impl_rabbit.py index 86fed1f4b..c66daf7ed 100644 --- a/oslo_messaging/_drivers/impl_rabbit.py +++ b/oslo_messaging/_drivers/impl_rabbit.py @@ -477,7 +477,6 @@ def __init__(self, conf, url): # max retry-interval = 30 seconds self.interval_max = 30 - self._ssl_params = self._fetch_ssl_params() self._login_method = self.driver_conf.rabbit_login_method if url.virtual_host is not None: @@ -529,7 +528,8 @@ def __init__(self, conf, url): self.channel = None self.connection = kombu.connection.Connection( - self._url, ssl=self._ssl_params, login_method=self._login_method, + self._url, ssl=self._fetch_ssl_params(), + login_method=self._login_method, failover_strategy="shuffle") LOG.info(_LI('Connecting to AMQP server on %(hostname)s:%(port)d'), @@ -581,24 +581,24 @@ def _fetch_ssl_params(self): """Handles fetching what ssl params should be used for the connection (if any). """ - ssl_params = dict() - - # http://docs.python.org/library/ssl.html - ssl.wrap_socket - if self.driver_conf.kombu_ssl_version: - ssl_params['ssl_version'] = self.validate_ssl_version( - self.driver_conf.kombu_ssl_version) - if self.driver_conf.kombu_ssl_keyfile: - ssl_params['keyfile'] = self.driver_conf.kombu_ssl_keyfile - if self.driver_conf.kombu_ssl_certfile: - ssl_params['certfile'] = self.driver_conf.kombu_ssl_certfile - if self.driver_conf.kombu_ssl_ca_certs: - ssl_params['ca_certs'] = self.driver_conf.kombu_ssl_ca_certs - # We might want to allow variations in the - # future with this? - ssl_params['cert_reqs'] = ssl.CERT_REQUIRED - - # Return the extended behavior or just have the default behavior - return ssl_params or None + if self.driver_conf.rabbit_use_ssl: + ssl_params = dict() + + # http://docs.python.org/library/ssl.html - ssl.wrap_socket + if self.driver_conf.kombu_ssl_version: + ssl_params['ssl_version'] = self.validate_ssl_version( + self.driver_conf.kombu_ssl_version) + if self.driver_conf.kombu_ssl_keyfile: + ssl_params['keyfile'] = self.driver_conf.kombu_ssl_keyfile + if self.driver_conf.kombu_ssl_certfile: + ssl_params['certfile'] = self.driver_conf.kombu_ssl_certfile + if self.driver_conf.kombu_ssl_ca_certs: + ssl_params['ca_certs'] = self.driver_conf.kombu_ssl_ca_certs + # We might want to allow variations in the + # future with this? + ssl_params['cert_reqs'] = ssl.CERT_REQUIRED + return ssl_params or True + return False def ensure(self, error_callback, method, retry=None, timeout_is_error=True): diff --git a/oslo_messaging/tests/drivers/test_impl_rabbit.py b/oslo_messaging/tests/drivers/test_impl_rabbit.py index e60bd3b2a..df0f3b3df 100644 --- a/oslo_messaging/tests/drivers/test_impl_rabbit.py +++ b/oslo_messaging/tests/drivers/test_impl_rabbit.py @@ -13,6 +13,7 @@ # under the License. import datetime +import ssl import sys import threading import time @@ -77,6 +78,39 @@ def test_driver_load(self, fake_ensure, fake_reset): self.assertEqual(self.url, url) +class TestRabbitDriverLoadSSL(test_utils.BaseTestCase): + scenarios = [ + ('no_ssl', dict(options=dict(), expected=False)), + ('no_ssl_with_options', dict(options=dict(kombu_ssl_version='TLSv1'), + expected=False)), + ('just_ssl', dict(options=dict(rabbit_use_ssl=True), + expected=True)), + ('ssl_with_options', dict(options=dict(rabbit_use_ssl=True, + kombu_ssl_version='TLSv1', + kombu_ssl_keyfile='foo', + kombu_ssl_certfile='bar', + kombu_ssl_ca_certs='foobar'), + expected=dict(ssl_version=3, + keyfile='foo', + certfile='bar', + ca_certs='foobar', + cert_reqs=ssl.CERT_REQUIRED))), + ] + + @mock.patch('oslo_messaging._drivers.impl_rabbit.Connection.ensure') + @mock.patch('kombu.connection.Connection') + def test_driver_load(self, connection_klass, fake_ensure): + self.config(group="oslo_messaging_rabbit", **self.options) + transport = oslo_messaging.get_transport(self.conf, + 'kombu+memory:////') + self.addCleanup(transport.cleanup) + + transport._driver._get_connection() + connection_klass.assert_called_once_with( + 'memory:///', ssl=self.expected, + login_method='AMQPLAIN', failover_strategy="shuffle") + + class TestRabbitIterconsume(test_utils.BaseTestCase): def test_iterconsume_timeout(self):