Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 73 additions & 35 deletions http/websocket.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ local http_patts = require "lpeg_patterns.http"
local rand = require "openssl.rand"
local digest = require "openssl.digest"
local bit = require "http.bit"
local onerror = require "http.connection_common".onerror
local new_headers = require "http.headers".new
local http_request = require "http.request"

Expand Down Expand Up @@ -177,13 +178,18 @@ local function build_close(code, message, mask)
end

local function read_frame(sock, deadline)
local frame do
local first_2, err, errno = sock:xread(2, "b", deadline and (deadline-monotime()))
local frame, first_2 do
local err, errno
first_2, err, errno = sock:xread(2, "b", deadline and (deadline-monotime()))
if not first_2 then
return nil, err, errno
elseif #first_2 ~= 2 then
sock:seterror("r", ce.EILSEQ)
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
local ok, errno2 = sock:unget(first_2)
if not ok then
return nil, onerror(sock, "unget", errno2)
end
return nil, onerror(sock, "read_frame", ce.EILSEQ)
end
local byte1, byte2 = first_2:byte(1, 2)
frame = {
Expand All @@ -200,50 +206,80 @@ local function read_frame(sock, deadline)
}
end

if frame.length == 126 then
local length, err, errno = sock:xread(2, "b", deadline and (deadline-monotime()))
if not length or #length ~= 2 then
if err == nil then
local fill_length = frame.length
if fill_length == 126 then
fill_length = 2
elseif fill_length == 127 then
fill_length = 8
end
if frame.MASK then
fill_length = fill_length + 4
end
do
local ok, err, errno = sock:fill(fill_length, 0)
if not ok then
local unget_ok1, unget_errno1 = sock:unget(first_2)
if not unget_ok1 then
return nil, onerror(sock, "unget", unget_errno1)
end
if errno == ce.ETIMEDOUT then
local timeout = deadline and deadline-monotime()
if cqueues.poll(sock, timeout) ~= timeout then
-- retry
return read_frame(sock, deadline)
end
elseif err == nil then
sock:seterror("r", ce.EILSEQ)
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
return nil, onerror(sock, "read_frame", ce.EILSEQ)
end
return nil, err, errno
end
frame.length = sunpack(">I2", length)
end

-- if `fill` succeeded these shouldn't be able to fail
local extra_fill_unget
if frame.length == 126 then
extra_fill_unget = assert(sock:xread(2, "b", 0))
frame.length = sunpack(">I2", extra_fill_unget)
fill_length = fill_length - 2
elseif frame.length == 127 then
local length, err, errno = sock:xread(8, "b", deadline and (deadline-monotime()))
if not length or #length ~= 8 then
if err == nil then
extra_fill_unget = assert(sock:xread(8, "b", 0))
frame.length = sunpack(">I8", extra_fill_unget)
fill_length = fill_length - 8 + frame.length
end

if extra_fill_unget then
local ok, err, errno = sock:fill(fill_length, 0)
if not ok then
local unget_ok1, unget_errno1 = sock:unget(extra_fill_unget)
if not unget_ok1 then
return nil, onerror(sock, "unget", unget_errno1)
end
local unget_ok2, unget_errno2 = sock:unget(first_2)
if not unget_ok2 then
return nil, onerror(sock, "unget", unget_errno2)
end
if errno == ce.ETIMEDOUT then
local timeout = deadline and deadline-monotime()
if cqueues.poll(sock, timeout) ~= timeout then
-- retry
return read_frame(sock, deadline)
end
elseif err == nil then
sock:seterror("r", ce.EILSEQ)
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
return nil, onerror(sock, "read_frame", ce.EILSEQ)
end
return nil, err, errno
end
frame.length = sunpack(">I8", length)
end

if frame.MASK then
local key, err, errno = sock:xread(4, "b", deadline and (deadline-monotime()))
if not key or #key ~= 4 then
if err == nil then
sock:seterror("r", ce.EILSEQ)
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
end
return nil, err, errno
end
local key = assert(sock:xread(4, "b", 0))
frame.key = { key:byte(1, 4) }
end

