Skip to content

Commit

Permalink
Improve, Fix & coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ehooo committed Apr 23, 2016
1 parent 0c7d3e0 commit 7bff076
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 68 deletions.
22 changes: 21 additions & 1 deletion django_mqtt/mosquitto/auth_plugin/test.py
@@ -1,6 +1,6 @@

from django.contrib.auth.models import User, Group
from django.test import TestCase, Client
from django.test import TestCase, Client, override_settings
from django.core.urlresolvers import reverse
from django.conf import settings

Expand Down Expand Up @@ -95,6 +95,26 @@ def setUp(self):
self.url_testing = reverse('mqtt_acl')
self.client = Client()

@override_settings(MQTT_ACL_ALLOW=False)
def test_topic_not_allow(self):
response = self.client.post(self.url_testing,
{'username': 'test',
'acc': models.PROTO_MQTT_ACC_PUB,
'topic': '/no/exist/topic'})
no_exist = models.Topic.objects.filter(name='/no/exist/topic').count()
self.assertEqual(no_exist, 1)
self.assertEqual(response.status_code, 403)

@override_settings(MQTT_ACL_ALLOW=True)
def test_topic_not_allow(self):
response = self.client.post(self.url_testing,
{'username': 'test',
'acc': models.PROTO_MQTT_ACC_PUB,
'topic': '/no/exist/topic'})
no_exist = models.Topic.objects.filter(name='/no/exist/topic').count()
self.assertEqual(no_exist, 1)
self.assertEqual(response.status_code, 200)

def test_no_topic(self):
no_exist = models.Topic.objects.filter(name='/no/exist/topic').count()
self.assertEqual(no_exist, 0)
Expand Down
2 changes: 1 addition & 1 deletion django_mqtt/publisher/models.py
Expand Up @@ -264,7 +264,7 @@ def update_remote(self):
cli.connect(self.client.server.host, self.client.server.port, self.client.keepalive)
mqtt_pre_publish.send(sender=Data.__class__, client=self.client,
topic=self.topic, payload=self.payload, qos=self.qos, retain=self.retain)
(rc, mid) = cli.publish(self.topic.name, self.payload, self.qos, self.retain)
(rc, mid) = cli.publish(self.topic.name, payload=self.payload, qos=self.qos, retain=self.retain)
self.client.server.status = rc
self.client.server.save()
mqtt_publish.send(sender=Client.__class__, client=self.client, userdata=cli._userdata, mid=mid)
Expand Down
13 changes: 9 additions & 4 deletions django_mqtt/server/management/commands/mqtt_server.py
Expand Up @@ -14,6 +14,7 @@
import ssl
import sys
import re
import os


