diff --git a/doc/index.md b/doc/index.md index 9cab6ed3..b99d1b63 100644 --- a/doc/index.md +++ b/doc/index.md @@ -908,6 +908,95 @@ Returns the time in HTTP preferred date format (See [RFC 7231 section 7.1.1.1](h Current version of lua-http as a string. +## http.websocket + + +### `new_from_uri(uri, protocols)` {#http.websocket.new_from_uri} + +Creates a new `http.websocket` object of type `"client"` from the given URI. + + - `protocols` (optional) should be a lua table containing a sequence of protocols to send to the server + + +### `new_from_stream(headers, stream)` {#http.websocket.new_from_stream} + +Attempts to create a new `http.websocket` object of type `"server"` from the given request headers and stream. + + - [`headers`](#http.headers) should be headers of a suspected websocket upgrade request from a HTTP 1 client. + - [`stream`](#http.h1_stream) should be a live HTTP 1 stream of the `"server"` type. + +This function does **not** have side effects, and is hence okay to use tentatively. + + +### `websocket.close_timeout` {#http.websocket.close_timeout} + +Amount of time (in seconds) to wait between sending a close frame and actually closing the connection. +Defaults to `3` seconds. + + +### `websocket:accept(protocols, timeout)` {#http.websocket:accept} + +Completes negotiation with a websocket client. + + - `protocols` (optional) should be a lua table containing a sequence of protocols to to allow from the client + +Usually called after a successful [`new_from_stream`](#http.websocket.new_from_stream) + + +### `websocket:connect(timeout)` {#http.websocket:connect} + +Connect to a websocket server. + +Usually called after a successful [`new_from_uri`](#http.websocket.new_from_uri) + + +### `websocket:receive(timeout)` {#http.websocket:receive} + +Reads and returns the next data frame plus its opcode. +Any ping frames received while reading will be responded to. + +The opcode `0x1` will be returned as `"text"` and `0x2` will be returned as `"binary"`. + + +### `websocket:each()` {#http.websocket:each} + +Iterator over [`websocket:receive()`](#http.websocket:receive). + + +### `websocket:send_frame(frame, timeout)` {#http.websocket:send_frame} + +Low level function to send a raw frame. + + +### `websocket:send(data, opcode, timeout)` {#http.websocket:send} + +Send the given `data` as a data frame. + + - `data` should be a string + - `opcode` can be a numeric opcode, `"text"` or `"binary"`. If `nil`, defaults to a text frame + + +### `websocket:close(code, reason, timeout)` {#http.websocket:close} + +Closes the websocket connection. + + - `code` defaults to `1000` + - `reason` is an optional string + + +### Example + +```lua +local websocket = require "http.websocket" +local ws = websocket.new_from_uri("wss://echo.websocket.org") +assert(ws:connect()) +assert(ws:send("koo-eee!")) +local data = assert(ws:receive()) +assert(data == "koo-eee!") +assert(ws:close()) +``` + + ## http.zlib An abstraction layer over the various lua zlib libraries. diff --git a/examples/websocket_client.lua b/examples/websocket_client.lua new file mode 100644 index 00000000..cc0cee93 --- /dev/null +++ b/examples/websocket_client.lua @@ -0,0 +1,20 @@ +--[[ +Example of websocket client usage + + - Connects to the coinbase feed. + - Sends a subscribe message + - Prints off 5 messages + - Close the socket and clean up. +]] + +local json = require "cjson" +local websocket = require "http.websocket" + +local ws = websocket.new_from_uri("ws://ws-feed.exchange.coinbase.com") +assert(ws:connect()) +assert(ws:send(json.encode({type = "subscribe", product_id = "BTC-USD"}))) +for _=1, 5 do + local data = assert(ws:receive()) + print(data) +end +assert(ws:close()) diff --git a/http-scm-0.rockspec b/http-scm-0.rockspec index 7b9db7df..f00e3a64 100644 --- a/http-scm-0.rockspec +++ b/http-scm-0.rockspec @@ -41,6 +41,7 @@ build = { ["http.tls"] = "http/tls.lua"; ["http.util"] = "http/util.lua"; ["http.version"] = "http/version.lua"; + ["http.websocket"] = "http/websocket.lua"; ["http.zlib"] = "http/zlib.lua"; ["http.compat.prosody"] = "http/compat/prosody.lua"; }; diff --git a/http/request.lua b/http/request.lua index e200a38b..571ff5fa 100644 --- a/http/request.lua +++ b/http/request.lua @@ -346,6 +346,7 @@ function request_methods:go(timeout) end return { + new_from_uri_t = new_from_uri_t; new_from_uri = new_from_uri; new_connect = new_connect; new_from_stream = new_from_stream; diff --git a/http/websocket.lua b/http/websocket.lua new file mode 100644 index 00000000..25dccbf3 --- /dev/null +++ b/http/websocket.lua @@ -0,0 +1,653 @@ +--[[ +WebSocket module + +Specified in RFC-6455 + +Design criteria: + - Client API must work without an event loop + - Borrow from the Browser Javascript WebSocket API when sensible + - server-side API should mirror client-side API + +This code is partially based on MIT/X11 code Copyright (C) 2012 Florian Zeitz +]] + +local basexx = require "basexx" +local spack = string.pack or require "compat53.string".pack +local sunpack = string.unpack or require "compat53.string".unpack +local unpack = table.unpack or unpack -- luacheck: ignore 113 +local cqueues = require "cqueues" +local monotime = cqueues.monotime +local ce = require "cqueues.errno" +local uri_patts = require "lpeg_patterns.uri" +local rand = require "openssl.rand" +local digest = require "openssl.digest" +local bit = require "http.bit" +local new_headers = require "http.headers".new +local http_request = require "http.request" + +local websocket_methods = { + -- Max seconds to wait after sending close frame until closing connection + close_timeout = 3; +} + +local websocket_mt = { + __name = "http.websocket"; + __index = websocket_methods; +} + +local magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +-- a nonce consisting of a randomly selected 16-byte value that has been base64-encoded +local function new_key() + return basexx.to_base64(rand.bytes(16)) +end + +local function base64_sha1(str) + return basexx.to_base64(digest.new("sha1"):final(str)) +end + +-- trim12 from http://lua-users.org/wiki/StringTrim +local function trim(s) + local from = s:match"^%s*()" + return from > #s and "" or s:match(".*%S", from) +end + +--[[ this value MUST be non-empty strings with characters in the range U+0021 +to U+007E not including separator characters as defined in [RFC2616] ]] +local function validate_protocol(p) + return p:match("^[\33\35-\39\42\43\45\46\48-\57\65-\90\94-\122\124\126\127]+$") +end + +-- XORs the string `str` with a 32bit key +local function apply_mask(str, key) + assert(#key == 4) + local data = {} + for i = 1, #str do + local key_index = (i-1)%4 + 1 + data[i] = string.char(bit.bxor(key[key_index], str:byte(i))) + end + return table.concat(data) +end + +local function build_frame(desc) + local data = desc.data or "" + + assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode") + if desc.opcode >= 0x8 then + -- RFC 6455 5.5 + assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.") + end + + local b1 = desc.opcode + if desc.FIN then + b1 = bit.bor(b1, 0x80) + end + if desc.RSV1 then + b1 = bit.bor(b1, 0x40) + end + if desc.RSV2 then + b1 = bit.bor(b1, 0x20) + end + if desc.RSV3 then + b1 = bit.bor(b1, 0x10) + end + + local b2 = #data + local length_extra + if b2 <= 125 then -- 7-bit length + length_extra = "" + elseif b2 <= 0xFFFF then -- 2-byte length + b2 = 126 + length_extra = spack(">I2", #data) + else -- 8-byte length + b2 = 127 + length_extra = spack(">I8", #data) + end + + local key = "" + if desc.MASK then + local key_a = desc.key + if key_a then + key = string.char(unpack(key_a, 1, 4)) + else + key = rand.bytes(4) + key_a = {key:byte(1,4)} + end + b2 = bit.bor(b2, 0x80) + data = apply_mask(data, key_a) + end + + return string.char(b1, b2) .. length_extra .. key .. data +end + +local function build_close(code, message, mask) + local data = spack(">I2", code) + if message then + assert(#message<=123, "Close reason must be <=123 bytes") + data = data .. message + end + return { + opcode = 0x8; + FIN = true; + MASK = mask; + data = data; + } +end + +local function read_frame(sock, deadline) + local frame do + local first_2, err, errno = sock:xread(2, deadline and (deadline-monotime())) + if not first_2 then + return nil, err, errno + end + local byte1, byte2 = first_2:byte(1, 2) + frame = { + FIN = bit.band(byte1, 0x80) ~= 0; + RSV1 = bit.band(byte1, 0x40) ~= 0; + RSV2 = bit.band(byte1, 0x20) ~= 0; + RSV3 = bit.band(byte1, 0x10) ~= 0; + opcode = bit.band(byte1, 0x0F); + + MASK = bit.band(byte2, 0x80) ~= 0; + length = bit.band(byte2, 0x7F); + + data = nil; + } + end + + if frame.length == 126 then + local length, err, errno = sock:xread(2, deadline and (deadline-monotime())) + if not length then + return nil, err, errno + end + frame.length = sunpack(">I2", length) + elseif frame.length == 127 then + local length, err, errno = sock:xread(8, deadline and (deadline-monotime())) + if not length then + return nil, err, errno + end + frame.length = sunpack(">I8", length) + end + + if frame.MASK then + local key, err, errno = sock:xread(4, deadline and (deadline-monotime())) + if not key then + return nil, err, errno + end + frame.key = { key:byte(1, 4) } + end + + do + local data, err, errno = sock:xread(frame.length, deadline and (deadline-monotime())) + if data == nil then + return nil, err, errno + end + + if frame.MASK then + frame.data = apply_mask(data, frame.key) + else + frame.data = data + end + end + + return frame +end + +local function parse_close(data) + local code, message + if #data >= 2 then + code = sunpack(">I2", data) + if #data > 2 then + message = data:sub(3) + end + end + return code, message +end + +function websocket_methods:send_frame(frame, timeout) + if self.readyState < 1 or self.readyState > 2 then + return nil, ce.strerror(ce.EPIPE), ce.EPIPE + end + local ok, err, errno = self.socket:xwrite(build_frame(frame), "n", timeout) + if not ok then + return nil, err, errno + end + if frame.opcode == 0x8 then + self.readyState = 2 + end + return true +end + +function websocket_methods:send(data, opcode, timeout) + if self.readyState >= 2 then + return nil, "WebSocket closed, unable to send data", ce.EPIPE + end + assert(type(data) == "string") + if opcode == "text" or opcode == nil then + opcode = 0x1 + elseif opcode == "binary" then + opcode = 0x2; + end + return self:send_frame({ + FIN = true; + --[[ RFC 6455 + 5.1: A server MUST NOT mask any frames that it sends to the client + 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked]] + MASK = self.type == "client"; + opcode = opcode; + data = data; + }, timeout) +end + +local function close_helper(self, code, reason, deadline) + if self.readyState == 3 then + return nil, ce.strerror(ce.EPIPE), ce.EPIPE + end + + if self.readyState < 2 then + local close_frame = build_close(code, reason, self.type == "client") + -- ignore failure + self:send_frame(close_frame, deadline and deadline-monotime()) + self.readyState = 2 + end + + -- Do not close socket straight away, wait for acknowledgement from server + local read_deadline = monotime() + self.close_timeout + if deadline then + read_deadline = math.min(read_deadline, deadline) + end + while not self.got_close_code do + if not self:receive(read_deadline-monotime()) then + break + end + end + + if self.readyState < 3 then + self.readyState = 3 + self.socket:shutdown() + cqueues.poll() + cqueues.poll() + self.socket:close() + end + + return nil, reason, code +end + +function websocket_methods:close(code, reason, timeout) + local deadline = timeout and (monotime()+timeout) + code = code or 1000 + close_helper(self, code, reason, deadline) + return true +end + +function websocket_methods:receive(timeout) + if self.readyState < 1 or self.readyState > 2 then + return nil, ce.strerror(ce.EPIPE), ce.EPIPE + end + local deadline = timeout and (monotime()+timeout) + while true do + local frame, err, errno = read_frame(self.socket, deadline and (deadline-monotime())) + if frame == nil then + return nil, err, errno + end + + -- Error cases + if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero + return close_helper(self, 1002, "Reserved bits not zero", deadline) + end + + if frame.opcode < 0x8 then + if frame.opcode == 0x0 then -- Continuation frames + if not self.databuffer then + return close_helper(self, 1002, "Unexpected continuation frame", deadline) + end + self.databuffer[#self.databuffer+1] = frame.data + elseif frame.opcode == 0x1 or frame.opcode == 0x2 then -- Text or Binary frame + if self.databuffer then + return close_helper(self, 1002, "Continuation frame expected", deadline) + end + self.databuffer = { frame.data } + self.databuffer_type = frame.opcode + else + return close_helper(self, 1002, "Reserved opcode", deadline) + end + if frame.FIN then + local databuffer_type = self.databuffer_type + self.databuffer_type = nil + if databuffer_type == 0x1 then + databuffer_type = "text" + elseif databuffer_type == 0x2 then + databuffer_type = "binary" + end + local databuffer = table.concat(self.databuffer) + self.databuffer = nil + return databuffer, databuffer_type + end + else -- Control frame + if frame.length > 125 then -- Control frame with too much payload + return close_helper(self, 1002, "Payload too large", deadline) + elseif not frame.FIN then -- Fragmented control frame + return close_helper(self, 1002, "Fragmented control frame", deadline) + end + if frame.opcode == 0x8 then -- Close request + if frame.length == 1 then + return close_helper(self, 1002, "Close frame with payload, but too short for status code", deadline) + end + local status_code, message = parse_close(frame.data) + if status_code == nil then + --[[ RFC 6455 7.4.1 + 1005 is a reserved value and MUST NOT be set as a status code in a + Close control frame by an endpoint. It is designated for use in + applications expecting a status code to indicate that no status + code was actually present.]] + status_code = 1005 + elseif status_code < 1000 then + return close_helper(self, 1002, "Closed with invalid status code", deadline) + elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then + return close_helper(self, 1002, "Closed with reserved status code", deadline) + end + self.got_close_code = status_code + self.got_close_message = message + --[[ RFC 6455 5.5.1 + When sending a Close frame in response, the endpoint typically + echos the status code it received.]] + return close_helper(self, status_code, message, deadline) + elseif frame.opcode == 0x9 then -- Ping frame + frame.opcode = 0xA + --[[ RFC 6455 + 5.1: A server MUST NOT mask any frames that it sends to the client + 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked]] + frame.MASK = self.type == "client"; + if not self:send_frame(frame, deadline and (deadline-monotime())) then + return close_helper(self, 1002, "Pong failed", deadline) + end + elseif frame.opcode == 0xA then -- luacheck: ignore 542 + -- Received unexpected pong frame + else + return close_helper(self, 1002, "Reserved opcode", deadline) + end + end + end +end + +function websocket_methods:each() + return function(self) -- luacheck: ignore 432 + return self:receive() + end, self +end + +local function new(type) + assert(type == "client" or type == "server") + local self = setmetatable({ + socket = nil; + type = type; + readyState = 0; + databuffer = nil; + databuffer_type = nil; + got_close_code = nil; + got_close_reason = nil; + key = nil; + protocol = nil; + protocols = nil; + -- only used by client: + request = nil; + -- only used by server: + stream = nil; + }, websocket_mt) + return self +end + +local function new_from_uri_t(uri_t, protocols) + local scheme = assert(uri_t.scheme, "URI missing scheme") + assert(scheme == "ws" or scheme == "wss", "scheme not websocket") + local self = new("client") + self.request = http_request.new_from_uri_t(uri_t) + self.request.version = 1.1 + self.request.headers:append("upgrade", "websocket") + self.request.headers:append("connection", "upgrade") + self.key = new_key() + self.request.headers:append("sec-websocket-key", self.key, true) + self.request.headers:append("sec-websocket-version", "13") + if protocols then + --[[ The request MAY include a header field with the name + Sec-WebSocket-Protocol. If present, this value indicates one + or more comma-separated subprotocol the client wishes to speak, + ordered by preference. The elements that comprise this value + MUST be non-empty strings with characters in the range U+0021 to + U+007E not including separator characters as defined in + [RFC2616] and MUST all be unique strings.]] + local n_protocols = #protocols + -- Copy the passed 'protocols' array so that caller is allowed to modify + local protocols_copy = {} + for i=1, n_protocols do + local v = protocols[i] + if protocols_copy[v] then + error("duplicate protocol") + end + assert(validate_protocol(v), "invalid protocol") + protocols_copy[v] = true + protocols_copy[i] = v + end + self.protocols = protocols_copy + self.request.headers:append("sec-websocket-protocol", table.concat(protocols_copy, ",", 1, n_protocols)) + end + return self +end + +local function new_from_uri(uri, ...) + local uri_t = assert(uri_patts.uri:match(uri), "invalid URI") + return new_from_uri_t(uri_t, ...) +end + +--[[ Takes a response to a websocket upgrade request, +and attempts to complete a websocket connection]] +local function handle_websocket_response(self, headers, stream) + assert(self.type == "client" and self.readyState == 0) + + if stream.connection.version < 1 or stream.connection.version >= 2 then + return nil, "websockets only supported with HTTP 1.x", ce.EINVAL + end + + --[[ If the status code received from the server is not 101, the + client handles the response per HTTP [RFC2616] procedures. In + particular, the client might perform authentication if it + receives a 401 status code; the server might redirect the client + using a 3xx status code (but clients are not required to follow + them), etc.]] + if headers:get(":status") ~= "101" then + return nil, "status code not 101", ce.EINVAL + end + + --[[ If the response lacks an Upgrade header field or the Upgrade + header field contains a value that is not an ASCII case- + insensitive match for the value "websocket", the client MUST + Fail the WebSocket Connection]] + local upgrade = headers:get("upgrade") + if not upgrade or upgrade:lower() ~= "websocket" then + return nil, "upgrade header not websocket", ce.EINVAL + end + + --[[ If the response lacks a Connection header field or the + Connection header field doesn't contain a token that is an + ASCII case-insensitive match for the value "Upgrade", the client + MUST Fail the WebSocket Connection]] + local has_connection_upgrade = false + local connection_header = headers:get_split_as_sequence("connection") + for i=1, connection_header.n do + if connection_header[i]:lower() == "upgrade" then + has_connection_upgrade = true + break + end + end + if not has_connection_upgrade then + return nil, "connection header doesn't contain upgrade", ce.EINVAL + end + + --[[ If the response lacks a Sec-WebSocket-Accept header field or + the Sec-WebSocket-Accept contains a value other than the + base64-encoded SHA-1 of the concatenation of the Sec-WebSocket- + Key (as a string, not base64-decoded) with the string "258EAFA5- + E914-47DA-95CA-C5AB0DC85B11" but ignoring any leading and + trailing whitespace, the client MUST Fail the WebSocket Connection]] + local sec_websocket_accept = headers:get("sec-websocket-accept") + if sec_websocket_accept == nil or + trim(sec_websocket_accept) ~= base64_sha1(self.key .. magic) + then + return nil, "sec-websocket-accept header incorrect", ce.EINVAL + end + + --[[ If the response includes a Sec-WebSocket-Extensions header field and + this header field indicates the use of an extension that was not present + in the client's handshake (the server has indicated an extension not + requested by the client), the client MUST Fail the WebSocket Connection]] + -- For now, we don't support any extensions + if headers:has("sec-websocket-extensions") then + return nil, "extensions not supported", ce.EINVAL + end + + --[[ If the response includes a Sec-WebSocket-Protocol header field and + this header field indicates the use of a subprotocol that was not present + in the client's handshake (the server has indicated a subprotocol not + requested by the client), the client MUST Fail the WebSocket Connection]] + local protocol = headers:get("sec-websocket-protocol") + if protocol then + local has_matching_protocol = self.protocols and self.protocols[protocol] + if not has_matching_protocol then + return nil, "unexpected protocol", ce.EINVAL + end + end + + -- Success! + assert(self.socket == nil, "websocket:connect called twice") + self.socket = assert(stream.connection:take_socket()) + self.request = nil + self.readyState = 1 + self.protocol = protocol + + return true +end + +function websocket_methods:connect(timeout) + assert(self.type == "client" and self.readyState == 0) + local headers, stream, errno = self.request:go(timeout) + if not headers then + return nil, stream, errno + end + return handle_websocket_response(self, headers, stream) +end + +-- Given an incoming HTTP1 request, attempts to upgrade it to a websocket connection +local function new_from_stream(headers, stream) + assert(stream.connection.type == "server") + + if stream.connection.version < 1 or stream.connection.version >= 2 then + return nil, "websockets only supported with HTTP 1.x", ce.EINVAL + end + + local upgrade = headers:get("upgrade") + if not upgrade or upgrade:lower() ~= "websocket" then + return nil, "upgrade header not websocket", ce.EINVAL + end + + local has_connection_upgrade = false + local connection_header = headers:get_split_as_sequence("connection") + for i=1, connection_header.n do + if connection_header[i]:lower() == "upgrade" then + has_connection_upgrade = true + break + end + end + if not has_connection_upgrade then + return nil, "connection header doesn't contain upgrade", ce.EINVAL + end + + local key = trim(headers:get("sec-websocket-key")) + if not key then + return nil, "missing sec-websocket-key", ce.EINVAL + end + + if headers:get("sec-websocket-version") ~= "13" then + return nil, "unsupported sec-websocket-version" + end + + local protocols_available + if headers:has("sec-websocket-protocol") then + local client_protocols = headers:get_split_as_sequence("sec-websocket-protocol") + --[[ The request MAY include a header field with the name + Sec-WebSocket-Protocol. If present, this value indicates one + or more comma-separated subprotocol the client wishes to speak, + ordered by preference. The elements that comprise this value + MUST be non-empty strings with characters in the range U+0021 to + U+007E not including separator characters as defined in + [RFC2616] and MUST all be unique strings.]] + protocols_available = {} + for i, protocol in ipairs(client_protocols) do + protocol = trim(protocol) + if protocols_available[protocol] then + return nil, "duplicate protocol", ce.EINVAL + end + if not validate_protocol(protocol) then + return nil, "invalid protocol", ce.EINVAL + end + protocols_available[protocol] = true + protocols_available[i] = protocol + end + end + + local self = new("server") + self.key = key + self.protocols = protocols_available + self.stream = stream + return self +end + +function websocket_methods:accept(protocols, timeout) + assert(self.type == "server" and self.readyState == 0) + + local response_headers = new_headers() + response_headers:upsert(":status", "101") + response_headers:upsert("upgrade", "websocket") + response_headers:upsert("connection", "upgrade") + response_headers:upsert("sec-websocket-accept", base64_sha1(self.key .. magic)) + + local chosen_protocol + if self.protocols then + if protocols then + for _, protocol in ipairs(protocols) do + if self.protocols[protocol] then + chosen_protocol = protocol + break + end + end + end + if not chosen_protocol then + return nil, "no matching protocol", ce.EPROTONOSUPPORT + end + response_headers:upsert("sec-websocket-protocol", chosen_protocol) + end + + do + local ok, err, errno = self.stream:write_headers(response_headers, false, timeout) + if not ok then + return ok, err, errno + end + end + + self.socket = assert(self.stream.connection:take_socket()) + self.stream = nil + self.readyState = 1 + self.protocol = chosen_protocol + return true +end + +return { + new_from_uri_t = new_from_uri_t; + new_from_uri = new_from_uri; + new_from_stream = new_from_stream; + + new = new; + build_frame = build_frame; + read_frame = read_frame; + build_close = build_close; + parse_close = parse_close; +} diff --git a/spec/websocket_spec.lua b/spec/websocket_spec.lua new file mode 100644 index 00000000..78f1cf21 --- /dev/null +++ b/spec/websocket_spec.lua @@ -0,0 +1,185 @@ +local TEST_TIMEOUT = 2 +describe("http.websocket module's internal functions work", function() + local websocket = require "http.websocket" + it("build_frame works for simple cases", function() + -- Examples from RFC 6455 Section 5.7 + + -- A single-frame unmasked text message + assert.same(string.char(0x81,0x05,0x48,0x65,0x6c,0x6c,0x6f), websocket.build_frame { + FIN = true; + MASK = false; + opcode = 0x1; + data = "Hello"; + }) + + -- A single-frame masked text message + assert.same(string.char(0x81,0x85,0x37,0xfa,0x21,0x3d,0x7f,0x9f,0x4d,0x51,0x58), websocket.build_frame { + FIN = true; + MASK = true; + key = {0x37,0xfa,0x21,0x3d}; + opcode = 0x1; + data = "Hello"; + }) + end) + it("build_frame validates opcode", function() + assert.has.errors(function() + websocket.build_frame { opcode = -1; } + end) + assert.has.errors(function() + websocket.build_frame { opcode = 16; } + end) + end) + it("build_frame validates data length", function() + assert.has.errors(function() + websocket.build_frame { + opcode = 0x8; + data = ("f"):rep(200); + } + end) + end) + it("build_close works for common case", function() + assert.same({ + opcode = 0x8; + FIN = true; + MASK = false; + data = "\3\232"; + }, websocket.build_close(1000, nil, false)) + + assert.same({ + opcode = 0x8; + FIN = true; + MASK = false; + data = "\3\232error"; + }, websocket.build_close(1000, "error", false)) + end) + it("build_close validates string length", function() + assert.has.errors(function() websocket.build_close(1000, ("f"):rep(200), false) end) + end) + it("parse_close works", function() + assert.same({nil, nil}, {websocket.parse_close ""}) + assert.same({1000, nil}, {websocket.parse_close "\3\232"}) + assert.same({1000, "error"}, {websocket.parse_close "\3\232error"}) + end) +end) +describe("http.websocket module two sided tests", function() + local server = require "http.server" + local util = require "http.util" + local websocket = require "http.websocket" + local cs = require "cqueues.socket" + local cqueues = require "cqueues" + local function assert_loop(cq, timeout) + local ok, err, _, thd = cq:loop(timeout) + if not ok then + if thd then + err = debug.traceback(thd, err) + end + error(err, 2) + end + end + local function new_pair() + local c, s = cs.pair() + local ws_client = websocket.new("client") + ws_client.socket = c + ws_client.readyState = 1 + local ws_server = websocket.new("server") + ws_server.socket = s + ws_server.readyState = 1 + return ws_client, ws_server + end + it("works with a socketpair", function() + local cq = cqueues.new() + local c, s = new_pair() + cq:wrap(function() + assert(c:send("hello")) + assert.same("world", c:receive()) + assert(c:close()) + end) + cq:wrap(function() + assert.same("hello", s:receive()) + assert(s:send("world")) + assert(s:close()) + end) + assert_loop(cq, TEST_TIMEOUT) + assert.truthy(cq:empty()) + end) + for _, flag in ipairs{"RSV1", "RSV2", "RSV3"} do + it("fails correctly on "..flag.." flag set", function() + local cq = cqueues.new() + local c, s = new_pair() + cq:wrap(function() + assert(c:send_frame({ + opcode = 1; + [flag] = true; + })) + assert(c:close()) + end) + cq:wrap(function() + local ok, _, errno = s:receive() + assert.same(nil, ok) + assert.same(1002, errno) + assert(s:close()) + end) + assert_loop(cq, TEST_TIMEOUT) + assert.truthy(cq:empty()) + end) + end + it("works when using uri string constructor", function() + local cq = cqueues.new() + local s = server.listen { + host = "localhost"; + port = 0; + } + assert(s:listen()) + local _, host, port = s:localname() + cq:wrap(function() + s:run(function (stream) + local headers = assert(stream:get_headers()) + s:pause() + local ws = websocket.new_from_stream(headers, stream) + assert(ws:accept()) + assert(ws:close()) + end) + s:close() + end) + cq:wrap(function() + local ws = websocket.new_from_uri("ws://"..util.to_authority(host, port, "ws")); + assert(ws:connect()) + assert(ws:close()) + end) + assert_loop(cq, TEST_TIMEOUT) + assert.truthy(cq:empty()) + end) + it("works when using uri table constructor and protocols", function() + local cq = cqueues.new() + local s = server.listen { + host = "localhost"; + port = 0; + } + assert(s:listen()) + local _, host, port = s:localname() + cq:wrap(function() + s:run(function (stream) + local headers = assert(stream:get_headers()) + s:pause() + local ws = websocket.new_from_stream(headers, stream) + assert(ws:accept({"my awesome-protocol", "foo"})) + -- Should prefer client protocol preference + assert.same("foo", ws.protocol) + assert(ws:close()) + end) + s:close() + end) + cq:wrap(function() + local ws = websocket.new_from_uri_t({ + scheme = "ws"; + host = host; + port = port; + }, {"foo", "my-awesome-protocol", "bar"}) + assert(ws:connect()) + assert.same("foo", ws.protocol) + assert(ws:close()) + end) + assert_loop(cq, TEST_TIMEOUT) + assert.truthy(cq:empty()) + end) +end)