do
local data, err, errno = sock:xread(frame.length, "b", deadline and (deadline-monotime()))
if data == nil or #data ~= frame.length then
if err == nil then
sock:seterror("r", ce.EILSEQ)
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
end
return nil, err, errno
end

local data = assert(sock:xread(frame.length, "b", 0))
if frame.MASK then
frame.data = apply_mask(data, frame.key)
else
Expand All @@ -267,9 +303,9 @@ end

function websocket_methods:send_frame(frame, timeout)
if self.readyState < 1 then
return nil, ce.strerror(ce.ENOTCONN), ce.ENOTCONN
return nil, onerror(self.socket, "send_frame", ce.ENOTCONN)
elseif self.readyState > 2 then
return nil, ce.strerror(ce.EPIPE), ce.EPIPE
return nil, onerror(self.socket, "send_frame", ce.EPIPE)
end
local ok, err, errno = self.socket:xwrite(build_frame(frame), "bn", timeout)
if not ok then
Expand Down Expand Up @@ -349,9 +385,9 @@ end

function websocket_methods:receive(timeout)
if self.readyState < 1 then
return nil, ce.strerror(ce.ENOTCONN), ce.ENOTCONN
return nil, onerror(self.socket, "receive", ce.ENOTCONN)
elseif self.readyState > 2 then
return nil, ce.strerror(ce.EPIPE), ce.EPIPE
return nil, onerror(self.socket, "receive", ce.EPIPE)
end
local deadline = timeout and (monotime()+timeout)
while true do
Expand Down Expand Up @@ -638,6 +674,7 @@ local function handle_websocket_response(self, headers, stream)
-- Success!
assert(self.socket == nil, "websocket:connect called twice")
self.socket = assert(stream.connection:take_socket())
self.socket:onerror(onerror)
self.request = nil
self.headers = headers
self.readyState = 1
Expand Down Expand Up @@ -776,6 +813,7 @@ function websocket_methods:accept(options, timeout)
end

self.socket = assert(self.stream.connection:take_socket())
self.socket:onerror(onerror)
self.stream = nil
self.readyState = 1
self.protocol = chosen_protocol
Expand Down
24 changes: 24 additions & 0 deletions spec/websocket_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,18 @@ describe("http.websocket", function()
end)
end)
describe("http.websocket module two sided tests", function()
local onerror = require "http.connection_common".onerror
local server = require "http.server"
local util = require "http.util"
local websocket = require "http.websocket"
local cqueues = require "cqueues"
local ca = require "cqueues.auxlib"
local ce = require "cqueues.errno"
local cs = require "cqueues.socket"
local function new_pair()
local s, c = ca.assert(cs.pair())
s:onerror(onerror)
c:onerror(onerror)
local ws_server = websocket.new("server")
ws_server.socket = s
ws_server.readyState = 1
Expand All @@ -201,6 +205,26 @@ describe("http.websocket module two sided tests", function()
assert_loop(cq, TEST_TIMEOUT)
assert.truthy(cq:empty())
end)
it("timeouts return nil, err, errno", function()
local cq = cqueues.new()
local c, s = new_pair()
local ok, _, errno = c:receive(0)
assert.same(nil, ok)
assert.same(ce.ETIMEDOUT, errno)
-- Check it still works afterwards
cq:wrap(function()
assert(c:send("hello"))
assert.same("world", c:receive())
assert(c:close())
end)
cq:wrap(function()
assert.same("hello", s:receive())
assert(s:send("world"))
assert(s:close())
end)
assert_loop(cq, TEST_TIMEOUT)
assert.truthy(cq:empty())
end)
it("doesn't fail when data contains a \\r\\n", function()
local cq = cqueues.new()
local c, s = new_pair()
Expand Down