Skip to content

Commit

Permalink
Adding two tests, for Connection.collect and Transport.close.
Browse files Browse the repository at this point in the history
  • Loading branch information
VinayGValsaraj authored and auvipy committed Dec 12, 2021
1 parent 1cf468c commit d4c879f
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
32 changes: 14 additions & 18 deletions amqp/connection.py
Expand Up @@ -466,24 +466,20 @@ def connected(self):
return self._transport and self._transport.connected

def collect(self):
try:
if self._transport:
self._transport.close()

if self.channels:
# Copy all the channels except self since the channels
# dictionary changes during the collection process.
channels = [
ch for ch in self.channels.values()
if ch is not self
]

for ch in channels:
ch.collect()
except OSError:
pass # connection already closed on the other end
finally:
self._transport = self.connection = self.channels = None
if self._transport:
self._transport.close()

if self.channels:
# Copy all the channels except self since the channels
# dictionary changes during the collection process.
channels = [
ch for ch in self.channels.values()
if ch is not self
]

for ch in channels:
ch.collect()
self._transport = self.connection = self.channels = None

def _get_free_channel_id(self):
try:
Expand Down
11 changes: 7 additions & 4 deletions amqp/transport.py
Expand Up @@ -276,7 +276,10 @@ def close(self):
# Call shutdown first to make sure that pending messages
# reach the AMQP broker if the program exits after
# calling this method.
self.sock.shutdown(socket.SHUT_RDWR)
try:
self.sock.shutdown(socket.SHUT_RDWR)
except OSError:
pass
self.sock.close()
self.sock = None
self.connected = False
Expand Down Expand Up @@ -525,8 +528,8 @@ def _wrap_socket_sni(self, sock, keyfile=None, certfile=None,
context.load_verify_locations(ca_certs)
if ciphers is not None:
context.set_ciphers(ciphers)
# Set SNI headers if supported.
# Must set context.check_hostname before setting context.verify_mode
# Set SNI headers if supported.
# Must set context.check_hostname before setting context.verify_mode
# to avoid setting context.verify_mode=ssl.CERT_NONE while
# context.check_hostname is still True (the default value in context
# if client-side) which results in the following exception:
Expand All @@ -539,7 +542,7 @@ def _wrap_socket_sni(self, sock, keyfile=None, certfile=None,
except AttributeError:
pass # ask forgiveness not permission

# See note above re: ordering for context.check_hostname and
# See note above re: ordering for context.check_hostname and
# context.verify_mode assignments.
if cert_reqs is not None:
context.verify_mode = cert_reqs
Expand Down
13 changes: 10 additions & 3 deletions t/unit/test_connection.py
Expand Up @@ -323,10 +323,17 @@ def test_collect(self):
channel.collect.assert_called_with()
assert self.conn._transport is None

def test_collect__channel_raises_socket_error(self):
self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
self.conn.channels[1].collect.side_effect = socket.error()
def test_collect__transport_socket_raises_os_error(self):
self.conn.transport = TCPTransport('localhost:5672')
sock = self.conn.transport.sock = Mock(name='sock')
channel = Mock(name='c1')
self.conn.channels = {1: channel}
sock.shutdown.side_effect = OSError
self.conn.collect()
channel.collect.assert_called_with()
sock.close.assert_called_with()
assert self.conn._transport is None
assert self.conn.channels is None

def test_collect_no_transport(self):
self.conn = Connection()
Expand Down
11 changes: 9 additions & 2 deletions t/unit/test_transport.py
Expand Up @@ -282,6 +282,13 @@ def test_close(self):
self.t.close()
assert self.t.sock is None and self.t.connected is False

def test_close_os_error(self):
sock = self.t.sock = Mock()
sock.shutdown.side_effect = OSError
self.t.close()
sock.close.assert_called_with()
assert self.t.sock is None and self.t.connected is False

def test_read_frame__timeout(self):
self.t._read = Mock()
self.t._read.side_effect = socket.timeout()
Expand Down Expand Up @@ -719,7 +726,7 @@ def test_wrap_socket_sni_cert_reqs(self):
)
assert context.verify_mode == sentinel.CERT_REQS

# testing context creation inside _wrap_socket_sni() with parameter
# testing context creation inside _wrap_socket_sni() with parameter
# cert_reqs == ssl.CERT_NONE. Previously raised ValueError because
# code path attempted to set context.verify_mode=ssl.CERT_NONE before
# setting context.check_hostname = False which raised a ValueError
Expand All @@ -740,7 +747,7 @@ def test_wrap_socket_sni_cert_reqs(self):
)
mock_load_default_certs.assert_not_called()
mock_wrap_socket.assert_called_once()

with patch('ssl.SSLContext.wrap_socket') as mock_wrap_socket:
with patch('ssl.SSLContext.load_default_certs') as mock_load_default_certs:
sock = Mock()
Expand Down

0 comments on commit d4c879f

Please sign in to comment.