View
@@ -3,13 +3,15 @@
import pytest
import socket
import warnings
from case import ContextMock, Mock, call
from amqp import Connection
from amqp import spec
from amqp.connection import SSLError
from amqp.exceptions import ConnectionError, NotFound, ResourceError
from amqp.five import items
from amqp.sasl import SASL, AMQPLAIN, PLAIN
from amqp.transport import TCPTransport
@@ -22,16 +24,35 @@ def setup_conn(self):
self.conn = Connection(
frame_handler=self.frame_handler,
frame_writer=self.frame_writer,
authentication=AMQPLAIN('foo', 'bar'),
)
self.conn.Channel = Mock(name='Channel')
self.conn.Transport = Mock(name='Transport')
self.conn.transport = self.conn.Transport.return_value
self.conn.send_method = Mock(name='send_method')
self.conn.frame_writer = Mock(name='frame_writer')
def test_login_response(self):
self.conn = Connection(login_response='foo')
assert self.conn.login_response == 'foo'
def test_sasl_authentication(self):
authentication = SASL()
self.conn = Connection(authentication=authentication)
assert self.conn.authentication == (authentication,)
def test_sasl_authentication_iterable(self):
authentication = SASL()
self.conn = Connection(authentication=(authentication,))
assert self.conn.authentication == (authentication,)
def test_amqplain(self):
self.conn = Connection(userid='foo', password='bar')
assert isinstance(self.conn.authentication[1], AMQPLAIN)
assert self.conn.authentication[1].username == 'foo'
assert self.conn.authentication[1].password == 'bar'
def test_plain(self):
self.conn = Connection(userid='foo', password='bar')
assert isinstance(self.conn.authentication[2], PLAIN)
assert self.conn.authentication[2].username == 'foo'
assert self.conn.authentication[2].password == 'bar'
def test_enter_exit(self):
self.conn.connect = Mock(name='connect')
@@ -68,47 +89,80 @@ def test_connect__already_connected(self):
callback.assert_called_with()
def test_on_start(self):
self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z', 'en_US en_GB')
self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN',
'en_US en_GB')
assert self.conn.version_major == 3
assert self.conn.version_minor == 4
assert self.conn.server_properties == {'foo': 'bar'}
assert self.conn.mechanisms == ['x', 'y', 'z']
assert self.conn.mechanisms == [b'x', b'y', b'z',
b'AMQPLAIN', b'PLAIN']
assert self.conn.locales == ['en_US', 'en_GB']
self.conn.send_method.assert_called_with(
spec.Connection.StartOk, 'FsSs', (
self.conn.client_properties, self.conn.login_method,
self.conn.login_response, self.conn.locale,
self.conn.client_properties, b'AMQPLAIN',
self.conn.authentication[0].start(self.conn), self.conn.locale,
),
)
def test_missing_credentials(self):
with pytest.raises(ValueError):
self.conn = Connection(userid=None, password=None)
with pytest.raises(ValueError):
self.conn = Connection(password=None)
def test_mechanism_mismatch(self):
with pytest.raises(ConnectionError):
self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z',
'en_US en_GB')
def test_login_method_response(self):
# An old way of doing things.:
login_method, login_response = b'foo', b'bar'
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.conn = Connection(login_method=login_method,
login_response=login_response)
self.conn.send_method = Mock(name='send_method')
self.conn._on_start(3, 4, {'foo': 'bar'}, login_method,
'en_US en_GB')
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
self.conn.send_method.assert_called_with(
spec.Connection.StartOk, 'FsSs', (
self.conn.client_properties, login_method,
login_response, self.conn.locale,
),
)
def test_on_start__consumer_cancel_notify(self):
self.conn._on_start(
3, 4, {'capabilities': {'consumer_cancel_notify': 1}},
'', '',
b'AMQPLAIN', '',
)
cap = self.conn.client_properties['capabilities']
assert cap['consumer_cancel_notify']
def test_on_start__connection_blocked(self):
self.conn._on_start(
3, 4, {'capabilities': {'connection.blocked': 1}},
'', '',
b'AMQPLAIN', '',
)
cap = self.conn.client_properties['capabilities']
assert cap['connection.blocked']
def test_on_start__authentication_failure_close(self):
self.conn._on_start(
3, 4, {'capabilities': {'authentication_failure_close': 1}},
'', '',
b'AMQPLAIN', '',
)
cap = self.conn.client_properties['capabilities']
assert cap['authentication_failure_close']
def test_on_start__authentication_failure_close__disabled(self):
self.conn._on_start(
3, 4, {'capabilities': {}},
'', '',
b'AMQPLAIN', '',
)
assert 'capabilities' not in self.conn.client_properties
View
@@ -0,0 +1,149 @@
from __future__ import absolute_import, unicode_literals
import contextlib
import socket
from io import BytesIO
from case import Mock, patch, call
import pytest
import sys
from amqp import sasl
from amqp.serialization import _write_table
class test_SASL:
def test_sasl_notimplemented(self):
mech = sasl.SASL()
with pytest.raises(NotImplementedError):
mech.mechanism
with pytest.raises(NotImplementedError):
mech.start(None)
def test_plain(self):
username, password = 'foo', 'bar'
mech = sasl.PLAIN(username, password)
response = mech.start(None)
assert isinstance(response, bytes)
assert response.split(b'\0') == \
[b'', username.encode('utf-8'), password.encode('utf-8')]
def test_amqplain(self):
username, password = 'foo', 'bar'
mech = sasl.AMQPLAIN(username, password)
response = mech.start(None)
assert isinstance(response, bytes)
login_response = BytesIO()
_write_table({b'LOGIN': username, b'PASSWORD': password},
login_response.write, [])
expected_response = login_response.getvalue()[4:]
assert response == expected_response
def test_gssapi_missing(self):
gssapi = sys.modules.pop('gssapi', None)
GSSAPI = sasl._get_gssapi_mechanism()
with pytest.raises(NotImplementedError):
GSSAPI()
if gssapi is not None:
sys.modules['gssapi'] = gssapi
@contextlib.contextmanager
def fake_gssapi(self):
orig_gssapi = sys.modules.pop('gssapi', None)
orig_gssapi_raw = sys.modules.pop('gssapi.raw', None)
orig_gssapi_raw_misc = sys.modules.pop('gssapi.raw.misc', None)
gssapi = sys.modules['gssapi'] = Mock()
sys.modules['gssapi.raw'] = gssapi.raw
sys.modules['gssapi.raw.misc'] = gssapi.raw.misc
class GSSError(Exception):
pass
gssapi.raw.misc.GSSError = GSSError
try:
yield gssapi
finally:
if orig_gssapi is None:
del sys.modules['gssapi']
else:
sys.modules['gssapi'] = orig_gssapi
if orig_gssapi_raw is None:
del sys.modules['gssapi.raw']
else:
sys.modules['gssapi.raw'] = orig_gssapi_raw
if orig_gssapi_raw_misc is None:
del sys.modules['gssapi.raw.misc']
else:
sys.modules['gssapi.raw.misc'] = orig_gssapi_raw_misc
@patch('socket.gethostbyaddr')
def test_gssapi_rdns(self, gethostbyaddr):
with self.fake_gssapi() as gssapi:
connection = Mock()
connection.transport.sock.getpeername.return_value = ('192.0.2.0',
5672)
connection.transport.sock.family = socket.AF_INET
gethostbyaddr.return_value = ('broker.example.org', (), ())
GSSAPI = sasl._get_gssapi_mechanism()
mech = GSSAPI(rdns=True)
mech.start(connection)
connection.transport.sock.getpeername.assert_called()
gethostbyaddr.assert_called_with('192.0.2.0')
gssapi.Name.assert_called_with(b'amqp@broker.example.org',
gssapi.NameType.hostbased_service)
def test_gssapi_no_rdns(self):
with self.fake_gssapi() as gssapi:
connection = Mock()
connection.transport.host = 'broker.example.org'
GSSAPI = sasl._get_gssapi_mechanism()
mech = GSSAPI()
mech.start(connection)
gssapi.Name.assert_called_with(b'amqp@broker.example.org',
gssapi.NameType.hostbased_service)
def test_gssapi_step_without_client_name(self):
with self.fake_gssapi() as gssapi:
context = Mock()
context.step.return_value = b'secrets'
name = Mock()
gssapi.SecurityContext.return_value = context
gssapi.Name.return_value = name
connection = Mock()
connection.transport.host = 'broker.example.org'
GSSAPI = sasl._get_gssapi_mechanism()
mech = GSSAPI()
response = mech.start(connection)
gssapi.SecurityContext.assert_called_with(name=name, creds=None)
context.step.assert_called_with(None)
assert response == b'secrets'
def test_gssapi_step_with_client_name(self):
with self.fake_gssapi() as gssapi:
context = Mock()
context.step.return_value = b'secrets'
client_name, service_name, credentials = Mock(), Mock(), Mock()
gssapi.SecurityContext.return_value = context
gssapi.Credentials.return_value = credentials
gssapi.Name.side_effect = [client_name, service_name]
connection = Mock()
connection.transport.host = 'broker.example.org'
GSSAPI = sasl._get_gssapi_mechanism()
mech = GSSAPI(client_name='amqp-client/client.example.org')
response = mech.start(connection)
gssapi.Name.assert_has_calls([
call(b'amqp-client/client.example.org'),
call(b'amqp@broker.example.org',
gssapi.NameType.hostbased_service)])
gssapi.Credentials.assert_called_with(name=client_name)
gssapi.SecurityContext.assert_called_with(name=service_name,
creds=credentials)
context.step.assert_called_with(None)
assert response == b'secrets'
View
@@ -323,7 +323,7 @@ def test_setup_transport(self):
self.t.sock.do_handshake.assert_called_with()
assert self.t._quick_recv is self.t.sock.read
@patch('ssl.wrap_socket', create=True)
@patch('ssl.wrap_socket')
def test_wrap_socket(self, wrap_socket):
sock = Mock()
self.t._wrap_context = Mock()