Skip to content

Commit

Permalink
Disconnect reason code.
Browse files Browse the repository at this point in the history
  • Loading branch information
eerimoq committed May 8, 2019
1 parent eec5a68 commit 42de10e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 54 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
test:
python3.7 setup.py test
python3 setup.py test

release-to-pypi:
python3.7 setup.py sdist
python3.7 setup.py bdist_wheel --universal
python3 setup.py sdist
python3 setup.py bdist_wheel --universal
twine upload dist/*
74 changes: 25 additions & 49 deletions mqttools/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,9 @@ def unpack_connack(payload):
return session_present, reason, properties


def pack_disconnect():
def pack_disconnect(reason):
packed = pack_fixed_header(ControlPacketType.DISCONNECT, 0, 2)
packed += struct.pack('B', DisconnectReasonCode.NORMAL_DISCONNECTION)
packed += struct.pack('B', reason)
packed += pack_variable_integer(0)

return packed
Expand Down Expand Up @@ -982,6 +982,7 @@ def __init__(self,
self._broker_receive_maximum = None
self._broker_receive_maximum_semaphore = None
self._on_publish_qos_2_transactions = None
self._disconnect_reason = None

if keep_alive_s == 0:
self._ping_period_s = None
Expand Down Expand Up @@ -1047,6 +1048,7 @@ async def start(self):
self._broker_receive_maximum = None
self._broker_receive_maximum_semaphore = None
self._on_publish_qos_2_transactions = {}
self._disconnect_reason = DisconnectReasonCode.NORMAL_DISCONNECTION
self._reader, self._writer = await asyncio.open_connection(
self._host,
self._port,
Expand Down Expand Up @@ -1140,7 +1142,11 @@ async def connect(self):
}

def disconnect(self):
self._write_packet(pack_disconnect())
if self._disconnect_reason is None:
return

self._write_packet(pack_disconnect(self._disconnect_reason))
self._disconnect_reason = None

async def subscribe(self, topic, qos):
"""Subscribe to given topic with given QoS.
Expand Down Expand Up @@ -1262,14 +1268,9 @@ async def on_publish_qos_2_timer(self, packet_identifier):

async def on_publish(self, flags, payload):
qos = ((flags >> 1) & 0x3)

try:
packet_identifier, topic, message, properties = unpack_publish(
payload,
qos)
except MalformedPacketError:
LOGGER.debug('Discarding malformed PUBLISH packet.')
return
packet_identifier, topic, message, properties = unpack_publish(
payload,
qos)

if PropertyIds.TOPIC_ALIAS in properties:
alias = properties[PropertyIds.TOPIC_ALIAS]
Expand Down Expand Up @@ -1310,11 +1311,7 @@ async def on_publish(self, flags, payload):
LOGGER.debug('Received invalid QoS %d.', qos)

def on_puback(self, payload):
try:
packet_identifier, reason = unpack_puback(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed PUBACK packet.')
return
packet_identifier, reason = unpack_puback(payload)

if packet_identifier in self.transactions:
self.transactions[packet_identifier].set_completed(reason)
Expand All @@ -1324,11 +1321,7 @@ def on_puback(self, payload):
packet_identifier)

def on_pubrec(self, payload):
try:
packet_identifier, reason = unpack_pubrec(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed PUBREC packet.')
return
packet_identifier, reason = unpack_pubrec(payload)

if packet_identifier in self.transactions:
self.transactions[packet_identifier].set_response(reason)
Expand All @@ -1338,11 +1331,7 @@ def on_pubrec(self, payload):
packet_identifier)

async def on_pubrel(self, payload):
try:
packet_identifier, reason = unpack_pubrel(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed PUBREL packet.')
return
packet_identifier, reason = unpack_pubrel(payload)

if packet_identifier in self._on_publish_qos_2_transactions:
if reason == PubrelReasonCode.SUCCESS:
Expand All @@ -1359,11 +1348,7 @@ async def on_pubrel(self, payload):
PubcompReasonCode.PACKET_IDENTIFIER_NOT_FOUND))

def on_pubcomp(self, payload):
try:
packet_identifier, reason = unpack_pubcomp(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed PUBCOMP packet.')
return
packet_identifier, reason = unpack_pubcomp(payload)

if packet_identifier in self.transactions:
self.transactions[packet_identifier].set_completed(reason)
Expand All @@ -1373,11 +1358,7 @@ def on_pubcomp(self, payload):
packet_identifier)

def on_suback(self, payload):
try:
packet_identifier, properties = unpack_suback(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed SUBACK packet.')
return
packet_identifier, properties = unpack_suback(payload)

if packet_identifier in self.transactions:
self.transactions[packet_identifier].set_completed(None)
Expand All @@ -1387,11 +1368,7 @@ def on_suback(self, payload):
packet_identifier)

def on_unsuback(self, payload):
try:
packet_identifier, properties = unpack_unsuback(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed UNSUBACK packet.')
return
packet_identifier, properties = unpack_unsuback(payload)

if packet_identifier in self.transactions:
self.transactions[packet_identifier].set_completed(None)
Expand All @@ -1404,11 +1381,7 @@ def on_pingresp(self):
self._pingresp_event.set()

async def on_disconnect(self, payload):
try:
reason, properties = unpack_disconnect(payload)
except MalformedPacketError:
LOGGER.debug('Discarding malformed DISCONNECT packet.')
return
reason, properties = unpack_disconnect(payload)

if reason != DisconnectReasonCode.NORMAL_DISCONNECTION:
LOGGER.info("Abnormal disconnect reason %s.", reason)
Expand Down Expand Up @@ -1444,9 +1417,7 @@ async def reader_loop(self):
elif packet_type == ControlPacketType.DISCONNECT:
await self.on_disconnect(payload)
else:
LOGGER.warning("Unsupported packet type %s with data %s.",
control_packet_type_to_string(packet_type),
payload.getvalue())
raise MalformedPacketError(f'Invalid packet type {packet_type}.')

async def _reader_main(self):
"""Read packets from the broker.
Expand All @@ -1457,6 +1428,10 @@ async def _reader_main(self):
await self.reader_loop()
except Exception as e:
LOGGER.info('Reader task stopped by %r.', e)

if isinstance(e, MalformedPacketError):
self._disconnect_reason = DisconnectReasonCode.MALFORMED_PACKET

await self._close()

async def keep_alive_loop(self):
Expand Down Expand Up @@ -1526,5 +1501,6 @@ def alloc_packet_identifier(self):
return packet_identifier

async def _close(self):
self.disconnect()
self._writer.close()
await self._messages.put((None, None))
19 changes: 17 additions & 2 deletions tests/test_reconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def client(self):
response_timeout=1,
keep_alive_s=1)

for _ in range(5):
for _ in range(6):
await client.start()
self.messages.append(await client.messages.get())
await client.stop()
Expand Down Expand Up @@ -115,11 +115,26 @@ def test_reconnect(self):
b'\x10\x10\x00\x04MQTT\x05\x02\x00\x01\x00\x00\x03goo')
# CONNACK
self.assertEqual(client2.send(b'\x20\x03\x00\x00\x00'), 5)
# PUBLISH malformed packet with non-UTF-8 topic.
self.assertEqual(client2.send(b'\x30\x07\x00\x01\xff\x00apa'), 9)
# DISCONNECT with reason malformed packet
self.assertEqual(client2.recv(4), b'\xe0\x02\x81\x00')

# Wait for another connection.
client, _ = listener.accept()
client2.close()

# CONNECT
self.assertEqual(
client.recv(18),
b'\x10\x10\x00\x04MQTT\x05\x02\x00\x01\x00\x00\x03goo')
# CONNACK
self.assertEqual(client.send(b'\x20\x03\x00\x00\x00'), 5)

client.close()
listener.close()
client_thread.done.wait()
self.assertEqual(client_thread.messages, 5 * [(None, None)])
self.assertEqual(client_thread.messages, 6 * [(None, None)])


logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit 42de10e

Please sign in to comment.