Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Add tests for the hybi-10 decoder and request handler.
Browse files Browse the repository at this point in the history
  • Loading branch information
wulczer committed Jul 31, 2011
1 parent e223509 commit 29ec4bd
Showing 1 changed file with 336 additions and 2 deletions.
338 changes: 336 additions & 2 deletions test_websocket.py
Expand Up @@ -4,14 +4,17 @@
"""
Tests for L{twisted.web.websocket}.
"""
import base64
from hashlib import sha1

from twisted.internet.main import CONNECTION_DONE
from twisted.internet.error import ConnectionDone
from twisted.python.failure import Failure

from websocket import WebSocketHandler, WebSocketFrameDecoder
from websocket import WebSocketSite, WebSocketTransport
from websocket import DecodingError
from websocket import WebSocketHybiFrameDecoder
from websocket import WebSocketSite, WebSocketTransport, WebSocketHybiTransport
from websocket import DecodingError, OPCODE_PING, OPCODE_TEXT

from twisted.web.resource import Resource
from twisted.web.server import Request, Site
Expand Down Expand Up @@ -45,13 +48,28 @@ class TestHandler(WebSocketHandler):
def __init__(self, request):
WebSocketHandler.__init__(self, request)
self.frames = []
self.binaryFrames = []
self.pongs = []
self.closes = []
self.lostReason = None


def frameReceived(self, frame):
self.frames.append(frame)


def binaryFrameReceived(self, frame):
self.binaryFrames.append(frame)


def pongReceived(self, data):
self.pongs.append(data)


def closeReceived(self, code, msg):
self.closes.append((code, msg))


def connectionLost(self, reason):
self.lostReason = reason

Expand Down Expand Up @@ -82,6 +100,8 @@ def renderRequest(self, headers=None, url="/test", ssl=False,
channel.transport = channel.SSL()
channel.site = self.site
request = self.site.requestFactory(channel, queued)
# store the reference to the request, so the tests can access it
channel.request = request
for k, v in headers:
request.requestHeaders.addRawHeader(k, v)
request.gotLength(0)
Expand Down Expand Up @@ -359,6 +379,91 @@ def test_addHandlerWithoutSlash(self):
ValueError, self.site.addHandler, "test", TestHandler)


def test_render_handShakeHybi(self):
"""
Test a hybi-10 handshake.
"""
# the key is a base64 encoded 16-bit integer, here chosen to be 14
key = "AA4="
headers = [
("Upgrade", "websocket"), ("Connection", "Upgrade"),
("Host", "localhost"), ("Origin", "http://localhost/"),
("Sec-WebSocket-Version", "8"), ("Sec-WebSocket-Key", key)]
channel = self.renderRequest(headers=headers)

self.assertTrue(channel.raw)

result = channel.transport.written.getvalue()
headers, response = result.split('\r\n\r\n')

guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
accept = base64.b64encode(sha1(key + guid).digest())
self.assertEquals(
headers,
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s" % accept)

self.assertFalse(channel.transport.disconnected)
self.assertFalse(channel.request.finished)


def test_hybiWrongVersion(self):
"""
A handshake that requests an unsupported version of the protocol
results in HTTP 426.
"""
key = "AA4="
headers = [
("Upgrade", "websocket"), ("Connection", "Upgrade"),
("Host", "localhost"), ("Origin", "http://localhost/"),
("Sec-WebSocket-Version", "9"), ("Sec-WebSocket-Key", key)]
channel = self.renderRequest(headers=headers)

result = channel.transport.written.getvalue()

self.assertIn("HTTP/1.1 426", result)
# Twisted canonicalizes header names (see
# http_headers.Headers._canonicalNameCaps), so it's not
# Sec-WebSocket-Version, but Sec-Websocket-Version, but clients
# understand it anyway
self.assertIn("Sec-Websocket-Version: 8", result)
self.assertTrue(channel.request.finished)


def test_hybiNoKey(self):
"""
A handshake without a websocket key results in HTTP 400.
"""
headers = [
("Upgrade", "websocket"), ("Connection", "Upgrade"),
("Host", "localhost"), ("Origin", "http://localhost/"),
("Sec-WebSocket-Version", "8")]
channel = self.renderRequest(headers=headers)

result = channel.transport.written.getvalue()

self.assertIn("HTTP/1.1 400", result)
self.assertTrue(channel.request.finished)


def test_hybiNotFound(self):
"""
A request for an unknown endpoint results in HTTP 404.
"""
key = "AA4="
headers = [
("Upgrade", "websocket"), ("Connection", "Upgrade"),
("Host", "localhost"), ("Origin", "http://localhost/"),
("Sec-WebSocket-Version", "8"), ("Sec-WebSocket-Key", key)]
channel = self.renderRequest(headers=headers, url="/foo")

result = channel.transport.written.getvalue()

self.assertIn("HTTP/1.1 404", result)
self.assertTrue(channel.request.finished)


class WebSocketFrameDecoderTestCase(TestCase):
"""
Expand Down Expand Up @@ -562,6 +667,172 @@ def test_emptyFrame(self):
self.assertFalse(self.channel.transport.disconnected)


class WebSocketHybiFrameDecoderTestCase(TestCase):
"""
Test for C{WebSocketHybiFrameDecoder}.
"""

def setUp(self):
self.channel = DummyChannel()
request = Request(self.channel, False)
transport = WebSocketHybiTransport(request)
handler = TestHandler(transport)
transport._attachHandler(handler)
self.decoder = WebSocketHybiFrameDecoder(request, handler)
self.decoder.MAX_LENGTH = 100
self.decoder.MAX_BINARY_LENGTH = 1000
# taken straight from the IETF draft, masking added where appropriate
self.hello = "\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58"
self.frag_hello = ("\x01\x83\x12\x21\x65\x23\x5a\x44\x09",
"\x80\x82\x63\x34\xf1\x00\x0f\x5b")
self.binary_orig = "\x3f" * 256
self.binary = ("\x82\xfe\x01\x00\x12\x6d\xa6\x23" +
"\x2d\x52\x99\x1c" * 64)
self.ping = "\x89\x85\x56\x23\x88\x23\x1e\x46\xe4\x4f\x39"
self.pong = "\x8a\x85\xde\x41\x0f\x34\x96\x24\x63\x58\xb1"
self.pong_unmasked = "\x8a\x05\x48\x65\x6c\x6c\x6f"
# code 1000, message "Normal Closure"
self.close = ("\x88\x90\x34\x23\x87\xde\x37\xcb\xc9\xb1\x46"
"\x4e\xe6\xb2\x14\x60\xeb\xb1\x47\x56\xf5\xbb")
self.empty_unmasked_close = "\x88\x00"
self.empty_text = "\x81\x80\x00\x01\x02\x03"
self.cont_empty_text = "\x00\x80\x00\x01\x02\x03"


def assertOneDecodingError(self):
"""
Assert that exactly one L{DecodingError} has been logged and return
that error.
"""
errors = self.flushLoggedErrors(DecodingError)
self.assertEquals(len(errors), 1)
return errors[0]


def test_oneTextFrame(self):
"""
We can send one frame handled with one C{dataReceived} call.
"""
self.decoder.dataReceived(self.hello)
self.assertEquals(self.decoder.handler.frames, ["Hello"])


def test_chunkedTextFrame(self):
"""
We can send one text frame handled with multiple C{dataReceived} calls.
"""
# taken straight from the IETF draft
for part in (self.hello[:1], self.hello[1:3],
self.hello[3:7], self.hello[7:]):
self.decoder.dataReceived(part)
self.assertEquals(self.decoder.handler.frames, ["Hello"])


def test_fragmentedTextFrame(self):
"""
We can send a fragmented frame handled with one C{dataReceived} call.
"""
self.decoder.dataReceived("".join(self.frag_hello))
self.assertEquals(self.decoder.handler.frames, ["Hello"])


def test_chunkedfragmentedTextFrame(self):
"""
We can send a fragmented text frame handled with multiple
C{dataReceived} calls.
"""
# taken straight from the IETF draft
for part in (self.frag_hello[0][:3], self.frag_hello[0][3:]):
self.decoder.dataReceived(part)
for part in (self.frag_hello[1][:1], self.frag_hello[1][1:]):
self.decoder.dataReceived(part)
self.assertEquals(self.decoder.handler.frames, ["Hello"])


def test_twoFrames(self):
"""
We can send two frames together and they will be correctly parsed.
"""
self.decoder.dataReceived("".join(self.frag_hello) + self.hello)
self.assertEquals(self.decoder.handler.frames, ["Hello"] * 2)


def test_controlInterleaved(self):
"""
A control message (in this case a pong) can appear between the
fragmented frames.
"""
data = self.frag_hello[0] + self.pong + self.frag_hello[1]
for part in data[:2], data[2:7], data[7:8], data[8:14], data[14:]:
self.decoder.dataReceived(part)
self.assertEquals(self.decoder.handler.frames, ["Hello"])
self.assertEquals(self.decoder.handler.pongs, ["Hello"])


def test_binaryFrame(self):
"""
We can send a binary frame that uses a longer length field.
"""
data = self.binary
for part in data[:3], data[3:4], data[4:]:
self.decoder.dataReceived(part)
self.assertEquals(self.decoder.handler.binaryFrames,
[self.binary_orig])


def test_pingInterleaved(self):
"""
We can get a ping frame in the middle of a fragmented frame and we'll
correctly send a pong resonse.
"""
data = self.frag_hello[0] + self.ping + self.frag_hello[1]
for part in data[:12], data[12:16], data[16:]:
self.decoder.dataReceived(part)
self.assertEquals(self.decoder.handler.frames, ["Hello"])

result = self.channel.transport.written.getvalue()
headers, response = result.split('\r\n\r\n')

self.assertEquals(response, self.pong_unmasked)


def test_close(self):
"""
A close frame causes the remaining data to be discarded and the
connection to be closed.
"""
self.decoder.dataReceived(self.hello + self.close + "crap" * 20)
self.assertEquals(self.decoder.handler.frames, ["Hello"])

result = self.channel.transport.written.getvalue()
headers, response = result.split('\r\n\r\n')

self.assertEquals(response, self.empty_unmasked_close)
self.assertTrue(self.channel.transport.disconnected)


def test_emptyFrame(self):
"""
An empty text frame is correctly parsed.
"""
self.decoder.dataReceived(self.empty_text)
self.assertEquals(self.decoder.handler.frames, [""])


def test_emptyFrameInterleaved(self):
"""
An empty fragmented frame and a interleaved pong message are received
and parsed.
"""
data = (self.frag_hello[0] + self.cont_empty_text +
self.pong + self.frag_hello[1])
for part in data[:1], data[1:8], data[8:17], data[17:]:
self.decoder.dataReceived(part)

self.assertEquals(self.decoder.handler.frames, ["Hello"])
self.assertEquals(self.decoder.handler.pongs, ["Hello"])


class WebSocketHandlerTestCase(TestCase):
"""
Tests for L{WebSocketHandler}.
Expand Down Expand Up @@ -605,3 +876,66 @@ def test_connectionLost(self):
"""
self.request.connectionLost(Failure(CONNECTION_DONE))
self.handler.lostReason.trap(ConnectionDone)


class WebSocketHybiHandlerTestCase(TestCase):
"""
Tests for L{WebSocketHandler} using the hybi-10 protocol.
"""

def setUp(self):
self.channel = DummyChannel()
self.request = request = Request(self.channel, False)
# Simulate request handling
request.startedWriting = True
transport = WebSocketHybiTransport(request)
self.handler = TestHandler(transport)
transport._attachHandler(self.handler)


def test_write(self):
"""
L{WebSocketHybiTransport.write} wraps the data in a text frame and
writes it to the request.
"""
self.handler.transport.write("Hello")
self.handler.transport.write("World")
self.assertEquals(
self.channel.transport.written.getvalue(),
"\x81\x05\x48\x65\x6c\x6c\x6f" + "\x81\x05\x57\x6f\x72\x6c\x64")
self.assertFalse(self.channel.transport.disconnected)


def test_sendFrame(self):
"""
L{WebSocketHybiTransport.sendFrame} creates an unmasked hybi-10 frame
and writes it to the request
"""
self.handler.transport.sendFrame(OPCODE_PING, "ping")
self.assertEquals(
self.channel.transport.written.getvalue(),
"\x89\x04\x70\x69\x6e\x67")
self.assertFalse(self.channel.transport.disconnected)


def test_sendLongFrame(self):
"""
Sending a frame with a payload longer than 125 bytes results in a
longer length field written to the request.
"""
self.handler.transport.sendFrame(OPCODE_TEXT, "crap" * 20000)
self.assertEquals(
self.channel.transport.written.getvalue(),
"\x81\x7f\x00\x00\x00\x00\x00\x01\x38\x80" + "crap" * 20000)
self.assertFalse(self.channel.transport.disconnected)


def test_sendFragmentedFrame(self):
"""
Sending a frame with the fragmented flag makes the correct flag unset.
"""
self.handler.transport.sendFrame(OPCODE_TEXT, "Hello", fragmented=True)
self.assertEquals(
self.channel.transport.written.getvalue(),
"\x01\x05\x48\x65\x6c\x6c\x6f")
self.assertFalse(self.channel.transport.disconnected)

0 comments on commit 29ec4bd

Please sign in to comment.