naiveip_re = re.compile(r"""^(?:
Expand Down Expand Up @@ -159,8 +160,8 @@ def manage_connection(self, addr, port, certfile=None, keyfile=None, ipv6=False,
bind_socket.listen(backlog)

forks = []
try:
while self.is_running:
while self.is_running:
try:
sock, from_addr = bind_socket.accept()
conn = sock
if context:
Expand All @@ -172,5 +173,9 @@ def manage_connection(self, addr, port, certfile=None, keyfile=None, ipv6=False,
forks.append(th)
else:
th.run()
finally:
pass
except Exception as ex:
import traceback
self.stdout.write(traceback.format_exc())
self.stdout.write(str(ex))
finally:
pass
20 changes: 14 additions & 6 deletions django_mqtt/server/packets.py
@@ -1,6 +1,7 @@
from django_mqtt.protocol import *
import struct
import random
import logging as logger


class MQTTException(Exception):
Expand Down Expand Up @@ -161,6 +162,9 @@ def get_payload(self):
return ""
raise NotImplemented # pragma: no cover

def __str__(self):
return unicode(self).encode('latin-1')

def __unicode__(self):
msg = self.get_variable_header()
msg += self.get_payload()
Expand Down Expand Up @@ -392,11 +396,10 @@ def parse_body(self, body):
padding = size + 2
self.proto_level, self.conn_flags, self.keep_alive = struct.unpack_from("!BBH", body, padding)
padding += 4
if not self.is_clean():
s = body[padding:]
self.client_id = get_string(body[padding:])
(size, ) = struct.unpack_from("!H", body, padding)
padding += 2 + size
s = body[padding:]
self.client_id = get_string(body[padding:])
(size, ) = struct.unpack_from("!H", body, padding)
padding += 2 + size
if self.has_flag():
self._topic = get_string(body[padding:])
(size, ) = struct.unpack_from("!H", body, padding)
Expand All @@ -415,8 +418,9 @@ def parse_body(self, body):
padding += 2
self.auth_password = body[padding: padding+size]
padding += size

if len(body) > padding:
raise MQTTException('Body too big')
raise MQTTException('Body too big size(%s) expected(%s)' % (len(body), padding))

def check_integrity(self):
super(Connect, self).check_integrity()
Expand Down Expand Up @@ -782,10 +786,14 @@ def parse_raw(connection):
cls.parse_body(body)
return cls
except struct.error as s_ex:
logger.exception(s_ex)
raise MQTTException('Invalid format', exception=s_ex)
except UnicodeDecodeError as u_ex:
logger.exception(u_ex)
raise MQTTException('Invalid encode', exception=u_ex)
except ValueError as v_ex:
logger.exception(v_ex)
raise MQTTException('Invalid value', exception=v_ex)
except TypeError as t_ex:
logger.exception(t_ex)
raise MQTTException('Invalid type', exception=t_ex)
83 changes: 42 additions & 41 deletions django_mqtt/server/service.py
Expand Up @@ -7,7 +7,6 @@
from django.db import transaction
from django.conf import settings

from datetime import datetime
from threading import Thread
import logging
import socket
Expand All @@ -17,7 +16,7 @@


class MqttServiceThread(Thread):
def __init__(self, connection, publish_callback=None,*args, **kwargs):
def __init__(self, connection, publish_callback=None, *args, **kwargs):
if not isinstance(connection, socket.socket):
raise ValueError('socket expected')
super(MqttServiceThread, self).__init__(*args, **kwargs)
Expand All @@ -31,6 +30,8 @@ def __init__(self, connection, publish_callback=None,*args, **kwargs):
def next_packet(self):
pkg = parse_raw(self._connection)
pkg.check_integrity()
if self._session:
self._session.ping()
return pkg

def notify_publish(self, publish_pk):
Expand All @@ -52,7 +53,17 @@ def send_publish(self, qos, topic, msg, pack_identifier=None):
publish_pkg = Publish(topic=topic, msg=msg, qos=qos, pack_identifier=pack_identifier)
if qos != MQTT_QoS0:
self._last_publication = publish_pkg
self._connection.sendall(unicode(publish_pkg))
self._connection.sendall(str(publish_pkg))

def stop(self):
self.disconnect = True
if self._session:
self._session = None
if self._connection:
self._connection.setblocking(0)
self._connection.shutdown(socket.SHUT_RDWR)
self._connection.close()
self._connection = None

def run(self):
self.disconnect = False
Expand All @@ -63,7 +74,6 @@ def run(self):
self.process_new_connection(conn_pkg)
while not self.disconnect:
pkg = self.next_packet()
self._session.ping()
if isinstance(pkg, Connect):
self.process_new_connection(pkg)
elif isinstance(pkg, ConnAck):
Expand All @@ -82,7 +92,7 @@ def run(self):
if self._last_publication.QoS != MQTT_QoS2:
raise MQTTException(_('Packer QoS2 not expected'))
resp = PubRel(pack_identifier=pkg.pack_identifier)
self._connection.sendall(unicode(resp))
self._connection.sendall(str(resp))
elif isinstance(pkg, PubRel):
raise MQTTException(_('Packer QoS2 not expected'))
elif isinstance(pkg, PubComp):
Expand All @@ -101,12 +111,11 @@ def run(self):
raise MQTTException(_('Client cannot use UnsubAck packer'))
elif isinstance(pkg, PingReq):
resp = PingResp()
self._connection.sendall(unicode(resp))
self._connection.sendall(str(resp))
elif isinstance(pkg, PingResp):
raise MQTTException(_('Client cannot use PingResp packer'))
elif isinstance(pkg, Disconnect):
self._session.active = False
self._session.save()
self._session.disconnect()
self.disconnect = True
# TODO manager timeoout or disconecctions
except MQTTProtocolException as ex:
Expand All @@ -118,52 +127,43 @@ def run(self):
logging.warning("%s" % ex)
self.disconnect = True
finally:
if self.disconnect:
self._session = None
self._connection.shutdown(socket.SHUT_RDWR)
self._connection.close()
self.stop()

def process_unsubscription(self, unsubscription_pkg):
logger.info('%(client_id)s unsubscription %(topics)s' % {
'client_id': self._session.client_id,
'topics': unsubscription_pkg.topic_list
})
resp = UnsubAck(pack_identifier=unsubscription_pkg.pack_identifier)
for topic in unsubscription_pkg.topic_list:
subs = self._session.subscriptions.filter(topic__name=topic)
if subs.count() > 0:
self._session.subscriptions.remove(subs)
else:
self._session.unsubscriptions.add(subs)
self._connection.sendall(unicode(resp))
self._session.unsubscribe(topic)
self._connection.sendall(str(resp))

def process_subscription(self, subscription_pkg):
resp = SubAck(pack_identifier=subscription_pkg.pack_identifier)
for topic in subscription_pkg.topic_list:
qos = subscription_pkg.topic_list[topic]
subs = None
code = MQTT_SUBACK_FAILURE
try:
topic, new_topic = Topic.objects.get_or_create(name=topic)
acl = ACL.get_acl(topic, PROTO_MQTT_ACC_SUS)
if self._session and acl and acl.has_permission(user=self._session.user):
if qos is MQTT_QoS0:
subs, new_subs = Channel.objects.get_or_create(topic=topic, qos=qos)
self._session = Session()
self._session.subscriptions.add(subs)
self._session.unsubscriptions.remove(topic)
resp.add_response(MQTT_SUBACK_QoS0)
code=MQTT_SUBACK_QoS0
elif qos is MQTT_QoS1:
subs, new_subs = Channel.objects.get_or_create(topic=topic, qos=qos)
self._session.subscriptions.add(subs)
self._session.unsubscriptions.remove(topic)
resp.add_response(MQTT_SUBACK_QoS1)
code=MQTT_SUBACK_QoS1
elif qos is MQTT_QoS2:
subs, new_subs = Channel.objects.get_or_create(topic=topic, qos=qos)
self._session.subscriptions.add(subs)
self._session.unsubscriptions.remove(topic)
resp.add_response(MQTT_SUBACK_QoS2)
else:
resp.add_response(MQTT_SUBACK_FAILURE)
else:
resp.add_response(MQTT_SUBACK_FAILURE)
except ValidationError:
resp.add_response(MQTT_SUBACK_FAILURE)
self._connection.sendall(unicode(resp))
code=MQTT_SUBACK_QoS2
except ValidationError as ex:
logger.exception(ex)
if subs:
self._session.subscribe(channel=subs)
resp.add_response(code)
self._connection.sendall(str(resp))

@transaction.atomic
def process_new_publish_qos2(self, publish_pkg, channel):
Expand All @@ -173,7 +173,7 @@ def process_new_publish_qos2(self, publish_pkg, channel):
publication.message = publish_pkg.msg
publication.save()
resp = PubRec(pack_identifier=publish_pkg.pack_identifier)
self._connection.sendall(unicode(resp))
self._connection.sendall(str(resp))

pkg = self.next_packet()
if not isinstance(pkg, PubRel):
Expand All @@ -185,7 +185,7 @@ def process_new_publish_qos2(self, publish_pkg, channel):

transaction.commit()
resp = PubComp(pack_identifier=pkg.pack_identifier)
self._connection.sendall(unicode(resp))
self._connection.sendall(str(resp))
if self._publish_callback:
self._publish_callback(publication.pk)
except MQTTException as ex:
Expand Down Expand Up @@ -214,7 +214,7 @@ def process_new_publish(self, publish_pkg):
if self._publish_callback:
self._publish_callback(publication.pk)
resp = PubAck(pack_identifier=publish_pkg.pack_identifier)
self._connection.sendall(unicode(resp))
self._connection.sendall(str(resp))
elif publish_pkg.QoS == MQTT_QoS2:
self.process_new_publish_qos2(publish_pkg, channel)

Expand Down Expand Up @@ -327,11 +327,12 @@ def process_new_connection(self, conn_pkg):
else:
conn_ack.set_flags(sp=True)
conn_ack.ret_code = mqtt.CONNACK_ACCEPTED
logger.info(_('New connection accepted id:%(client_id)s user:%(user)s') %
{'client_id': cli_id, 'user': user})
self._connection.settimeout(conn_pkg.keep_alive)
logger.info(_('New connection accepted id:%(client_id)s user:%(user)s "keep alive":%(keep_alive)s') %
{'client_id': cli_id, 'user': user, 'keep_alive': conn_pkg.keep_alive})
except MQTTProtocolException as ex:
conn_ack = ex.get_nack()
raise MQTTException(exception=ex)
finally:
if conn_ack:
self._connection.sendall(unicode(conn_ack))
self._connection.sendall(str(conn_ack))
1 change: 1 addition & 0 deletions django_mqtt/server/test/__init__.py
@@ -0,0 +1 @@

File renamed without changes.
@@ -1,7 +1,3 @@
from django.test import TestCase

from django_mqtt.server.test_service import *
from django_mqtt.server.test_models import *
from django_mqtt.test_models import *
from django_mqtt.server.packets import *
from django_mqtt.protocol import *
Expand Down

0 comments on commit 7bff076

Please sign in to comment.