From 29ec4bd68e9ef670ca69143715130cc4a74f427b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Urba=C5=84ski?= Date: Sun, 31 Jul 2011 21:18:19 +0200 Subject: [PATCH] Add tests for the hybi-10 decoder and request handler. --- test_websocket.py | 338 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 336 insertions(+), 2 deletions(-) diff --git a/test_websocket.py b/test_websocket.py index 9877d16..5235ea5 100644 --- a/test_websocket.py +++ b/test_websocket.py @@ -4,14 +4,17 @@ """ Tests for L{twisted.web.websocket}. """ +import base64 +from hashlib import sha1 from twisted.internet.main import CONNECTION_DONE from twisted.internet.error import ConnectionDone from twisted.python.failure import Failure from websocket import WebSocketHandler, WebSocketFrameDecoder -from websocket import WebSocketSite, WebSocketTransport -from websocket import DecodingError +from websocket import WebSocketHybiFrameDecoder +from websocket import WebSocketSite, WebSocketTransport, WebSocketHybiTransport +from websocket import DecodingError, OPCODE_PING, OPCODE_TEXT from twisted.web.resource import Resource from twisted.web.server import Request, Site @@ -45,6 +48,9 @@ class TestHandler(WebSocketHandler): def __init__(self, request): WebSocketHandler.__init__(self, request) self.frames = [] + self.binaryFrames = [] + self.pongs = [] + self.closes = [] self.lostReason = None @@ -52,6 +58,18 @@ def frameReceived(self, frame): self.frames.append(frame) + def binaryFrameReceived(self, frame): + self.binaryFrames.append(frame) + + + def pongReceived(self, data): + self.pongs.append(data) + + + def closeReceived(self, code, msg): + self.closes.append((code, msg)) + + def connectionLost(self, reason): self.lostReason = reason @@ -82,6 +100,8 @@ def renderRequest(self, headers=None, url="/test", ssl=False, channel.transport = channel.SSL() channel.site = self.site request = self.site.requestFactory(channel, queued) + # store the reference to the request, so the tests can access it + channel.request = request for k, v in headers: request.requestHeaders.addRawHeader(k, v) request.gotLength(0) @@ -359,6 +379,91 @@ def test_addHandlerWithoutSlash(self): ValueError, self.site.addHandler, "test", TestHandler) + def test_render_handShakeHybi(self): + """ + Test a hybi-10 handshake. + """ + # the key is a base64 encoded 16-bit integer, here chosen to be 14 + key = "AA4=" + headers = [ + ("Upgrade", "websocket"), ("Connection", "Upgrade"), + ("Host", "localhost"), ("Origin", "http://localhost/"), + ("Sec-WebSocket-Version", "8"), ("Sec-WebSocket-Key", key)] + channel = self.renderRequest(headers=headers) + + self.assertTrue(channel.raw) + + result = channel.transport.written.getvalue() + headers, response = result.split('\r\n\r\n') + + guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + accept = base64.b64encode(sha1(key + guid).digest()) + self.assertEquals( + headers, + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s" % accept) + + self.assertFalse(channel.transport.disconnected) + self.assertFalse(channel.request.finished) + + + def test_hybiWrongVersion(self): + """ + A handshake that requests an unsupported version of the protocol + results in HTTP 426. + """ + key = "AA4=" + headers = [ + ("Upgrade", "websocket"), ("Connection", "Upgrade"), + ("Host", "localhost"), ("Origin", "http://localhost/"), + ("Sec-WebSocket-Version", "9"), ("Sec-WebSocket-Key", key)] + channel = self.renderRequest(headers=headers) + + result = channel.transport.written.getvalue() + + self.assertIn("HTTP/1.1 426", result) + # Twisted canonicalizes header names (see + # http_headers.Headers._canonicalNameCaps), so it's not + # Sec-WebSocket-Version, but Sec-Websocket-Version, but clients + # understand it anyway + self.assertIn("Sec-Websocket-Version: 8", result) + self.assertTrue(channel.request.finished) + + + def test_hybiNoKey(self): + """ + A handshake without a websocket key results in HTTP 400. + """ + headers = [ + ("Upgrade", "websocket"), ("Connection", "Upgrade"), + ("Host", "localhost"), ("Origin", "http://localhost/"), + ("Sec-WebSocket-Version", "8")] + channel = self.renderRequest(headers=headers) + + result = channel.transport.written.getvalue() + + self.assertIn("HTTP/1.1 400", result) + self.assertTrue(channel.request.finished) + + + def test_hybiNotFound(self): + """ + A request for an unknown endpoint results in HTTP 404. + """ + key = "AA4=" + headers = [ + ("Upgrade", "websocket"), ("Connection", "Upgrade"), + ("Host", "localhost"), ("Origin", "http://localhost/"), + ("Sec-WebSocket-Version", "8"), ("Sec-WebSocket-Key", key)] + channel = self.renderRequest(headers=headers, url="/foo") + + result = channel.transport.written.getvalue() + + self.assertIn("HTTP/1.1 404", result) + self.assertTrue(channel.request.finished) + class WebSocketFrameDecoderTestCase(TestCase): """ @@ -562,6 +667,172 @@ def test_emptyFrame(self): self.assertFalse(self.channel.transport.disconnected) +class WebSocketHybiFrameDecoderTestCase(TestCase): + """ + Test for C{WebSocketHybiFrameDecoder}. + """ + + def setUp(self): + self.channel = DummyChannel() + request = Request(self.channel, False) + transport = WebSocketHybiTransport(request) + handler = TestHandler(transport) + transport._attachHandler(handler) + self.decoder = WebSocketHybiFrameDecoder(request, handler) + self.decoder.MAX_LENGTH = 100 + self.decoder.MAX_BINARY_LENGTH = 1000 + # taken straight from the IETF draft, masking added where appropriate + self.hello = "\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58" + self.frag_hello = ("\x01\x83\x12\x21\x65\x23\x5a\x44\x09", + "\x80\x82\x63\x34\xf1\x00\x0f\x5b") + self.binary_orig = "\x3f" * 256 + self.binary = ("\x82\xfe\x01\x00\x12\x6d\xa6\x23" + + "\x2d\x52\x99\x1c" * 64) + self.ping = "\x89\x85\x56\x23\x88\x23\x1e\x46\xe4\x4f\x39" + self.pong = "\x8a\x85\xde\x41\x0f\x34\x96\x24\x63\x58\xb1" + self.pong_unmasked = "\x8a\x05\x48\x65\x6c\x6c\x6f" + # code 1000, message "Normal Closure" + self.close = ("\x88\x90\x34\x23\x87\xde\x37\xcb\xc9\xb1\x46" + "\x4e\xe6\xb2\x14\x60\xeb\xb1\x47\x56\xf5\xbb") + self.empty_unmasked_close = "\x88\x00" + self.empty_text = "\x81\x80\x00\x01\x02\x03" + self.cont_empty_text = "\x00\x80\x00\x01\x02\x03" + + + def assertOneDecodingError(self): + """ + Assert that exactly one L{DecodingError} has been logged and return + that error. + """ + errors = self.flushLoggedErrors(DecodingError) + self.assertEquals(len(errors), 1) + return errors[0] + + + def test_oneTextFrame(self): + """ + We can send one frame handled with one C{dataReceived} call. + """ + self.decoder.dataReceived(self.hello) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + + + def test_chunkedTextFrame(self): + """ + We can send one text frame handled with multiple C{dataReceived} calls. + """ + # taken straight from the IETF draft + for part in (self.hello[:1], self.hello[1:3], + self.hello[3:7], self.hello[7:]): + self.decoder.dataReceived(part) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + + + def test_fragmentedTextFrame(self): + """ + We can send a fragmented frame handled with one C{dataReceived} call. + """ + self.decoder.dataReceived("".join(self.frag_hello)) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + + + def test_chunkedfragmentedTextFrame(self): + """ + We can send a fragmented text frame handled with multiple + C{dataReceived} calls. + """ + # taken straight from the IETF draft + for part in (self.frag_hello[0][:3], self.frag_hello[0][3:]): + self.decoder.dataReceived(part) + for part in (self.frag_hello[1][:1], self.frag_hello[1][1:]): + self.decoder.dataReceived(part) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + + + def test_twoFrames(self): + """ + We can send two frames together and they will be correctly parsed. + """ + self.decoder.dataReceived("".join(self.frag_hello) + self.hello) + self.assertEquals(self.decoder.handler.frames, ["Hello"] * 2) + + + def test_controlInterleaved(self): + """ + A control message (in this case a pong) can appear between the + fragmented frames. + """ + data = self.frag_hello[0] + self.pong + self.frag_hello[1] + for part in data[:2], data[2:7], data[7:8], data[8:14], data[14:]: + self.decoder.dataReceived(part) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + self.assertEquals(self.decoder.handler.pongs, ["Hello"]) + + + def test_binaryFrame(self): + """ + We can send a binary frame that uses a longer length field. + """ + data = self.binary + for part in data[:3], data[3:4], data[4:]: + self.decoder.dataReceived(part) + self.assertEquals(self.decoder.handler.binaryFrames, + [self.binary_orig]) + + + def test_pingInterleaved(self): + """ + We can get a ping frame in the middle of a fragmented frame and we'll + correctly send a pong resonse. + """ + data = self.frag_hello[0] + self.ping + self.frag_hello[1] + for part in data[:12], data[12:16], data[16:]: + self.decoder.dataReceived(part) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + + result = self.channel.transport.written.getvalue() + headers, response = result.split('\r\n\r\n') + + self.assertEquals(response, self.pong_unmasked) + + + def test_close(self): + """ + A close frame causes the remaining data to be discarded and the + connection to be closed. + """ + self.decoder.dataReceived(self.hello + self.close + "crap" * 20) + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + + result = self.channel.transport.written.getvalue() + headers, response = result.split('\r\n\r\n') + + self.assertEquals(response, self.empty_unmasked_close) + self.assertTrue(self.channel.transport.disconnected) + + + def test_emptyFrame(self): + """ + An empty text frame is correctly parsed. + """ + self.decoder.dataReceived(self.empty_text) + self.assertEquals(self.decoder.handler.frames, [""]) + + + def test_emptyFrameInterleaved(self): + """ + An empty fragmented frame and a interleaved pong message are received + and parsed. + """ + data = (self.frag_hello[0] + self.cont_empty_text + + self.pong + self.frag_hello[1]) + for part in data[:1], data[1:8], data[8:17], data[17:]: + self.decoder.dataReceived(part) + + self.assertEquals(self.decoder.handler.frames, ["Hello"]) + self.assertEquals(self.decoder.handler.pongs, ["Hello"]) + + class WebSocketHandlerTestCase(TestCase): """ Tests for L{WebSocketHandler}. @@ -605,3 +876,66 @@ def test_connectionLost(self): """ self.request.connectionLost(Failure(CONNECTION_DONE)) self.handler.lostReason.trap(ConnectionDone) + + +class WebSocketHybiHandlerTestCase(TestCase): + """ + Tests for L{WebSocketHandler} using the hybi-10 protocol. + """ + + def setUp(self): + self.channel = DummyChannel() + self.request = request = Request(self.channel, False) + # Simulate request handling + request.startedWriting = True + transport = WebSocketHybiTransport(request) + self.handler = TestHandler(transport) + transport._attachHandler(self.handler) + + + def test_write(self): + """ + L{WebSocketHybiTransport.write} wraps the data in a text frame and + writes it to the request. + """ + self.handler.transport.write("Hello") + self.handler.transport.write("World") + self.assertEquals( + self.channel.transport.written.getvalue(), + "\x81\x05\x48\x65\x6c\x6c\x6f" + "\x81\x05\x57\x6f\x72\x6c\x64") + self.assertFalse(self.channel.transport.disconnected) + + + def test_sendFrame(self): + """ + L{WebSocketHybiTransport.sendFrame} creates an unmasked hybi-10 frame + and writes it to the request + """ + self.handler.transport.sendFrame(OPCODE_PING, "ping") + self.assertEquals( + self.channel.transport.written.getvalue(), + "\x89\x04\x70\x69\x6e\x67") + self.assertFalse(self.channel.transport.disconnected) + + + def test_sendLongFrame(self): + """ + Sending a frame with a payload longer than 125 bytes results in a + longer length field written to the request. + """ + self.handler.transport.sendFrame(OPCODE_TEXT, "crap" * 20000) + self.assertEquals( + self.channel.transport.written.getvalue(), + "\x81\x7f\x00\x00\x00\x00\x00\x01\x38\x80" + "crap" * 20000) + self.assertFalse(self.channel.transport.disconnected) + + + def test_sendFragmentedFrame(self): + """ + Sending a frame with the fragmented flag makes the correct flag unset. + """ + self.handler.transport.sendFrame(OPCODE_TEXT, "Hello", fragmented=True) + self.assertEquals( + self.channel.transport.written.getvalue(), + "\x01\x05\x48\x65\x6c\x6c\x6f") + self.assertFalse(self.channel.transport.disconnected)