/
TCPListeners.py
235 lines (206 loc) · 7.07 KB
/
TCPListeners.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
*******************************************************************
Copyright (c) 2013, 2017 IBM Corp.
All rights reserved. This program and the accompanying materials
are made available under the terms of the Eclipse Public License v1.0
and Eclipse Distribution License v1.0 which accompany this distribution.
The Eclipse Public License is available at
http://www.eclipse.org/legal/epl-v10.html
and the Eclipse Distribution License is available at
http://www.eclipse.org/org/documents/edl-v10.php.
Contributors:
Ian Craggs - initial implementation and/or documentation
Ian Craggs - add websockets support
Ian Craggs - add TLS support
*******************************************************************
"""
import socketserver, select, sys, traceback, socket, logging, getopt, hashlib, base64
import threading, ssl
from mqtt.brokers.V311 import MQTTBrokers as MQTTV3Brokers
from mqtt.brokers.V5 import MQTTBrokers as MQTTV5Brokers
from mqtt.formats.MQTTV311 import MQTTException as MQTTV3Exception
from mqtt.formats.MQTTV5 import MQTTException as MQTTV5Exception
server = None
logger = logging.getLogger('MQTT broker')
class BufferedSockets:
def __init__(self, socket):
self.socket = socket
self.buffer = bytearray()
self.websockets = False
def rebuffer(self, data):
self.buffer = data + self.buffer
def wsrecv(self):
header1 = ord(self.socket.recv(1))
header2 = ord(self.socket.recv(1))
opcode = (header1 & 0x0f)
maskbit = (header2 & 0x80) == 0x80
length = (header2 & 0x7f)
if length == 126:
lb1 = ord(self.socket.recv(1))
lb2 = ord(self.socket.recv(1))
length = lb1*256+lb2
elif length == 127:
length = 0
for i in range(0, 8):
length += ord(self.socket.recv(1))
length = (leng << 8)
if maskbit:
mask = self.socket.recv(4)
mpayload = bytearray()
while len(mpayload) < length:
mpayload += self.socket.recv(length - len(mpayload))
buffer = bytearray()
if maskbit:
mi = 0
for i in mpayload:
buffer.append(i ^ mask[mi])
mi = (mi+1)%4
else:
buffer = mplayload
self.buffer += buffer
def recv(self, bufsize):
if self.websockets:
while len(self.buffer) < bufsize:
self.wsrecv()
out = self.buffer[:bufsize]
self.buffer = self.buffer[bufsize:]
else:
if bufsize <= len(self.buffer):
out = self.buffer[:bufsize]
self.buffer = self.buffer[bufsize:]
else:
out = self.buffer + self.socket.recv(bufsize - len(self.buffer))
self.buffer = bytes()
return out
def __getattr__(self, name):
return getattr(self.socket, name)
def send(self, data):
header = bytearray()
if self.websockets:
header.append(0x82) # opcode
l = len(data)
if l < 126:
header.append(l)
elif 125 < l <= 32767:
header += bytearray([126, l // 256, l % 256])
elif l > 32767:
logger("TODO: payload longer than 32767 bytes")
return
return self.socket.send(header + data)
class WebSocketTCPHandler(socketserver.StreamRequestHandler):
def getheaders(self, data):
headers = {}
lines = data.splitlines()
for curline in lines[1:]:
if curline.find(":") != -1:
key, value = curline.split(": ", 1)
headers[key] = value
return headers
def handshake(self, client):
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
data = client.recv(1024).decode('utf-8')
headers = self.getheaders(data)
digest = base64.b64encode(hashlib.sha1((headers['Sec-WebSocket-Key'] + GUID).encode("utf-8")).digest())
resp = b"HTTP/1.1 101 Switching Protocols\r\n" +\
b"Upgrade: websocket\r\n" +\
b"Connection: Upgrade\r\n" +\
b"Sec-WebSocket-Protocol: mqtt\r\n" +\
b"Sec-WebSocket-Accept: " + digest +b"\r\n\r\n"
return client.send(resp)
def handle(self):
global server
first = True
broker = None
sock = BufferedSockets(self.request)
sock_no = sock.fileno()
terminate = keptalive = False
logger.info("Starting communications for socket %d", sock_no)
while not terminate and server and not server.terminate:
try:
if not keptalive:
logger.info("Waiting for request")
(i, o, e) = select.select([sock], [], [], 1)
if i == [sock]:
if first:
char = sock.recv(1)
sock.rebuffer(char)
if char == b"G": # should be websocket connection
self.handshake(sock)
sock.websockets = True
if sock.websockets and first:
pass
else:
if broker == None:
connbuf = sock.recv(1)
if connbuf == b'\x10': # connect packet
while connbuf[-4:] != b"MQTT" and len(connbuf) < 10:
connbuf += sock.recv(1)
connbuf += sock.recv(1)
version = connbuf[-1]
if version == 4:
broker = broker3
elif version == 5:
broker = broker5
sock.rebuffer(connbuf)
if broker == None:
terminate = True
else:
terminate = broker.handleRequest(sock)
keptalive = False
first = False
elif (i, o, e) == ([], [], []):
broker3.keepalive(sock)
broker5.keepalive(sock)
keptalive = True
else:
break
except UnicodeDecodeError:
logger.error("[MQTT-1.4.0-1] Unicode field encoding error")
break
except MQTTV3Exception as exc:
logger.error(exc.args[0])
break
except MQTTV5Exception as exc:
logger.error(exc.args[0])
break
except AssertionError as exc:
if (len(exc.args) > 0):
logger.error(exc.args[0])
else:
logger.error("")
break
except:
logger.exception("WebSocketTCPHandler")
break
logger.info("Finishing communications for socket %d", sock_no)
class ThreadingTCPServer(socketserver.ThreadingMixIn,
socketserver.TCPServer):
pass
def setBrokers(aBroker3, aBroker5):
global broker3, broker5
broker3 = aBroker3
broker5 = aBroker5
def create(port, host="", TLS=False, serve_forever=False,
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=None, certfile=None, keyfile=None):
global server
logger.info("Starting MQTT server on address '%s' port %d %s", host, port, "with TLS support" if TLS else "")
bind_address = ""
if host not in ["", "INADDR_ANY"]:
bind_address = host
server = ThreadingTCPServer((bind_address, port), WebSocketTCPHandler, False)
if TLS:
server.socket = ssl.wrap_socket(server.socket,
ca_certs=ca_certs, certfile=certfile, keyfile=keyfile,
cert_reqs=cert_reqs, server_side=True)
server.terminate = False
server.allow_reuse_address = True
server.server_bind()
server.server_activate()
if serve_forever:
server.serve_forever()
else:
thread = threading.Thread(target = server.serve_forever)
thread.daemon = True
thread.start()
return server