From 8a697622495fd319582cd1c604e7eb2cc0ac0ef6 Mon Sep 17 00:00:00 2001 From: Pierre Ossman Date: Thu, 15 Sep 2016 19:51:26 +0200 Subject: [PATCH] Separate out raw WebSocket protocol handling --- tests/echo.py | 4 +- tests/echo_client.py | 70 ++ tests/load.py | 14 +- tests/test_websocket.py | 386 +------- tests/test_websocketproxy.py | 4 +- tests/test_websocketserver.py | 347 +++++++ websockify/websocket.py | 1609 ++++++++++++++------------------- websockify/websocketproxy.py | 20 +- websockify/websocketserver.py | 819 +++++++++++++++++ 9 files changed, 1968 insertions(+), 1305 deletions(-) create mode 100755 tests/echo_client.py create mode 100644 tests/test_websocketserver.py create mode 100644 websockify/websocketserver.py diff --git a/tests/echo.py b/tests/echo.py index e6a68515..3d81e04b 100755 --- a/tests/echo.py +++ b/tests/echo.py @@ -12,7 +12,7 @@ import os, sys, select, optparse, logging sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer, WebSocketRequestHandler +from websockify.websocketserver import WebSocketServer, WebSocketRequestHandler class WebSocketEcho(WebSocketRequestHandler): """ @@ -48,7 +48,7 @@ def new_websocket_client(self): cqueue.extend(frames) if closed: - self.send_close() + break if __name__ == '__main__': parser = optparse.OptionParser(usage="%prog [options] listen_port") diff --git a/tests/echo_client.py b/tests/echo_client.py new file mode 100755 index 00000000..6d745ecd --- /dev/null +++ b/tests/echo_client.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +import os +import sys +import optparse +import select + +sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) +from websockify.websocket import WebSocket, \ + WebSocketWantReadError, WebSocketWantWriteError + +parser = optparse.OptionParser(usage="%prog URL") +(opts, args) = parser.parse_args() + +try: + if len(args) != 1: raise + URL = args[0] +except: + parser.error("Invalid arguments") + +sock = WebSocket() +print("Connecting to %s..." % URL) +sock.connect(URL) +print("Connected.") + +def send(msg): + while True: + try: + sock.sendmsg(msg) + break + except WebSocketWantReadError: + msg = '' + ins, outs, excepts = select.select([sock], [], []) + if excepts: raise Exception("Socket exception") + except WebSocketWantWriteError: + msg = '' + ins, outs, excepts = select.select([], [sock], []) + if excepts: raise Exception("Socket exception") + +def read(): + while True: + try: + return sock.recvmsg() + except WebSocketWantReadError: + ins, outs, excepts = select.select([sock], [], []) + if excepts: raise Exception("Socket exception") + except WebSocketWantWriteError: + ins, outs, excepts = select.select([], [sock], []) + if excepts: raise Exception("Socket exception") + +counter = 1 +while True: + msg = "Message #%d" % counter + counter += 1 + send(msg) + print("Sent message: %r" % msg) + + while True: + ins, outs, excepts = select.select([sock], [], [], 1.0) + if excepts: raise Exception("Socket exception") + + if ins == []: + break + + while True: + msg = read() + print("Received message: %r" % msg) + + if not sock.pending(): + break diff --git a/tests/load.py b/tests/load.py index c76feb12..caf6b58d 100755 --- a/tests/load.py +++ b/tests/load.py @@ -8,7 +8,7 @@ import sys, os, select, random, time, optparse, logging sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer, WebSocketRequestHandler +from websockify.websocketserver import WebSocketServer, WebSocketRequestHandler class WebSocketLoadServer(WebSocketServer): @@ -35,12 +35,10 @@ def new_websocket_client(self): self.send_cnt = 0 self.recv_cnt = 0 - try: - self.responder(self.request) - except: - print "accumulated errors:", self.errors - self.errors = 0 - raise + self.responder(self.request) + + print "accumulated errors:", self.errors + self.errors = 0 def responder(self, client): c_pend = 0 @@ -62,7 +60,7 @@ def responder(self, client): print err if closed: - self.send_close() + break now = time.time() * 1000 if client in outs: diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 545fa1c2..77d0eca1 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -15,418 +15,74 @@ # under the License. """ Unit tests for websocket """ -import errno -import os -import logging -import select -import shutil -import socket -import ssl -from mox3 import stubout -import sys -import tempfile import unittest -import socket -import signal from websockify import websocket -try: - from SimpleHTTPServer import SimpleHTTPRequestHandler -except ImportError: - from http.server import SimpleHTTPRequestHandler - -try: - from StringIO import StringIO - BytesIO = StringIO -except ImportError: - from io import StringIO - from io import BytesIO - - - - -def raise_oserror(*args, **kwargs): - raise OSError('fake error') - - -class FakeSocket(object): - def __init__(self, data=''): - if isinstance(data, bytes): - self._data = data - else: - self._data = data.encode('latin_1') - - def recv(self, amt, flags=None): - res = self._data[0:amt] - if not (flags & socket.MSG_PEEK): - self._data = self._data[amt:] - - return res - - def makefile(self, mode='r', buffsize=None): - if 'b' in mode: - return BytesIO(self._data) - else: - return StringIO(self._data.decode('latin_1')) - - -class WebSocketRequestHandlerTestCase(unittest.TestCase): - def setUp(self): - super(WebSocketRequestHandlerTestCase, self).setUp() - self.stubs = stubout.StubOutForTesting() - self.tmpdir = tempfile.mkdtemp('-websockify-tests') - # Mock this out cause it screws tests up - self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - lambda *args, **kwargs: None) - - def tearDown(self): - """Called automatically after each test.""" - self.stubs.UnsetAll() - os.rmdir(self.tmpdir) - super(WebSocketRequestHandlerTestCase, self).tearDown() - - def _get_server(self, handler_class=websocket.WebSocketRequestHandler, - **kwargs): - web = kwargs.pop('web', self.tmpdir) - return websocket.WebSocketServer( - handler_class, listen_host='localhost', - listen_port=80, key=self.tmpdir, web=web, - record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, - **kwargs) - - def test_normal_get_with_only_upgrade_returns_error(self): - server = self._get_server(web=None) - handler = websocket.WebSocketRequestHandler( - FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) - - def fake_send_response(self, code, message=None): - self.last_code = code - - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - fake_send_response) - - handler.do_GET() - self.assertEqual(handler.last_code, 405) - - def test_list_dir_with_file_only_returns_error(self): - server = self._get_server(file_only=True) - handler = websocket.WebSocketRequestHandler( - FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) - - def fake_send_response(self, code, message=None): - self.last_code = code - - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - fake_send_response) - - handler.path = '/' - handler.do_GET() - self.assertEqual(handler.last_code, 404) - - -class WebSocketServerTestCase(unittest.TestCase): - def setUp(self): - super(WebSocketServerTestCase, self).setUp() - self.stubs = stubout.StubOutForTesting() - self.tmpdir = tempfile.mkdtemp('-websockify-tests') - # Mock this out cause it screws tests up - self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) - - def tearDown(self): - """Called automatically after each test.""" - self.stubs.UnsetAll() - os.rmdir(self.tmpdir) - super(WebSocketServerTestCase, self).tearDown() - - def _get_server(self, handler_class=websocket.WebSocketRequestHandler, - **kwargs): - return websocket.WebSocketServer( - handler_class, listen_host='localhost', - listen_port=80, key=self.tmpdir, web=self.tmpdir, - record=self.tmpdir, **kwargs) - - def test_daemonize_raises_error_while_closing_fds(self): - server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: 0) - self.stubs.Set(signal, 'signal', lambda *args: None) - self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(os, 'close', raise_oserror) - self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') - - def test_daemonize_ignores_ebadf_error_while_closing_fds(self): - def raise_oserror_ebadf(fd): - raise OSError(errno.EBADF, 'fake error') - - server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: 0) - self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(signal, 'signal', lambda *args: None) - self.stubs.Set(os, 'close', raise_oserror_ebadf) - self.stubs.Set(os, 'open', raise_oserror) - self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') - - def test_handshake_fails_on_not_ready(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - FakeSocket(), '127.0.0.1') - - def test_empty_handshake_fails(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - - sock = FakeSocket('') - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_handshake_policy_request(self): - # TODO(directxman12): implement - pass - - def test_handshake_ssl_only_without_ssl_raises_error(self): - server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - - sock = FakeSocket('some initial data') - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_do_handshake_no_ssl(self): - class FakeHandler(object): - CALLED = False - def __init__(self, *args, **kwargs): - type(self).CALLED = True - - FakeHandler.CALLED = False - - server = self._get_server( - handler_class=FakeHandler, daemon=True, - ssl_only=0, idle_timeout=1) - - sock = FakeSocket('some initial data') - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) - self.assertTrue(FakeHandler.CALLED, True) - - def test_do_handshake_ssl(self): - # TODO(directxman12): implement this - pass - - def test_do_handshake_ssl_without_ssl_raises_error(self): - # TODO(directxman12): implement this - pass - - def test_do_handshake_ssl_without_cert_raises_error(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, - cert='afdsfasdafdsafdsafdsafdas') - - sock = FakeSocket("\x16some ssl data") - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_do_handshake_ssl_error_eof_raises_close_error(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - - sock = FakeSocket("\x16some ssl data") - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - def fake_wrap_socket(*args, **kwargs): - raise ssl.SSLError(ssl.SSL_ERROR_EOF) - - self.stubs.Set(select, 'select', fake_select) - self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_fallback_sigchld_handler(self): - # TODO(directxman12): implement this - pass - - def test_start_server_error(self): - server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) - sock = server.socket('localhost') - - def fake_select(rlist, wlist, xlist, timeout=None): - raise Exception("fake error") - - self.stubs.Set(websocket.WebSocketServer, 'socket', - lambda *args, **kwargs: sock) - self.stubs.Set(websocket.WebSocketServer, 'daemonize', - lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', fake_select) - server.start_server() - - def test_start_server_keyboardinterrupt(self): - server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost') - - def fake_select(rlist, wlist, xlist, timeout=None): - raise KeyboardInterrupt - - self.stubs.Set(websocket.WebSocketServer, 'socket', - lambda *args, **kwargs: sock) - self.stubs.Set(websocket.WebSocketServer, 'daemonize', - lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', fake_select) - server.start_server() - - def test_start_server_systemexit(self): - server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost') - - def fake_select(rlist, wlist, xlist, timeout=None): - sys.exit() - - self.stubs.Set(websocket.WebSocketServer, 'socket', - lambda *args, **kwargs: sock) - self.stubs.Set(websocket.WebSocketServer, 'daemonize', - lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', fake_select) - server.start_server() - - def test_socket_set_keepalive_options(self): - keepcnt = 12 - keepidle = 34 - keepintvl = 56 - - server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) - - if hasattr(socket, 'TCP_KEEPCNT'): - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPCNT), keepcnt) - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPIDLE), keepidle) - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPINTVL), keepintvl) - - sock = server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) - - if hasattr(socket, 'TCP_KEEPCNT'): - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPCNT), keepcnt) - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPIDLE), keepidle) - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPINTVL), keepintvl) - - class HyBiEncodeDecodeTestCase(unittest.TestCase): def test_decode_hybi_text(self): buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' - res = websocket.WebSocketRequestHandler.decode_hybi(buf) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x1) self.assertEqual(res['masked'], True) - self.assertEqual(res['length'], 5) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], b'Hello') - self.assertEqual(res['left'], 0) def test_decode_hybi_binary(self): buf = b'\x82\x04\x01\x02\x03\x04' - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], 4) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], b'\x01\x02\x03\x04') - self.assertEqual(res['left'], 0) def test_decode_hybi_extended_16bit_binary(self): data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 buf = b'\x82\x7e\x01\x04' + data - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], 260) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], data) - self.assertEqual(res['left'], 0) def test_decode_hybi_extended_64bit_binary(self): data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], 260) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], data) - self.assertEqual(res['left'], 0) def test_decode_hybi_multi(self): buf1 = b'\x01\x03\x48\x65\x6c' buf2 = b'\x80\x02\x6c\x6f' - res1 = websocket.WebSocketRequestHandler.decode_hybi(buf1, strict=False) + ws = websocket.WebSocket() + + res1 = ws._decode_hybi(buf1) self.assertEqual(res1['fin'], 0) self.assertEqual(res1['opcode'], 0x1) - self.assertEqual(res1['length'], 3) + self.assertEqual(res1['length'], len(buf1)) self.assertEqual(res1['payload'], b'Hel') - self.assertEqual(res1['left'], 0) - res2 = websocket.WebSocketRequestHandler.decode_hybi(buf2, strict=False) + res2 = ws._decode_hybi(buf2) self.assertEqual(res2['fin'], 1) self.assertEqual(res2['opcode'], 0x0) - self.assertEqual(res2['length'], 2) + self.assertEqual(res2['length'], len(buf2)) self.assertEqual(res2['payload'], b'lo') - self.assertEqual(res2['left'], 0) def test_encode_hybi_basic(self): - res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1) - expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0) + ws = websocket.WebSocket() + res = ws._encode_hybi(0x1, b'Hello') + expected = b'\x81\x05\x48\x65\x6c\x6c\x6f' self.assertEqual(res, expected) - - def test_strict_mode_refuses_unmasked_client_frames(self): - buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' - self.assertRaises(websocket.WebSocketRequestHandler.CClose, - websocket.WebSocketRequestHandler.decode_hybi, - buf) - - def test_no_strict_mode_accepts_unmasked_client_frames(self): - buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) - - self.assertEqual(res['fin'], 1) - self.assertEqual(res['opcode'], 0x1) - self.assertEqual(res['masked'], False) - self.assertEqual(res['length'], 5) - self.assertEqual(res['payload'], b'Hello') diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index b48796e4..ac08dfaa 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -22,7 +22,7 @@ from mox3 import stubout -from websockify import websocket +from websockify import websocketserver from websockify import websocketproxy from websockify import token_plugins from websockify import auth_plugins @@ -75,7 +75,7 @@ def setUp(self): FakeSocket(''), "127.0.0.1", FakeServer()) self.handler.path = "https://localhost:6080/websockify?token=blah" self.handler.headers = None - self.stubs.Set(websocket.WebSocketServer, 'socket', + self.stubs.Set(websocketserver.WebSocketServer, 'socket', staticmethod(lambda *args, **kwargs: None)) def tearDown(self): diff --git a/tests/test_websocketserver.py b/tests/test_websocketserver.py new file mode 100644 index 00000000..aaeeee69 --- /dev/null +++ b/tests/test_websocketserver.py @@ -0,0 +1,347 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright(c)2013 NTT corp. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" Unit tests for websocketserver """ +import errno +import os +import logging +import select +import shutil +import socket +import ssl +from mox3 import stubout +import sys +import tempfile +import unittest +import socket +import signal +from websockify import websocketserver + +try: + from SimpleHTTPServer import SimpleHTTPRequestHandler +except ImportError: + from http.server import SimpleHTTPRequestHandler + +try: + from StringIO import StringIO + BytesIO = StringIO +except ImportError: + from io import StringIO + from io import BytesIO + + + + +def raise_oserror(*args, **kwargs): + raise OSError('fake error') + + +class FakeSocket(object): + def __init__(self, data=''): + if isinstance(data, bytes): + self._data = data + else: + self._data = data.encode('latin_1') + + def recv(self, amt, flags=None): + res = self._data[0:amt] + if not (flags & socket.MSG_PEEK): + self._data = self._data[amt:] + + return res + + def makefile(self, mode='r', buffsize=None): + if 'b' in mode: + return BytesIO(self._data) + else: + return StringIO(self._data.decode('latin_1')) + + +class WebSocketRequestHandlerTestCase(unittest.TestCase): + def setUp(self): + super(WebSocketRequestHandlerTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.tmpdir = tempfile.mkdtemp('-websockify-tests') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + lambda *args, **kwargs: None) + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + os.rmdir(self.tmpdir) + super(WebSocketRequestHandlerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocketserver.WebSocketRequestHandler, + **kwargs): + web = kwargs.pop('web', self.tmpdir) + return websocketserver.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=web, + record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, + **kwargs) + + def test_normal_get_with_only_upgrade_returns_error(self): + server = self._get_server(web=None) + handler = websocketserver.WebSocketRequestHandler( + FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) + + def fake_send_response(self, code, message=None): + self.last_code = code + + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) + + handler.do_GET() + self.assertEqual(handler.last_code, 405) + + def test_list_dir_with_file_only_returns_error(self): + server = self._get_server(file_only=True) + handler = websocketserver.WebSocketRequestHandler( + FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) + + def fake_send_response(self, code, message=None): + self.last_code = code + + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) + + handler.path = '/' + handler.do_GET() + self.assertEqual(handler.last_code, 404) + + +class WebSocketServerTestCase(unittest.TestCase): + def setUp(self): + super(WebSocketServerTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.tmpdir = tempfile.mkdtemp('-websockify-tests') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + os.rmdir(self.tmpdir) + super(WebSocketServerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocketserver.WebSocketRequestHandler, + **kwargs): + return websocketserver.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=self.tmpdir, + record=self.tmpdir, **kwargs) + + def test_daemonize_raises_error_while_closing_fds(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) + self.stubs.Set(signal, 'signal', lambda *args: None) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(os, 'close', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + + def test_daemonize_ignores_ebadf_error_while_closing_fds(self): + def raise_oserror_ebadf(fd): + raise OSError(errno.EBADF, 'fake error') + + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(signal, 'signal', lambda *args: None) + self.stubs.Set(os, 'close', raise_oserror_ebadf) + self.stubs.Set(os, 'open', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + + def test_handshake_fails_on_not_ready(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + FakeSocket(), '127.0.0.1') + + def test_empty_handshake_fails(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket('') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_handshake_policy_request(self): + # TODO(directxman12): implement + pass + + def test_handshake_ssl_only_without_ssl_raises_error(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_no_ssl(self): + class FakeHandler(object): + CALLED = False + def __init__(self, *args, **kwargs): + type(self).CALLED = True + + FakeHandler.CALLED = False + + server = self._get_server( + handler_class=FakeHandler, daemon=True, + ssl_only=0, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) + self.assertTrue(FakeHandler.CALLED, True) + + def test_do_handshake_ssl(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_ssl_raises_error(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_cert_raises_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, + cert='afdsfasdafdsafdsafdsafdas') + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_ssl_error_eof_raises_close_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + def fake_wrap_socket(*args, **kwargs): + raise ssl.SSLError(ssl.SSL_ERROR_EOF) + + self.stubs.Set(select, 'select', fake_select) + self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_fallback_sigchld_handler(self): + # TODO(directxman12): implement this + pass + + def test_start_server_error(self): + server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise Exception("fake error") + + self.stubs.Set(websocketserver.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) + self.stubs.Set(websocketserver.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_start_server_keyboardinterrupt(self): + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise KeyboardInterrupt + + self.stubs.Set(websocketserver.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) + self.stubs.Set(websocketserver.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_start_server_systemexit(self): + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + sys.exit() + + self.stubs.Set(websocketserver.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) + self.stubs.Set(websocketserver.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_socket_set_keepalive_options(self): + keepcnt = 12 + keepidle = 34 + keepintvl = 56 + + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) + + if hasattr(socket, 'TCP_KEEPCNT'): + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPCNT), keepcnt) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPIDLE), keepidle) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPINTVL), keepintvl) + + sock = server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) + + if hasattr(socket, 'TCP_KEEPCNT'): + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPCNT), keepcnt) + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPIDLE), keepidle) + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPINTVL), keepintvl) diff --git a/websockify/websocket.py b/websockify/websocket.py index 00ee4916..72a269ce 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -1,413 +1,266 @@ #!/usr/bin/env python ''' -Python WebSocket library with support for "wss://" encryption. +Python WebSocket library Copyright 2011 Joel Martin +Copyright 2016 Pierre Ossman Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Supports following protocol versions: - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-07 - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 - http://tools.ietf.org/html/rfc6455 - -You can make a cert/key with openssl using: -openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem -as taken from http://docs.python.org/dev/library/ssl.html#certificates - ''' -import os, sys, time, errno, signal, socket, select, logging -import array, ssl, multiprocessing +import sys +import array +import email +import errno +import random +import socket +import ssl +import struct from base64 import b64encode from hashlib import sha1 -from struct import pack, unpack_from -# Imports that vary by python version +try: + import numpy +except ImportError: + import warnings + warnings.warn("no 'numpy' module, HyBi protocol will be slower") + numpy = None # python 3.0 differences -if sys.hexversion > 0x3000000: - b2s = lambda buf: buf.decode('latin_1') - s2b = lambda s: s.encode('latin_1') - s2a = lambda s: s -else: - b2s = lambda buf: buf # No-op - s2b = lambda s: s # No-op - s2a = lambda s: [ord(c) for c in s] -try: from io import StringIO -except: from cStringIO import StringIO -try: from http.server import SimpleHTTPRequestHandler -except: from SimpleHTTPServer import SimpleHTTPRequestHandler - -# Degraded functionality if these imports are missing -for mod, msg in [('numpy', 'HyBi protocol will be slower'), - ('resource', 'daemonizing is disabled')]: - try: - globals()[mod] = __import__(mod) - except ImportError: - globals()[mod] = None - print("WARNING: no '%s' module, %s" % (mod, msg)) - -if sys.platform == 'win32': - # make sockets pickle-able/inheritable - import multiprocessing.reduction - - -# HTTP handler with WebSocket upgrade support -class WebSocketRequestHandler(SimpleHTTPRequestHandler): - """ - WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler. - Must be sub-classed with new_websocket_client method definition. - The request handler can be configured by setting optional - attributes on the server object: - - * only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled, - only websocket is allowed. - * verbose: If true, verbose logging is activated. - * daemon: Running as daemon, do not write to console etc - * record: Record raw frame data as JavaScript array into specified filename - * run_once: Handle a single request - * handler_id: A sequence number for this connection, appended to record filename +try: from urllib.parse import urlparse +except: from urlparse import urlparse + +# SSLWant*Error is 2.7.9+ +try: + class WebSocketWantReadError(ssl.SSLWantReadError): + pass + class WebSocketWantWriteError(ssl.SSLWantWriteError): + pass +except: + class WebSocketWantReadError(OSError): + def __init__(self): + OSError.__init__(self, errno.EWOULDBLOCK) + class WebSocketWantWriteError(OSError): + def __init__(self): + OSError.__init__(self, errno.EWOULDBLOCK) + +class WebSocket(object): + """WebSocket protocol socket like class. + + This provides access to the WebSocket protocol by behaving much + like a real socket would. It shares many similarities with + ssl.SSLSocket. + + The WebSocket protocols requires extra data to be sent and received + compared to the application level data. This means that a socket + that is ready to be read may not hold enough data to decode any + application data, and a socket that is ready to be written to may + not have enough space for an entire WebSocket frame. This is + handled by the exceptions WebSocketWantReadError and + WebSocketWantWriteError. When these are raised the caller must wait + for the socket to become ready again and call the relevant function + again. + + A connection is established by using either connect() or accept(), + depending on if a client or server session is desired. See the + respective functions for details. + + The following methods are passed on to the underlying socket: + + - fileno + - getpeername, getsockname + - getsockopt, setsockopt + - gettimeout, settimeout + - setblocking """ - buffer_size = 65536 GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_version = "WebSockify" + def __init__(self): + """Creates an unconnected WebSocket""" - protocol_version = "HTTP/1.1" + self._state = "new" - # An exception while the WebSocket client was connected - class CClose(Exception): - pass + self._partial_msg = ''.encode("ascii") - def __init__(self, req, addr, server): - # Retrieve a few configuration variables from the server - self.only_upgrade = getattr(server, "only_upgrade", False) - self.verbose = getattr(server, "verbose", False) - self.daemon = getattr(server, "daemon", False) - self.record = getattr(server, "record", False) - self.run_once = getattr(server, "run_once", False) - self.rec = None - self.handler_id = getattr(server, "handler_id", False) - self.file_only = getattr(server, "file_only", False) - self.traffic = getattr(server, "traffic", False) - self.auto_pong = getattr(server, "auto_pong", False) - self.strict_mode = getattr(server, "strict_mode", True) - - self.logger = getattr(server, "logger", None) - if self.logger is None: - self.logger = WebSocketServer.get_logger() - - SimpleHTTPRequestHandler.__init__(self, req, addr, server) - - def log_message(self, format, *args): - self.logger.info("%s - - [%s] %s" % (self.address_string(), self.log_date_time_string(), format % args)) - - @staticmethod - def unmask(buf, hlen, plen): - pstart = hlen + 4 - pend = pstart + plen - if numpy: - b = c = s2b('') - if plen >= 4: - dtype=numpy.dtype('') - mask = numpy.frombuffer(buf, dtype, offset=hlen, count=1) - data = numpy.frombuffer(buf, dtype, offset=pstart, - count=int(plen / 4)) - #b = numpy.bitwise_xor(data, mask).data - b = numpy.bitwise_xor(data, mask).tostring() + self._recv_buffer = ''.encode("ascii") + self._recv_queue = [] + self._send_buffer = ''.encode("ascii") - if plen % 4: - #self.msg("Partial unmask") - dtype=numpy.dtype('B') - if sys.byteorder == 'big': - dtype = dtype.newbyteorder('>') - mask = numpy.frombuffer(buf, dtype, offset=hlen, - count=(plen % 4)) - data = numpy.frombuffer(buf, dtype, - offset=pend - (plen % 4), count=(plen % 4)) - c = numpy.bitwise_xor(data, mask).tostring() - return b + c - else: - # Slower fallback - mask = buf[hlen:hlen+4] - data = array.array('B') - mask = s2a(mask) - data.fromstring(buf[pstart:pend]) - for i in range(len(data)): - data[i] ^= mask[i % 4] - return data.tostring() + self._sent_close = False + self._received_close = False - @staticmethod - def encode_hybi(buf, opcode): - """ Encode a HyBi style WebSocket frame. - Optional opcode: - 0x0 - continuation - 0x1 - text frame (base64 encode buf) - 0x2 - binary frame (use raw buf) - 0x8 - connection close - 0x9 - ping - 0xA - pong - """ + self.close_code = None + self.close_reason = None - b1 = 0x80 | (opcode & 0x0f) # FIN + opcode - payload_len = len(buf) - if payload_len <= 125: - header = pack('>BB', b1, payload_len) - elif payload_len > 125 and payload_len < 65536: - header = pack('>BBH', b1, 126, payload_len) - elif payload_len >= 65536: - header = pack('>BBQ', b1, 127, payload_len) + self.socket = None - #self.msg("Encoded: %s", repr(header + buf)) + def __getattr__(self, name): + # These methods are just redirected to the underlying socket + if name in ["fileno", + "getpeername", "getsockname", + "getsockopt", "setsockopt", + "gettimeout", "settimeout", + "setblocking"]: + assert self.socket is not None + return getattr(self.socket, name) + else: + raise AttributeError("%s instance has no attribute '%s'" % + (self.__class__.__name__, name)) - return header + buf, len(header), 0 + def connect(self, uri, origin=None, protocols=[]): + """Establishes a new connection to a WebSocket server. - @staticmethod - def decode_hybi(buf, logger=None, strict=True): - """ Decode HyBi style WebSocket packets. - Returns: - {'fin' : 0_or_1, - 'opcode' : number, - 'masked' : boolean, - 'hlen' : header_bytes_number, - 'length' : payload_bytes_number, - 'payload' : decoded_buffer, - 'left' : bytes_left_number, - 'close_code' : number, - 'close_reason' : string} + This method connects to the host specified by uri and + negotiates a WebSocket connection. origin should be specified + in accordance with RFC 6454 if known. A list of valid + sub-protocols can be specified in the protocols argument. + + The data will be sent in the clear if the "ws" scheme is used, + and encrypted if the "wss" scheme is used. + + Both WebSocketWantReadError and WebSocketWantWriteError can be + raised whilst negotiating the connection. Repeated calls to + connect() must retain the same arguments. """ - f = {'fin' : 0, - 'opcode' : 0, - 'masked' : False, - 'hlen' : 2, - 'length' : 0, - 'payload' : None, - 'left' : 0, - 'close_code' : 1000, - 'close_reason' : ''} + self.client = True; - if logger is None: - logger = WebSocketServer.get_logger() + uri = urlparse(uri) - blen = len(buf) - f['left'] = blen + port = uri.port + if uri.scheme in ("ws", "http"): + if not port: + port = 80 + elif uri.scheme in ("wss", "https"): + if not port: + port = 443 + else: + raise Exception("Unknown scheme '%s'" % uri.scheme) - if blen < f['hlen']: - return f # Incomplete frame header + # This is a state machine in order to handle + # WantRead/WantWrite events - b1, b2 = unpack_from(">BB", buf) - f['opcode'] = b1 & 0x0f - f['fin'] = (b1 & 0x80) >> 7 - f['masked'] = (b2 & 0x80) >> 7 + if self._state == "new": + self.socket = socket.create_connection((uri.hostname, port)) - f['length'] = b2 & 0x7f + if uri.scheme in ("wss", "https"): + self.socket = ssl.wrap_socket(self.socket) + self._state = "ssl_handshake" + else: + self._state = "headers" - if f['length'] == 126: - f['hlen'] = 4 - if blen < f['hlen']: - return f # Incomplete frame header - (f['length'],) = unpack_from('>xxH', buf) - elif f['length'] == 127: - f['hlen'] = 10 - if blen < f['hlen']: - return f # Incomplete frame header - (f['length'],) = unpack_from('>xxQ', buf) + if self._state == "ssl_handshake": + self.socket.do_handshake() + self._state = "headers" - full_len = f['hlen'] + f['masked'] * 4 + f['length'] + if self._state == "headers": + self._key = '' + for i in range(16): + self._key += chr(random.randrange(256)) + if sys.hexversion >= 0x3000000: + self._key = bytes(self._key, "latin-1") + self._key = b64encode(self._key).decode("ascii") - if blen < full_len: # Incomplete frame - return f # Incomplete frame header + path = uri.path + if not path: + path = "/" - # Number of bytes that are part of the next frame(s) - f['left'] = blen - full_len + self._queue_str("GET %s HTTP/1.1\r\n" % path) + self._queue_str("Host: %s\r\n" % uri.hostname) + self._queue_str("Upgrade: websocket\r\n") + self._queue_str("Connection: upgrade\r\n") + self._queue_str("Sec-WebSocket-Key: %s\r\n" % self._key) + self._queue_str("Sec-WebSocket-Version: 13\r\n") - # Process 1 frame - if f['masked']: - # unmask payload - f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'], - f['length']) - else: - logger.debug("Unmasked frame: %s" % repr(buf)) + if origin is not None: + self._queue_str("Origin: %s\r\n" % origin) + if len(protocols) > 0: + self._queue_str("Sec-WebSocket-Protocol: %s\r\n" % ", ".join(protocols)) - if strict: - raise WebSocketRequestHandler.CClose(1002, "The client sent an unmasked frame.") + self._queue_str("\r\n") - f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] + self._state = "send_headers" - if f['opcode'] == 0x08: - if f['length'] >= 2: - f['close_code'] = unpack_from(">H", f['payload'])[0] - if f['length'] > 3: - f['close_reason'] = f['payload'][2:] + if self._state == "send_headers": + self._flush() + self._state = "response" - return f + if self._state == "response": + if not self._recv(): + raise Exception("Socket closed unexpectedly") + if self._recv_buffer.find('\r\n\r\n'.encode("ascii")) == -1: + raise WebSocketWantReadError - # - # WebSocketRequestHandler logging/output functions - # - - def print_traffic(self, token="."): - """ Show traffic flow mode. """ - if self.traffic: - sys.stdout.write(token) - sys.stdout.flush() - - def msg(self, msg, *args, **kwargs): - """ Output message with handler_id prefix. """ - prefix = "% 3d: " % self.handler_id - self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) - - def vmsg(self, msg, *args, **kwargs): - """ Same as msg() but as debug. """ - prefix = "% 3d: " % self.handler_id - self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs) - - def warn(self, msg, *args, **kwargs): - """ Same as msg() but as warning. """ - prefix = "% 3d: " % self.handler_id - self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs) - - # - # Main WebSocketRequestHandler methods - # - def send_frames(self, bufs=None): - """ Encode and send WebSocket frames. Any frames already - queued will be sent first. If buf is not set then only queued - frames will be sent. Returns the number of pending frames that - could not be fully sent. If returned pending frames is greater - than 0, then the caller should call again when the socket is - ready. """ - - tdelta = int(time.time()*1000) - self.start_time - - if bufs: - for buf in bufs: - encbuf, lenhead, lentail = self.encode_hybi(buf, opcode=2) - - if self.rec: - self.rec.write("%s,\n" % - repr("{%s{" % tdelta - + encbuf[lenhead:len(encbuf)-lentail])) - - self.send_parts.append(encbuf) - - while self.send_parts: - # Send pending frames - buf = self.send_parts.pop(0) - sent = self.request.send(buf) - - if sent == len(buf): - self.print_traffic("<") - else: - self.print_traffic("<.") - self.send_parts.insert(0, buf[sent:]) - break + (request, self._recv_buffer) = self._recv_buffer.split('\r\n'.encode("ascii"), 1) + request = request.decode("latin-1") - return len(self.send_parts) + words = request.split() + if (len(words) < 2) or (words[0] != "HTTP/1.1"): + raise Exception("Invalid response") + if words[1] != "101": + raise Exception("WebSocket request denied: %s" % " ".join(words[1:])) - def recv_frames(self): - """ Receive and decode WebSocket frames. + (headers, self._recv_buffer) = self._recv_buffer.split('\r\n\r\n'.encode("ascii"), 1) + headers = headers.decode('latin-1') + '\r\n' + headers = email.message_from_string(headers) - Returns: - (bufs_list, closed_string) - """ + if headers.get("Upgrade", "").lower() != "websocket": + print(type(headers)) + raise Exception("Missing or incorrect upgrade header") - closed = False - bufs = [] - tdelta = int(time.time()*1000) - self.start_time - - buf = self.request.recv(self.buffer_size) - if len(buf) == 0: - closed = {'code': 1000, 'reason': "Client closed abruptly"} - return bufs, closed - - if self.recv_part: - # Add partially received frames to current read buffer - buf = self.recv_part + buf - self.recv_part = None - - while buf: - frame = self.decode_hybi(buf, - logger=self.logger, - strict=self.strict_mode) - #self.msg("Received buf: %s, frame: %s", repr(buf), frame) - - if frame['payload'] == None: - # Incomplete/partial frame - self.print_traffic("}.") - if frame['left'] > 0: - self.recv_part = buf[-frame['left']:] - break - else: - if frame['opcode'] == 0x8: # connection close - closed = {'code': frame['close_code'], - 'reason': frame['close_reason']} - break - elif self.auto_pong and frame['opcode'] == 0x9: # ping - self.print_traffic("} ping %s\n" % - repr(frame['payload'])) - self.send_pong(frame['payload']) - return [], False - elif frame['opcode'] == 0xA: # pong - self.print_traffic("} pong %s\n" % - repr(frame['payload'])) - return [], False - - self.print_traffic("}") - - if self.rec: - start = frame['hlen'] - end = frame['hlen'] + frame['length'] - if frame['masked']: - recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'], - frame['length']) - else: - recbuf = buf[frame['hlen']:frame['hlen'] + - frame['length']] - self.rec.write("%s,\n" % - repr("}%s}" % tdelta + recbuf)) + accept = headers.get('Sec-WebSocket-Accept') + if accept is None: + raise Exception("Missing Sec-WebSocket-Accept header"); + expected = sha1((self._key + self.GUID).encode("ascii")).digest() + expected = b64encode(expected).decode("ascii") - bufs.append(frame['payload']) + del self._key - if frame['left']: - buf = buf[-frame['left']:] - else: - buf = '' + if accept != expected: + raise Exception("Invalid Sec-WebSocket-Accept header"); + + self._state = "done" - return bufs, closed + return - def send_close(self, code=1000, reason=''): - """ Send a WebSocket orderly close frame. """ + raise Exception("WebSocket is in an invalid state") - msg = pack(">H%ds" % len(reason), code, s2b(reason)) - buf, h, t = self.encode_hybi(msg, opcode=0x08) - self.request.send(buf) + def accept(self, socket, headers): + """Establishes a new WebSocket session with a client. - def send_pong(self, data=''): - """ Send a WebSocket pong frame. """ - buf, h, t = self.encode_hybi(s2b(data), opcode=0x0A) - self.request.send(buf) + This method negotiates a WebSocket connection with an incoming + client. The caller must provide the client socket and the + headers from the HTTP request. + + A server can identify that a client is requesting a WebSocket + connection by looking at the "Upgrade" header. It will include + the value "websocket" in such cases. + + WebSocketWantWriteError can be raised if the response cannot be + sent right away. Repeated calls to accept() does not need to + retain the arguments. + """ - def send_ping(self, data=''): - """ Send a WebSocket ping frame. """ - buf, h, t = self.encode_hybi(s2b(data), opcode=0x09) - self.request.send(buf) + # This is a state machine in order to handle + # WantRead/WantWrite events - def do_websocket_handshake(self): - h = self.headers + if self._state == "new": + self.client = False + self.socket = socket - prot = 'WebSocket-Protocol' - protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') + if headers.get("upgrade", "").lower() != "websocket": + raise Exception("Missing or incorrect upgrade header") - ver = h.get('Sec-WebSocket-Version') - if ver: - # HyBi/IETF version of the protocol + ver = headers.get('Sec-WebSocket-Version') + if ver is None: + raise Exception("Missing Sec-WebSocket-Version header"); # HyBi-07 report version 7 # HyBi-08 - HyBi-12 report version 8 @@ -415,635 +268,551 @@ def do_websocket_handshake(self): if ver in ['7', '8', '13']: self.version = "hybi-%02d" % int(ver) else: - self.send_error(400, "Unsupported protocol version %s" % ver) - return False + raise Exception("Unsupported protocol version %s" % ver) - key = h['Sec-WebSocket-Key'] + key = headers.get('Sec-WebSocket-Key') + if key is None: + raise Exception("Missing Sec-WebSocket-Key header"); # Generate the hash value for the accept header - accept = b64encode(sha1(s2b(key + self.GUID)).digest()) + accept = sha1((key + self.GUID).encode("ascii")).digest() + accept = b64encode(accept).decode("ascii") - self.send_response(101, "Switching Protocols") - self.send_header("Upgrade", "websocket") - self.send_header("Connection", "Upgrade") - self.send_header("Sec-WebSocket-Accept", b2s(accept)) - self.end_headers() + self.protocol = '' + protocols = headers.get('Sec-WebSocket-Protocol', '').split(',') + if protocols: + self.protocol = self.select_subprotocol(protocols) - # Other requests cannot follow Websocket data - self.close_connection = True + self._queue_str("HTTP/1.1 101 Switching Protocols\r\n") + self._queue_str("Upgrade: websocket\r\n") + self._queue_str("Connection: Upgrade\r\n") + self._queue_str("Sec-WebSocket-Accept: %s\r\n" % accept) - return True - else: - self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + if self.protocol: + self._queue_str("Sec-WebSocket-Protocol: %s\r\n" % self.protocol) - return False + self._queue_str("\r\n") - def handle_websocket(self): - """Upgrade a connection to Websocket, if requested. If this succeeds, - new_websocket_client() will be called. Otherwise, False is returned. - """ + self._state = "flush" - if (self.headers.get('upgrade') and - self.headers.get('upgrade').lower() == 'websocket'): + if self._state == "flush": + self._flush() + self._state = "done" - # ensure connection is authorized, and determine the target - self.validate_connection() + return - if not self.do_websocket_handshake(): - return False + raise Exception("WebSocket is in an invalid state") - # Indicate to server that a Websocket upgrade was done - self.server.ws_connection = True - # Initialize per client settings - self.send_parts = [] - self.recv_part = None - self.start_time = int(time.time()*1000) + def select_subprotocol(self, protocols): + """Returns which sub-protocol should be used. - # client_address is empty with, say, UNIX domain sockets - client_addr = "" - is_ssl = False - try: - client_addr = self.client_address[0] - is_ssl = self.client_address[2] - except IndexError: - pass + This method does not select any sub-protocol by default and is + meant to be overridden by an implementation that wishes to make + use of sub-protocols. It will be called during handling of + accept(). + """ + return "" - if is_ssl: - self.stype = "SSL/TLS (wss://)" - else: - self.stype = "Plain non-SSL (ws://)" + def handle_ping(self, data): + """Called when a WebSocket ping message is received. - self.log_message("%s: %s WebSocket connection", client_addr, - self.stype) - if self.path != '/': - self.log_message("%s: Path: '%s'", client_addr, self.path) + This will be called whilst processing recv()/recvmsg(). The + default implementation sends a pong reply back.""" + self.pong(data) - if self.record: - # Record raw frame data as JavaScript array - fname = "%s.%s" % (self.record, - self.handler_id) - self.log_message("opening record file: %s", fname) - self.rec = open(fname, 'w+') - self.rec.write("var VNC_frame_data = [\n") + def handle_pong(self, data): + """Called when a WebSocket pong message is received. - try: - self.new_websocket_client() - except self.CClose: - # Close the client - _, exc, _ = sys.exc_info() - self.send_close(exc.args[0], exc.args[1]) - return True - else: - return False + This will be called whilst processing recv()/recvmsg(). The + default implementation does nothing.""" + pass - def do_GET(self): - """Handle GET request. Calls handle_websocket(). If unsuccessful, - and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called.""" - if not self.handle_websocket(): - if self.only_upgrade: - self.send_error(405, "Method Not Allowed") - else: - SimpleHTTPRequestHandler.do_GET(self) + def recv(self): + """Read data from the WebSocket. - def list_directory(self, path): - if self.file_only: - self.send_error(404, "No such file") - else: - return SimpleHTTPRequestHandler.list_directory(self, path) + This will return any available data on the socket. If the + socket is closed then an empty buffer will be returned. The + reason for the close is found in the 'close_code' and + 'close_reason' properties. - def new_websocket_client(self): - """ Do something with a WebSockets client connection. """ - raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + Unlike recvmsg() this method may return data from more than one + WebSocket message. It is however not guaranteed to return all + buffered data. Callers should continue calling recv() whilst + pending() returns True. - def validate_connection(self): - """ Ensure that the connection is a valid connection, and set the target. """ - pass + Both WebSocketWantReadError and WebSocketWantWriteError can be + raised when calling recv(). + """ + return self.recvmsg() - def do_HEAD(self): - if self.only_upgrade: - self.send_error(405, "Method Not Allowed") - else: - SimpleHTTPRequestHandler.do_HEAD(self) - - def finish(self): - if self.rec: - self.rec.write("'EOF'];\n") - self.rec.close() - - def handle(self): - # When using run_once, we have a single process, so - # we cannot loop in BaseHTTPRequestHandler.handle; we - # must return and handle new connections - if self.run_once: - self.handle_one_request() - else: - SimpleHTTPRequestHandler.handle(self) + def recvmsg(self): + """Read a single message from the WebSocket. - def log_request(self, code='-', size='-'): - if self.verbose: - SimpleHTTPRequestHandler.log_request(self, code, size) + This will return a single WebSocket message from the socket. + If the socket is closed then an empty buffer will be returned. + The reason for the close is found in the 'close_code' and + 'close_reason' properties. + Unlike recv() this method will not return data from more than + one WebSocket message. Callers should continue calling + recvmsg() whilst pending() returns True. -class WebSocketServer(object): - """ - WebSockets server class. - As an alternative, the standard library SocketServer can be used - """ + Both WebSocketWantReadError and WebSocketWantWriteError can be + raised when calling recvmsg(). + """ + # May have been called to flush out a close + if self._received_close: + self._flush() + return ''.encode("ascii") + + # Anything already queued? + if self.pending(): + msg = self._recvmsg() + if msg is not None: + return msg + + # Note: We cannot proceed to self._recv() here as we may + # have already called it once as part of the caller's + # "while websock.pending():" loop + raise WebSocketWantReadError + + # Nope, let's try to read a bit + if not self._recv_frames(): + return ''.encode("ascii") + + # Anything queued now? + msg = self._recvmsg() + if msg is not None: + return msg + + # Still nope + raise WebSocketWantReadError + + def pending(self): + """Check if any WebSocket data is pending. + + This method will return True as long as there are WebSocket + frames that have yet been processed. A single recv() from the + underlying socket may return multiple WebSocket frames and it + is therefore important that a caller continues calling recv() + or recvmsg() as long as pending() returns True. + + Note that this function merely tells if there are raw WebSocket + frames pending. Those frames may not contain any application + data. + """ + return len(self._recv_queue) > 0 - policy_response = """\n""" - log_prefix = "websocket" + def send(self, bytes): + """Write data to the WebSocket - # An exception before the WebSocket connection was established - class EClose(Exception): - pass + This will queue the given data and attempt to send it to the + peer. Unlike sendmsg() this method might coalesce the data with + data from other calls, or split it over multiple messages. - class Terminate(Exception): - pass + WebSocketWantWriteError can be raised if there is insufficient + space in the underlying socket. + """ + return self.sendmsg(bytes) - def __init__(self, RequestHandlerClass, listen_host='', - listen_port=None, source_is_ipv6=False, - verbose=False, cert='', key='', ssl_only=None, - daemon=False, record='', web='', - file_only=False, - run_once=False, timeout=0, idle_timeout=0, traffic=False, - tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, - tcp_keepintvl=None, auto_pong=False, strict_mode=True): - - # settings - self.RequestHandlerClass = RequestHandlerClass - self.verbose = verbose - self.listen_host = listen_host - self.listen_port = listen_port - self.prefer_ipv6 = source_is_ipv6 - self.ssl_only = ssl_only - self.daemon = daemon - self.run_once = run_once - self.timeout = timeout - self.idle_timeout = idle_timeout - self.traffic = traffic - self.file_only = file_only - self.strict_mode = strict_mode - - self.launch_time = time.time() - self.ws_connection = False - self.handler_id = 1 - - self.logger = self.get_logger() - self.tcp_keepalive = tcp_keepalive - self.tcp_keepcnt = tcp_keepcnt - self.tcp_keepidle = tcp_keepidle - self.tcp_keepintvl = tcp_keepintvl - - self.auto_pong = auto_pong - # Make paths settings absolute - self.cert = os.path.abspath(cert) - self.key = self.web = self.record = '' - if key: - self.key = os.path.abspath(key) - if web: - self.web = os.path.abspath(web) - if record: - self.record = os.path.abspath(record) - - if self.web: - os.chdir(self.web) - self.only_upgrade = not self.web - - # Sanity checks - if self.daemon and not resource: - raise Exception("Module 'resource' required to daemonize") - - # Show configuration - self.msg("WebSocket server settings:") - self.msg(" - Listen on %s:%s", - self.listen_host, self.listen_port) - self.msg(" - Flash security policy server") - if self.web: - if self.file_only: - self.msg(" - Web server (no directory listings). Web root: %s", self.web) - else: - self.msg(" - Web server. Web root: %s", self.web) - if os.path.exists(self.cert): - self.msg(" - SSL/TLS support") - if self.ssl_only: - self.msg(" - Deny non-SSL/TLS connections") - else: - self.msg(" - No SSL/TLS support (no cert file)") - if self.daemon: - self.msg(" - Backgrounding (daemon)") - if self.record: - self.msg(" - Recording to '%s.*'", self.record) - - # - # WebSocketServer static methods - # - - @staticmethod - def get_logger(): - return logging.getLogger("%s.%s" % ( - WebSocketServer.log_prefix, - WebSocketServer.__class__.__name__)) - - @staticmethod - def socket(host, port=None, connect=False, prefer_ipv6=False, - unix_socket=None, use_ssl=False, tcp_keepalive=True, - tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): - """ Resolve a host (and optional port) to an IPv4 or IPv6 - address. Create a socket. Bind to it if listen is set, - otherwise connect to it. Return the socket. + def sendmsg(self, msg): + """Write a single message to the WebSocket + + This will queue the given message and attempt to send it to the + peer. Unlike send() this method will preserve the data as a + single WebSocket message. + + WebSocketWantWriteError can be raised if there is insufficient + space in the underlying socket. """ - flags = 0 - if host == '': - host = None - if connect and not (port or unix_socket): - raise Exception("Connect mode requires a port") - if not connect and use_ssl: - raise Exception("SSL only supported in connect mode (for now)") - if not connect: - flags = flags | socket.AI_PASSIVE - - if not unix_socket: - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, - socket.IPPROTO_TCP, flags) - if not addrs: - raise Exception("Could not resolve host '%s'" % host) - addrs.sort(key=lambda x: x[0]) - if prefer_ipv6: - addrs.reverse() - sock = socket.socket(addrs[0][0], addrs[0][1]) - - if tcp_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if tcp_keepcnt: - if hasattr(socket, 'TCP_KEEPCNT'): - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, - tcp_keepcnt) - else: - self.msg('tcp_keepcnt not available on your system') - if tcp_keepidle: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, - tcp_keepidle) - if tcp_keepintvl: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, - tcp_keepintvl) - - if connect: - sock.connect(addrs[0][4]) - if use_ssl: - sock = ssl.wrap_socket(sock) - else: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(addrs[0][4]) - sock.listen(100) - else: - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(unix_socket) + if not self._sent_close: + # Only called to flush? + if msg: + self._sendmsg(0x2, msg) - return sock + self._flush() + return len(msg) - @staticmethod - def daemonize(keepfd=None, chdir='/'): - - if keepfd is None: - keepfd = [] + def ping(self, data=None): + """Write a ping message to the WebSocket.""" + self._sendmg(0x9, data) - os.umask(0) - if chdir: - os.chdir(chdir) - else: - os.chdir('/') - os.setgid(os.getgid()) # relinquish elevations - os.setuid(os.getuid()) # relinquish elevations - - # Double fork to daemonize - if os.fork() > 0: os._exit(0) # Parent exits - os.setsid() # Obtain new process group - if os.fork() > 0: os._exit(0) # Parent exits - - # Signal handling - signal.signal(signal.SIGTERM, signal.SIG_IGN) - signal.signal(signal.SIGINT, signal.SIG_IGN) - - # Close open files - maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] - if maxfd == resource.RLIM_INFINITY: maxfd = 256 - for fd in reversed(range(maxfd)): - try: - if fd not in keepfd: - os.close(fd) - except OSError: - _, exc, _ = sys.exc_info() - if exc.errno != errno.EBADF: raise - - # Redirect I/O to /dev/null - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) - - def do_handshake(self, sock, address): + def pong(self, data=None): + """Write a pong message to the WebSocket.""" + self._sendmg(0xA, data) + + def shutdown(self, how, code=1000, reason=None): + """Gracefully terminate the WebSocket connection. + + This will start the process to terminate the WebSocket + connection. The caller must continue to calling recv() or + recvmsg() after this function in order to wait for the peer to + acknowledge the close. Calls to send() and sendmsg() will be + ignored. + + WebSocketWantWriteError can be raised if there is insufficient + space in the underlying socket for the close message. + + The how argument is currently ignored. """ - do_handshake does the following: - - Peek at the first few bytes from the socket. - - If the connection is Flash policy request then answer it, - close the socket and return. - - If the connection is an HTTPS/SSL/TLS connection then SSL - wrap the socket. - - Read from the (possibly wrapped) socket. - - If we have received a HTTP GET request and the webserver - functionality is enabled, answer it, close the socket and - return. - - Assume we have a WebSockets connection, parse the client - handshake data. - - Send a WebSockets handshake server response. - - Return the socket for this WebSocket client. + + # Already closing? + if self._sent_close: + self._flush() + return + + # Special code to indicate that we closed the connection + if not self._received_close: + self.close_code = 1000 + self.close_reason = "Locally initiated close" + + self._sent_close = True + + msg = ''.encode('ascii') + if code is not None: + msg += struct.pack(">H", code) + if reason is not None: + msg += reason.encode("UTF-8") + + self._sendmsg(0x8, msg) + + def close(self, code=1000, reason=None): + """Terminate the WebSocket connection immediately. + + This will close the WebSocket connection directly after sending + a close message to the peer. + + WebSocketWantWriteError can be raised if there is insufficient + space in the underlying socket for the close message. """ - ready = select.select([sock], [], [], 3)[0] - - - if not ready: - raise self.EClose("ignoring socket not ready") - # Peek, but do not read the data so that we have a opportunity - # to SSL wrap the socket first - handshake = sock.recv(1024, socket.MSG_PEEK) - #self.msg("Handshake [%s]" % handshake) - - if not handshake: - raise self.EClose("ignoring empty handshake") - - elif handshake.startswith(s2b("")): - # Answer Flash policy request - handshake = sock.recv(1024) - sock.send(s2b(self.policy_response)) - raise self.EClose("Sending flash policy response") - - elif handshake[0] in ("\x16", "\x80", 22, 128): - # SSL wrap the connection - if not os.path.exists(self.cert): - raise self.EClose("SSL connection but '%s' not found" - % self.cert) - retsock = None + self.shutdown(socket.SHUT_RDWR, code, reason) + self._close() + + def _recv(self): + # Fetches more data from the socket to the buffer + assert self.socket is not None + + while True: try: - retsock = ssl.wrap_socket( - sock, - server_side=True, - certfile=self.cert, - keyfile=self.key) - except ssl.SSLError: - _, x, _ = sys.exc_info() - if x.args[0] == ssl.SSL_ERROR_EOF: - if len(x.args) > 1: - raise self.EClose(x.args[1]) - else: - raise self.EClose("Got SSL_ERROR_EOF") + data = self.socket.recv(4096) + except (socket.error, OSError): + exc = sys.exc_info()[1] + if hasattr(exc, 'errno'): + err = exc.errno else: - raise + err = exc[0] - elif self.ssl_only: - raise self.EClose("non-SSL connection received but disallowed") + if err == errno.EWOULDBLOCK: + raise WebSocketWantReadError - else: - retsock = sock + raise + + if len(data) == 0: + return False - # If the address is like (host, port), we are extending it - # with a flag indicating SSL. Not many other options - # available... - if len(address) == 2: - address = (address[0], address[1], (retsock != sock)) + self._recv_buffer += data - self.RequestHandlerClass(retsock, address, self) + # Support for SSLSocket like objects + if hasattr(self.socket, "pending"): + if not self.socket.pending(): + break + else: + break - # Return the WebSockets socket which may be SSL wrapped - return retsock + return True - # - # WebSocketServer logging/output functions - # + def _recv_frames(self): + # Fetches more data and decodes the frames + if not self._recv(): + if self.close_code is None: + self.close_code = 1006 + self.close_reason = "Connection closed abnormally" + self._sent_close = self._received_close = True + self._close() + return False - def msg(self, *args, **kwargs): - """ Output message as info """ - self.logger.log(logging.INFO, *args, **kwargs) + while True: + frame = self._decode_hybi(self._recv_buffer) + if frame is None: + break + self._recv_buffer = self._recv_buffer[frame['length']:] + self._recv_queue.append(frame) + + return True + + def _recvmsg(self): + # Process pending frames and returns any application data + while self._recv_queue: + frame = self._recv_queue.pop(0) + + if not self.client and not frame['masked']: + self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame not masked") + continue + if self.client and frame['masked']: + self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Frame masked") + continue + + if frame["opcode"] == 0x0: + if not self._partial_msg: + self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected continuation frame") + continue + + self._partial_msg += frame["payload"] + + if frame["fin"]: + msg = self._partial_msg + self._partial_msg = ''.decode("ascii") + return msg + elif frame["opcode"] == 0x2: + if self._partial_msg: + self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Unexpected new frame") + continue + + if frame["fin"]: + return frame["payload"] + else: + self._partial_msg = frame["payload"] + elif frame["opcode"] == 0x8: + if self._received_close: + continue + + self._received_close = True + + if self._sent_close: + self._close() + return ''.encode("ascii") + + if not frame["fin"]: + self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented close") + continue + + code = None + reason = None + if len(frame["payload"]) >= 2: + code = struct.unpack(">H", frame["payload"][:2]) + if len(frame["payload"]) > 2: + reason = frame["payload"][2:] + try: + reason = reason.decode("UTF-8") + except UnicodeDecodeError: + self.shutdown(socket.SHUT_RDWR, 1002, "Procotol error: Invalid UTF-8 in close") + continue + + if code is None: + self.close_code = 1005 + self.close_reason = "No close status code specified by peer" + else: + self.close_code = code + if reason is not None: + self.close_reason = reason + + self.shutdown(code, reason) + return ''.encode("ascii") + elif frame["opcode"] == 0x9: + if not frame["fin"]: + self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented ping") + continue + + self.handle_ping(frame["payload"]) + elif frame["opcode"] == 0xA: + if not frame["fin"]: + self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Fragmented pong") + continue + + self.handle_pong(frame["payload"]) + else: + self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Unknown opcode 0x%02x" % frame["opcode"]) - def vmsg(self, *args, **kwargs): - """ Same as msg() but as debug. """ - self.logger.log(logging.DEBUG, *args, **kwargs) + return None - def warn(self, *args, **kwargs): - """ Same as msg() but as warning. """ - self.logger.log(logging.WARN, *args, **kwargs) + def _flush(self): + # Writes pending data to the socket + if not self._send_buffer: + return + assert self.socket is not None - # - # Events that can/should be overridden in sub-classes - # - def started(self): - """ Called after WebSockets startup """ - self.vmsg("WebSockets server started") + try: + sent = self.socket.send(self._send_buffer) + except (socket.error, OSError): + exc = sys.exc_info()[1] + if hasattr(exc, 'errno'): + err = exc.errno + else: + err = exc[0] + + if err == errno.EWOULDBLOCK: + raise WebSocketWantWriteError + + raise + + self._send_buffer = self._send_buffer[sent:] + + if self._send_buffer: + raise WebSocketWantWriteError + + # We had a pending close and we've flushed the buffer, + # time to end things + if self._received_close and self._sent_close: + self._close() + + def _send(self, data): + # Queues data and attempts to send it + self._send_buffer += data + self._flush() + + def _queue_str(self, string): + # Queue some data to be sent later. + # Only used by the connecting methods. + self._send_buffer += string.encode("latin-1") + + def _sendmsg(self, opcode, msg): + # Sends a standard data message + if self.client: + mask = '' + for i in range(4): + mask += chr(random.randrange(256)) + if sys.hexversion >= 0x3000000: + mask = bytes(mask, "latin-1") + frame = self._encode_hybi(opcode, msg, mask) + else: + frame = self._encode_hybi(opcode, msg) - def poll(self): - """ Run periodically while waiting for connections. """ - #self.vmsg("Running poll()") - pass + return self._send(frame) - def terminate(self): - raise self.Terminate() + def _close(self): + # Close the underlying socket + self.socket.close() + self.socket = None - def multiprocessing_SIGCHLD(self, sig, stack): - # TODO: figure out a way to actually log this information without - # calling `log` in the signal handlers - multiprocessing.active_children() + def _mask(self, buf, mask): + # Mask a frame + return self._unmask(buf, mask) - def fallback_SIGCHLD(self, sig, stack): - # Reap zombies when using os.fork() (python 2.4) - # TODO: figure out a way to actually log this information without - # calling `log` in the signal handlers - try: - result = os.waitpid(-1, os.WNOHANG) - while result[0]: - self.vmsg("Reaped child process %s" % result[0]) - result = os.waitpid(-1, os.WNOHANG) - except (OSError): - pass - - def do_SIGINT(self, sig, stack): - # TODO: figure out a way to actually log this information without - # calling `log` in the signal handlers - self.terminate() - - def do_SIGTERM(self, sig, stack): - # TODO: figure out a way to actually log this information without - # calling `log` in the signal handlers - self.terminate() - - def top_new_client(self, startsock, address): - """ Do something with a WebSockets client connection. """ - # handler process - client = None - try: - try: - client = self.do_handshake(startsock, address) - except self.EClose: - _, exc, _ = sys.exc_info() - # Connection was not a WebSockets connection - if exc.args[0]: - self.msg("%s: %s" % (address[0], exc.args[0])) - except WebSocketServer.Terminate: - raise - except Exception: - _, exc, _ = sys.exc_info() - self.msg("handler exception: %s" % str(exc)) - self.vmsg("exception", exc_info=True) - finally: - - if client and client != startsock: - # Close the SSL wrapped socket - # Original socket closed by caller - client.close() - - def get_log_fd(self): - """ - Get file descriptors for the loggers. - They should not be closed when the process is forked. - """ - descriptors = [] - for handler in self.logger.parent.handlers: - if isinstance(handler, logging.FileHandler): - descriptors.append(handler.stream.fileno()) + def _unmask(self, buf, mask): + # Unmask a frame + if numpy: + plen = len(buf) + pstart = 0 + pend = plen + b = c = ''.encode('ascii') + if plen >= 4: + dtype=numpy.dtype('') + mask = numpy.frombuffer(mask, dtype, count=1) + data = numpy.frombuffer(buf, dtype, count=int(plen / 4)) + #b = numpy.bitwise_xor(data, mask).data + b = numpy.bitwise_xor(data, mask).tostring() - return descriptors + if plen % 4: + dtype=numpy.dtype('B') + if sys.byteorder == 'big': + dtype = dtype.newbyteorder('>') + mask = numpy.frombuffer(mask, dtype, count=(plen % 4)) + data = numpy.frombuffer(buf, dtype, + offset=plen - (plen % 4), count=(plen % 4)) + c = numpy.bitwise_xor(data, mask).tostring() + return b + c + else: + # Slower fallback + if sys.hexversion < 0x3000000: + mask = [ ord(c) for c in mask ] + data = array.array('B') + data.fromstring(buf) + for i in range(len(data)): + data[i] ^= mask[i % 4] + return data.tostring() - def start_server(self): + def _encode_hybi(self, opcode, buf, mask_key=None, fin=True): + """ Encode a HyBi style WebSocket frame. + Optional opcode: + 0x0 - continuation + 0x1 - text frame + 0x2 - binary frame + 0x8 - connection close + 0x9 - ping + 0xA - pong """ - Daemonize if requested. Listen for for connections. Run - do_handshake() method for each connection. If the connection - is a WebSockets client then call new_websocket_client() method (which must - be overridden) for each new client connection. + + b1 = opcode & 0x0f + if fin: + b1 |= 0x80 + + mask_bit = 0 + if mask_key is not None: + mask_bit = 0x80 + buf = self._mask(buf, mask_key) + + payload_len = len(buf) + if payload_len <= 125: + header = struct.pack('>BB', b1, payload_len | mask_bit) + elif payload_len > 125 and payload_len < 65536: + header = struct.pack('>BBH', b1, 126 | mask_bit, payload_len) + elif payload_len >= 65536: + header = struct.pack('>BBQ', b1, 127 | mask_bit, payload_len) + + if mask_key is not None: + return header + mask_key + buf + else: + return header + buf + + def _decode_hybi(self, buf): + """ Decode HyBi style WebSocket packets. + Returns: + {'fin' : boolean, + 'opcode' : number, + 'masked' : boolean, + 'length' : encoded_length, + 'payload' : decoded_buffer} """ - lsock = self.socket(self.listen_host, self.listen_port, False, - self.prefer_ipv6, - tcp_keepalive=self.tcp_keepalive, - tcp_keepcnt=self.tcp_keepcnt, - tcp_keepidle=self.tcp_keepidle, - tcp_keepintvl=self.tcp_keepintvl) - - if self.daemon: - keepfd = self.get_log_fd() - keepfd.append(lsock.fileno()) - self.daemonize(keepfd=keepfd, chdir=self.web) - - self.started() # Some things need to happen after daemonizing - - # Allow override of signals - original_signals = { - signal.SIGINT: signal.getsignal(signal.SIGINT), - signal.SIGTERM: signal.getsignal(signal.SIGTERM), - signal.SIGCHLD: signal.getsignal(signal.SIGCHLD), - } - signal.signal(signal.SIGINT, self.do_SIGINT) - signal.signal(signal.SIGTERM, self.do_SIGTERM) - # make sure that _cleanup is called when children die - # by calling active_children on SIGCHLD - signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) - - last_active_time = self.launch_time - try: - while True: - try: - try: - startsock = None - pid = err = 0 - child_count = 0 - - # Collect zombie child processes - child_count = len(multiprocessing.active_children()) - - time_elapsed = time.time() - self.launch_time - if self.timeout and time_elapsed > self.timeout: - self.msg('listener exit due to --timeout %s' - % self.timeout) - break - - if self.idle_timeout: - idle_time = 0 - if child_count == 0: - idle_time = time.time() - last_active_time - else: - idle_time = 0 - last_active_time = time.time() - - if idle_time > self.idle_timeout and child_count == 0: - self.msg('listener exit due to --idle-timeout %s' - % self.idle_timeout) - break - try: - self.poll() - - ready = select.select([lsock], [], [], 1)[0] - if lsock in ready: - startsock, address = lsock.accept() - else: - continue - except self.Terminate: - raise - except Exception: - _, exc, _ = sys.exc_info() - if hasattr(exc, 'errno'): - err = exc.errno - elif hasattr(exc, 'args'): - err = exc.args[0] - else: - err = exc[0] - if err == errno.EINTR: - self.vmsg("Ignoring interrupted syscall") - continue - else: - raise - - if self.run_once: - # Run in same process if run_once - self.top_new_client(startsock, address) - if self.ws_connection : - self.msg('%s: exiting due to --run-once' - % address[0]) - break - else: - self.vmsg('%s: new handler Process' % address[0]) - p = multiprocessing.Process( - target=self.top_new_client, - args=(startsock, address)) - p.start() - # child will not return - - # parent process - self.handler_id += 1 - - except (self.Terminate, SystemExit, KeyboardInterrupt): - self.msg("In exit") - # terminate all child processes - if not self.run_once: - children = multiprocessing.active_children() - - for child in children: - self.msg("Terminating child %s" % child.pid) - child.terminate() - - break - except Exception: - exc = sys.exc_info()[1] - self.msg("handler exception: %s", str(exc)) - self.vmsg("exception", exc_info=True) - - finally: - if startsock: - startsock.close() - finally: - # Close listen port - self.vmsg("Closing socket listening at %s:%s", - self.listen_host, self.listen_port) - lsock.close() - - # Restore signals - for sig, func in original_signals.items(): - signal.signal(sig, func) + f = {'fin' : 0, + 'opcode' : 0, + 'masked' : False, + 'length' : 0, + 'payload' : None} + + blen = len(buf) + hlen = 2 + + if blen < hlen: + return None + b1, b2 = struct.unpack(">BB", buf[:2]) + f['opcode'] = b1 & 0x0f + f['fin'] = not not (b1 & 0x80) + f['masked'] = not not (b2 & 0x80) + + if f['masked']: + hlen += 4 + if blen < hlen: + return None + + length = b2 & 0x7f + + if length == 126: + hlen += 2 + if blen < hlen: + return None + length, = struct.unpack('>H', buf[2:4]) + elif length == 127: + hlen += 8 + if blen < hlen: + return None + length, = struct.unpack('>Q', buf[2:10]) + + f['length'] = hlen + length + + if blen < f['length']: + return None + + if f['masked']: + # unmask payload + mask_key = buf[hlen-4:hlen] + f['payload'] = self._unmask(buf[hlen:(hlen+length)], mask_key) + else: + f['payload'] = buf[hlen:(hlen+length)] + + return f diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 32b3e637..117e751a 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -17,7 +17,7 @@ try: from http.server import HTTPServer except: from BaseHTTPServer import HTTPServer import select -from websockify import websocket +from websockify import websocketserver from websockify import auth_plugins as auth try: from urllib.parse import parse_qs, urlparse @@ -25,7 +25,9 @@ from cgi import parse_qs from urlparse import urlparse -class ProxyRequestHandler(websocket.WebSocketRequestHandler): +class ProxyRequestHandler(websocketserver.WebSocketRequestHandler): + + buffer_size = 65536 traffic_legend = """ Traffic Legend: @@ -86,9 +88,11 @@ def new_websocket_client(self): msg += " (using SSL)" self.log_message(msg) - tsock = websocket.WebSocketServer.socket(self.server.target_host, - self.server.target_port, - connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target) + tsock = websocketserver.WebSocketServer.socket(self.server.target_host, + self.server.target_port, + connect=True, + use_ssl=self.server.ssl_target, + unix_socket=self.server.unix_target) self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) if not self.server.wrap_cmd and not self.server.unix_target: @@ -217,7 +221,7 @@ def do_proxy(self, target): cqueue.append(buf) self.print_traffic("{") -class WebSocketProxy(websocket.WebSocketServer): +class WebSocketProxy(websocketserver.WebSocketServer): """ Proxy traffic to and from a WebSockets client to a normal TCP socket server target. @@ -270,7 +274,7 @@ def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): "REBIND_OLD_PORT": str(kwargs['listen_port']), "REBIND_NEW_PORT": str(self.target_port)}) - websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs) + websocketserver.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs) def run_wrap_cmd(self): self.msg("Starting '%s'", " ".join(self.wrap_cmd)) @@ -465,7 +469,7 @@ def websockify_init(): if len(args) > 2: parser.error("Too many arguments") - if not websocket.ssl and opts.ssl_target: + if not websocketserver.ssl and opts.ssl_target: parser.error("SSL target requested and Python SSL module not loaded."); if opts.ssl_only and not os.path.exists(opts.cert): diff --git a/websockify/websocketserver.py b/websockify/websocketserver.py new file mode 100644 index 00000000..249e0289 --- /dev/null +++ b/websockify/websocketserver.py @@ -0,0 +1,819 @@ +#!/usr/bin/env python + +''' +Python WebSocket server base with support for "wss://" encryption. +Copyright 2011 Joel Martin +Copyright 2016 Pierre Ossman +Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) + +You can make a cert/key with openssl using: +openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem +as taken from http://docs.python.org/dev/library/ssl.html#certificates + +''' + +import os, sys, time, errno, signal, socket, select, logging +import multiprocessing + +# Imports that vary by python version + +# python 3.0 differences +if sys.hexversion > 0x3000000: + s2b = lambda s: s.encode('latin_1') +else: + s2b = lambda s: s # No-op +try: from http.server import SimpleHTTPRequestHandler +except: from SimpleHTTPServer import SimpleHTTPRequestHandler + +# Degraded functionality if these imports are missing +for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'), + ('resource', 'daemonizing is disabled')]: + try: + globals()[mod] = __import__(mod) + except ImportError: + globals()[mod] = None + print("WARNING: no '%s' module, %s" % (mod, msg)) + +if sys.platform == 'win32': + # make sockets pickle-able/inheritable + import multiprocessing.reduction + +from websockify.websocket import WebSocket, WebSocketWantReadError, WebSocketWantWriteError + +# HTTP handler with WebSocket upgrade support +class WebSocketRequestHandler(SimpleHTTPRequestHandler): + """ + WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler. + Must be sub-classed with new_websocket_client method definition. + The request handler can be configured by setting optional + attributes on the server object: + + * only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled, + only websocket is allowed. + * verbose: If true, verbose logging is activated. + * daemon: Running as daemon, do not write to console etc + * record: Record raw frame data as JavaScript array into specified filename + * run_once: Handle a single request + * handler_id: A sequence number for this connection, appended to record filename + """ + server_version = "WebSockify" + + protocol_version = "HTTP/1.1" + + # An exception while the WebSocket client was connected + class CClose(Exception): + pass + + def __init__(self, req, addr, server): + # Retrieve a few configuration variables from the server + self.only_upgrade = getattr(server, "only_upgrade", False) + self.verbose = getattr(server, "verbose", False) + self.daemon = getattr(server, "daemon", False) + self.record = getattr(server, "record", False) + self.run_once = getattr(server, "run_once", False) + self.rec = None + self.handler_id = getattr(server, "handler_id", False) + self.file_only = getattr(server, "file_only", False) + self.traffic = getattr(server, "traffic", False) + + self.logger = getattr(server, "logger", None) + if self.logger is None: + self.logger = WebSocketServer.get_logger() + + SimpleHTTPRequestHandler.__init__(self, req, addr, server) + + def log_message(self, format, *args): + self.logger.info("%s - - [%s] %s" % (self.address_string(), self.log_date_time_string(), format % args)) + + # + # WebSocketRequestHandler logging/output functions + # + + def print_traffic(self, token="."): + """ Show traffic flow mode. """ + if self.traffic: + sys.stdout.write(token) + sys.stdout.flush() + + def msg(self, msg, *args, **kwargs): + """ Output message with handler_id prefix. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) + + def vmsg(self, msg, *args, **kwargs): + """ Same as msg() but as debug. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) + + def warn(self, msg, *args, **kwargs): + """ Same as msg() but as warning. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs) + + # + # Main WebSocketRequestHandler methods + # + def send_frames(self, bufs=None): + """ Encode and send WebSocket frames. Any frames already + queued will be sent first. If buf is not set then only queued + frames will be sent. Returns True if any frames could not be + fully sent, in which case the caller should call again when + the socket is ready. """ + + tdelta = int(time.time()*1000) - self.start_time + + if bufs: + for buf in bufs: + if self.rec: + self.rec.write("%s,\n" % repr("{%s{" % tdelta + buf)) + self.send_parts.append(buf) + + # Flush any previously queued data + try: + self.request.sendmsg('') + except WebSocketWantWriteError: + return True + + while self.send_parts: + # Send pending frames + buf = self.send_parts.pop(0) + try: + self.request.sendmsg(buf) + except WebSocketWantWriteError: + self.print_traffic("<.") + return True + self.print_traffic("<") + + return False + + def recv_frames(self): + """ Receive and decode WebSocket frames. + + Returns: + (bufs_list, closed_string) + """ + + closed = False + bufs = [] + tdelta = int(time.time()*1000) - self.start_time + + while True: + try: + buf = self.request.recvmsg() + except WebSocketWantReadError: + self.print_traffic("}.") + break + + if len(buf) == 0: + closed = {'code': self.request.close_code, + 'reason': self.request.close_reason} + return bufs, closed + + self.print_traffic("}") + + if self.rec: + self.rec.write("%s,\n" % repr("}%s}" % tdelta + buf)) + + bufs.append(buf) + + if not self.request.pending(): + break + + return bufs, closed + + def send_close(self, code=1000, reason=''): + """ Send a WebSocket orderly close frame. """ + self.request.shutdown(code, reason) + + def send_pong(self, data=''): + """ Send a WebSocket pong frame. """ + self.request.pong(data) + + def send_ping(self, data=''): + """ Send a WebSocket ping frame. """ + self.request.ping(data) + + def handle_websocket(self): + """Upgrade a connection to Websocket, if requested. If this succeeds, + new_websocket_client() will be called. Otherwise, False is returned. + """ + + if (self.headers.get('upgrade') and + self.headers.get('upgrade').lower() == 'websocket'): + + # ensure connection is authorized, and determine the target + self.validate_connection() + + websocket = WebSocket() + try: + websocket.accept(self.request, self.headers) + except Exception: + exc = sys.exc_info()[1] + self.send_error(400, str(exc)) + return False + + self.request = websocket + + # Other requests cannot follow Websocket data + self.close_connection = True + + # Indicate to server that a Websocket upgrade was done + self.server.ws_connection = True + # Initialize per client settings + self.send_parts = [] + self.recv_part = None + self.start_time = int(time.time()*1000) + + # client_address is empty with, say, UNIX domain sockets + client_addr = "" + is_ssl = False + try: + client_addr = self.client_address[0] + is_ssl = self.client_address[2] + except IndexError: + pass + + if is_ssl: + self.stype = "SSL/TLS (wss://)" + else: + self.stype = "Plain non-SSL (ws://)" + + self.log_message("%s: %s WebSocket connection", client_addr, + self.stype) + if self.path != '/': + self.log_message("%s: Path: '%s'", client_addr, self.path) + + if self.record: + # Record raw frame data as JavaScript array + fname = "%s.%s" % (self.record, + self.handler_id) + self.log_message("opening record file: %s", fname) + self.rec = open(fname, 'w+') + self.rec.write("var VNC_frame_data = [\n") + + try: + self.new_websocket_client() + except self.CClose: + # Close the client + _, exc, _ = sys.exc_info() + self.send_close(exc.args[0], exc.args[1]) + return True + else: + return False + + def do_GET(self): + """Handle GET request. Calls handle_websocket(). If unsuccessful, + and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called.""" + if not self.handle_websocket(): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_GET(self) + + def list_directory(self, path): + if self.file_only: + self.send_error(404, "No such file") + else: + return SimpleHTTPRequestHandler.list_directory(self, path) + + def new_websocket_client(self): + """ Do something with a WebSockets client connection. """ + raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + + def validate_connection(self): + """ Ensure that the connection is a valid connection, and set the target. """ + pass + + def do_HEAD(self): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_HEAD(self) + + def finish(self): + if self.rec: + self.rec.write("'EOF'];\n") + self.rec.close() + + def handle(self): + # When using run_once, we have a single process, so + # we cannot loop in BaseHTTPRequestHandler.handle; we + # must return and handle new connections + if self.run_once: + self.handle_one_request() + else: + SimpleHTTPRequestHandler.handle(self) + + def log_request(self, code='-', size='-'): + if self.verbose: + SimpleHTTPRequestHandler.log_request(self, code, size) + + +class WebSocketServer(object): + """ + WebSockets server class. + As an alternative, the standard library SocketServer can be used + """ + + policy_response = """\n""" + log_prefix = "websocket" + + # An exception before the WebSocket connection was established + class EClose(Exception): + pass + + class Terminate(Exception): + pass + + def __init__(self, RequestHandlerClass, listen_host='', + listen_port=None, source_is_ipv6=False, + verbose=False, cert='', key='', ssl_only=None, + daemon=False, record='', web='', + file_only=False, + run_once=False, timeout=0, idle_timeout=0, traffic=False, + tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, + tcp_keepintvl=None): + + # settings + self.RequestHandlerClass = RequestHandlerClass + self.verbose = verbose + self.listen_host = listen_host + self.listen_port = listen_port + self.prefer_ipv6 = source_is_ipv6 + self.ssl_only = ssl_only + self.daemon = daemon + self.run_once = run_once + self.timeout = timeout + self.idle_timeout = idle_timeout + self.traffic = traffic + self.file_only = file_only + + self.launch_time = time.time() + self.ws_connection = False + self.handler_id = 1 + + self.logger = self.get_logger() + self.tcp_keepalive = tcp_keepalive + self.tcp_keepcnt = tcp_keepcnt + self.tcp_keepidle = tcp_keepidle + self.tcp_keepintvl = tcp_keepintvl + + # Make paths settings absolute + self.cert = os.path.abspath(cert) + self.key = self.web = self.record = '' + if key: + self.key = os.path.abspath(key) + if web: + self.web = os.path.abspath(web) + if record: + self.record = os.path.abspath(record) + + if self.web: + os.chdir(self.web) + self.only_upgrade = not self.web + + # Sanity checks + if not ssl and self.ssl_only: + raise Exception("No 'ssl' module and SSL-only specified") + if self.daemon and not resource: + raise Exception("Module 'resource' required to daemonize") + + # Show configuration + self.msg("WebSocket server settings:") + self.msg(" - Listen on %s:%s", + self.listen_host, self.listen_port) + self.msg(" - Flash security policy server") + if self.web: + if self.file_only: + self.msg(" - Web server (no directory listings). Web root: %s", self.web) + else: + self.msg(" - Web server. Web root: %s", self.web) + if ssl: + if os.path.exists(self.cert): + self.msg(" - SSL/TLS support") + if self.ssl_only: + self.msg(" - Deny non-SSL/TLS connections") + else: + self.msg(" - No SSL/TLS support (no cert file)") + else: + self.msg(" - No SSL/TLS support (no 'ssl' module)") + if self.daemon: + self.msg(" - Backgrounding (daemon)") + if self.record: + self.msg(" - Recording to '%s.*'", self.record) + + # + # WebSocketServer static methods + # + + @staticmethod + def get_logger(): + return logging.getLogger("%s.%s" % ( + WebSocketServer.log_prefix, + WebSocketServer.__class__.__name__)) + + @staticmethod + def socket(host, port=None, connect=False, prefer_ipv6=False, + unix_socket=None, use_ssl=False, tcp_keepalive=True, + tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): + """ Resolve a host (and optional port) to an IPv4 or IPv6 + address. Create a socket. Bind to it if listen is set, + otherwise connect to it. Return the socket. + """ + flags = 0 + if host == '': + host = None + if connect and not (port or unix_socket): + raise Exception("Connect mode requires a port") + if not connect and use_ssl: + raise Exception("SSL only supported in connect mode (for now)") + if not connect: + flags = flags | socket.AI_PASSIVE + + if not unix_socket: + addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, + socket.IPPROTO_TCP, flags) + if not addrs: + raise Exception("Could not resolve host '%s'" % host) + addrs.sort(key=lambda x: x[0]) + if prefer_ipv6: + addrs.reverse() + sock = socket.socket(addrs[0][0], addrs[0][1]) + + if tcp_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if tcp_keepcnt: + if hasattr(socket, 'TCP_KEEPCNT'): + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, + tcp_keepcnt) + else: + self.msg('tcp_keepcnt not available on your system') + if tcp_keepidle: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, + tcp_keepidle) + if tcp_keepintvl: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, + tcp_keepintvl) + + if connect: + sock.connect(addrs[0][4]) + if use_ssl: + sock = ssl.wrap_socket(sock) + else: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addrs[0][4]) + sock.listen(100) + else: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(unix_socket) + + return sock + + @staticmethod + def daemonize(keepfd=None, chdir='/'): + + if keepfd is None: + keepfd = [] + + os.umask(0) + if chdir: + os.chdir(chdir) + else: + os.chdir('/') + os.setgid(os.getgid()) # relinquish elevations + os.setuid(os.getuid()) # relinquish elevations + + # Double fork to daemonize + if os.fork() > 0: os._exit(0) # Parent exits + os.setsid() # Obtain new process group + if os.fork() > 0: os._exit(0) # Parent exits + + # Signal handling + signal.signal(signal.SIGTERM, signal.SIG_IGN) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Close open files + maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if maxfd == resource.RLIM_INFINITY: maxfd = 256 + for fd in reversed(range(maxfd)): + try: + if fd not in keepfd: + os.close(fd) + except OSError: + _, exc, _ = sys.exc_info() + if exc.errno != errno.EBADF: raise + + # Redirect I/O to /dev/null + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) + + def do_handshake(self, sock, address): + """ + do_handshake does the following: + - Peek at the first few bytes from the socket. + - If the connection is Flash policy request then answer it, + close the socket and return. + - If the connection is an HTTPS/SSL/TLS connection then SSL + wrap the socket. + - Read from the (possibly wrapped) socket. + - If we have received a HTTP GET request and the webserver + functionality is enabled, answer it, close the socket and + return. + - Assume we have a WebSockets connection, parse the client + handshake data. + - Send a WebSockets handshake server response. + - Return the socket for this WebSocket client. + """ + ready = select.select([sock], [], [], 3)[0] + + + if not ready: + raise self.EClose("ignoring socket not ready") + # Peek, but do not read the data so that we have a opportunity + # to SSL wrap the socket first + handshake = sock.recv(1024, socket.MSG_PEEK) + #self.msg("Handshake [%s]" % handshake) + + if not handshake: + raise self.EClose("ignoring empty handshake") + + elif handshake.startswith(s2b("")): + # Answer Flash policy request + handshake = sock.recv(1024) + sock.send(s2b(self.policy_response)) + raise self.EClose("Sending flash policy response") + + elif handshake[0] in ("\x16", "\x80", 22, 128): + # SSL wrap the connection + if not ssl: + raise self.EClose("SSL connection but no 'ssl' module") + if not os.path.exists(self.cert): + raise self.EClose("SSL connection but '%s' not found" + % self.cert) + retsock = None + try: + retsock = ssl.wrap_socket( + sock, + server_side=True, + certfile=self.cert, + keyfile=self.key) + except ssl.SSLError: + _, x, _ = sys.exc_info() + if x.args[0] == ssl.SSL_ERROR_EOF: + if len(x.args) > 1: + raise self.EClose(x.args[1]) + else: + raise self.EClose("Got SSL_ERROR_EOF") + else: + raise + + elif self.ssl_only: + raise self.EClose("non-SSL connection received but disallowed") + + else: + retsock = sock + + # If the address is like (host, port), we are extending it + # with a flag indicating SSL. Not many other options + # available... + if len(address) == 2: + address = (address[0], address[1], (retsock != sock)) + + self.RequestHandlerClass(retsock, address, self) + + # Return the WebSockets socket which may be SSL wrapped + return retsock + + # + # WebSocketServer logging/output functions + # + + def msg(self, *args, **kwargs): + """ Output message as info """ + self.logger.log(logging.INFO, *args, **kwargs) + + def vmsg(self, *args, **kwargs): + """ Same as msg() but as debug. """ + self.logger.log(logging.INFO, *args, **kwargs) + + def warn(self, *args, **kwargs): + """ Same as msg() but as warning. """ + self.logger.log(logging.WARN, *args, **kwargs) + + + # + # Events that can/should be overridden in sub-classes + # + def started(self): + """ Called after WebSockets startup """ + self.vmsg("WebSockets server started") + + def poll(self): + """ Run periodically while waiting for connections. """ + #self.vmsg("Running poll()") + pass + + def terminate(self): + raise self.Terminate() + + def multiprocessing_SIGCHLD(self, sig, stack): + # TODO: figure out a way to actually log this information without + # calling `log` in the signal handlers + multiprocessing.active_children() + + def fallback_SIGCHLD(self, sig, stack): + # Reap zombies when using os.fork() (python 2.4) + # TODO: figure out a way to actually log this information without + # calling `log` in the signal handlers + try: + result = os.waitpid(-1, os.WNOHANG) + while result[0]: + self.vmsg("Reaped child process %s" % result[0]) + result = os.waitpid(-1, os.WNOHANG) + except (OSError): + pass + + def do_SIGINT(self, sig, stack): + # TODO: figure out a way to actually log this information without + # calling `log` in the signal handlers + self.terminate() + + def do_SIGTERM(self, sig, stack): + # TODO: figure out a way to actually log this information without + # calling `log` in the signal handlers + self.terminate() + + def top_new_client(self, startsock, address): + """ Do something with a WebSockets client connection. """ + # handler process + client = None + try: + try: + client = self.do_handshake(startsock, address) + except self.EClose: + _, exc, _ = sys.exc_info() + # Connection was not a WebSockets connection + if exc.args[0]: + self.msg("%s: %s" % (address[0], exc.args[0])) + except WebSocketServer.Terminate: + raise + except Exception: + _, exc, _ = sys.exc_info() + self.msg("handler exception: %s" % str(exc)) + self.vmsg("exception", exc_info=True) + finally: + + if client and client != startsock: + # Close the SSL wrapped socket + # Original socket closed by caller + client.close() + + def get_log_fd(self): + """ + Get file descriptors for the loggers. + They should not be closed when the process is forked. + """ + descriptors = [] + for handler in self.logger.parent.handlers: + if isinstance(handler, logging.FileHandler): + descriptors.append(handler.stream.fileno()) + + return descriptors + + def start_server(self): + """ + Daemonize if requested. Listen for for connections. Run + do_handshake() method for each connection. If the connection + is a WebSockets client then call new_websocket_client() method (which must + be overridden) for each new client connection. + """ + lsock = self.socket(self.listen_host, self.listen_port, False, + self.prefer_ipv6, + tcp_keepalive=self.tcp_keepalive, + tcp_keepcnt=self.tcp_keepcnt, + tcp_keepidle=self.tcp_keepidle, + tcp_keepintvl=self.tcp_keepintvl) + + if self.daemon: + keepfd = self.get_log_fd() + keepfd.append(lsock.fileno()) + self.daemonize(keepfd=keepfd, chdir=self.web) + + self.started() # Some things need to happen after daemonizing + + # Allow override of signals + original_signals = { + signal.SIGINT: signal.getsignal(signal.SIGINT), + signal.SIGTERM: signal.getsignal(signal.SIGTERM), + signal.SIGCHLD: signal.getsignal(signal.SIGCHLD), + } + signal.signal(signal.SIGINT, self.do_SIGINT) + signal.signal(signal.SIGTERM, self.do_SIGTERM) + # make sure that _cleanup is called when children die + # by calling active_children on SIGCHLD + signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) + + last_active_time = self.launch_time + try: + while True: + try: + try: + startsock = None + pid = err = 0 + child_count = 0 + + # Collect zombie child processes + child_count = len(multiprocessing.active_children()) + + time_elapsed = time.time() - self.launch_time + if self.timeout and time_elapsed > self.timeout: + self.msg('listener exit due to --timeout %s' + % self.timeout) + break + + if self.idle_timeout: + idle_time = 0 + if child_count == 0: + idle_time = time.time() - last_active_time + else: + idle_time = 0 + last_active_time = time.time() + + if idle_time > self.idle_timeout and child_count == 0: + self.msg('listener exit due to --idle-timeout %s' + % self.idle_timeout) + break + + try: + self.poll() + + ready = select.select([lsock], [], [], 1)[0] + if lsock in ready: + startsock, address = lsock.accept() + else: + continue + except self.Terminate: + raise + except Exception: + _, exc, _ = sys.exc_info() + if hasattr(exc, 'errno'): + err = exc.errno + elif hasattr(exc, 'args'): + err = exc.args[0] + else: + err = exc[0] + if err == errno.EINTR: + self.vmsg("Ignoring interrupted syscall") + continue + else: + raise + + if self.run_once: + # Run in same process if run_once + self.top_new_client(startsock, address) + if self.ws_connection : + self.msg('%s: exiting due to --run-once' + % address[0]) + break + else: + self.vmsg('%s: new handler Process' % address[0]) + p = multiprocessing.Process( + target=self.top_new_client, + args=(startsock, address)) + p.start() + # child will not return + + # parent process + self.handler_id += 1 + + except (self.Terminate, SystemExit, KeyboardInterrupt): + self.msg("In exit") + # terminate all child processes + if not self.run_once: + children = multiprocessing.active_children() + + for child in children: + self.msg("Terminating child %s" % child.pid) + child.terminate() + + break + except Exception: + exc = sys.exc_info()[1] + self.msg("handler exception: %s", str(exc)) + self.vmsg("exception", exc_info=True) + + finally: + if startsock: + startsock.close() + finally: + # Close listen port + self.vmsg("Closing socket listening at %s:%s", + self.listen_host, self.listen_port) + lsock.close() + + # Restore signals + for sig, func in original_signals.items(): + signal.signal(sig, func) + +