From 2f70b4ae4212a65472ebdf5a9079b6e7645b7422 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 23 May 2018 23:55:58 +1000 Subject: [PATCH] http/websocket.lua: Fix behaviour of read_frame on error - Fixes incorrect return values on error (missing error message #107) - Now retry safe --- http/websocket.lua | 108 +++++++++++++++++++++++++++------------- spec/websocket_spec.lua | 24 +++++++++ 2 files changed, 97 insertions(+), 35 deletions(-) diff --git a/http/websocket.lua b/http/websocket.lua index 51763b12..fbc2bfce 100644 --- a/http/websocket.lua +++ b/http/websocket.lua @@ -35,6 +35,7 @@ local http_patts = require "lpeg_patterns.http" local rand = require "openssl.rand" local digest = require "openssl.digest" local bit = require "http.bit" +local onerror = require "http.connection_common".onerror local new_headers = require "http.headers".new local http_request = require "http.request" @@ -177,13 +178,18 @@ local function build_close(code, message, mask) end local function read_frame(sock, deadline) - local frame do - local first_2, err, errno = sock:xread(2, "b", deadline and (deadline-monotime())) + local frame, first_2 do + local err, errno + first_2, err, errno = sock:xread(2, "b", deadline and (deadline-monotime())) if not first_2 then return nil, err, errno elseif #first_2 ~= 2 then sock:seterror("r", ce.EILSEQ) - return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ + local ok, errno2 = sock:unget(first_2) + if not ok then + return nil, onerror(sock, "unget", errno2) + end + return nil, onerror(sock, "read_frame", ce.EILSEQ) end local byte1, byte2 = first_2:byte(1, 2) frame = { @@ -200,50 +206,80 @@ local function read_frame(sock, deadline) } end - if frame.length == 126 then - local length, err, errno = sock:xread(2, "b", deadline and (deadline-monotime())) - if not length or #length ~= 2 then - if err == nil then + local fill_length = frame.length + if fill_length == 126 then + fill_length = 2 + elseif fill_length == 127 then + fill_length = 8 + end + if frame.MASK then + fill_length = fill_length + 4 + end + do + local ok, err, errno = sock:fill(fill_length, 0) + if not ok then + local unget_ok1, unget_errno1 = sock:unget(first_2) + if not unget_ok1 then + return nil, onerror(sock, "unget", unget_errno1) + end + if errno == ce.ETIMEDOUT then + local timeout = deadline and deadline-monotime() + if cqueues.poll(sock, timeout) ~= timeout then + -- retry + return read_frame(sock, deadline) + end + elseif err == nil then sock:seterror("r", ce.EILSEQ) - return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ + return nil, onerror(sock, "read_frame", ce.EILSEQ) end return nil, err, errno end - frame.length = sunpack(">I2", length) + end + + -- if `fill` succeeded these shouldn't be able to fail + local extra_fill_unget + if frame.length == 126 then + extra_fill_unget = assert(sock:xread(2, "b", 0)) + frame.length = sunpack(">I2", extra_fill_unget) + fill_length = fill_length - 2 elseif frame.length == 127 then - local length, err, errno = sock:xread(8, "b", deadline and (deadline-monotime())) - if not length or #length ~= 8 then - if err == nil then + extra_fill_unget = assert(sock:xread(8, "b", 0)) + frame.length = sunpack(">I8", extra_fill_unget) + fill_length = fill_length - 8 + frame.length + end + + if extra_fill_unget then + local ok, err, errno = sock:fill(fill_length, 0) + if not ok then + local unget_ok1, unget_errno1 = sock:unget(extra_fill_unget) + if not unget_ok1 then + return nil, onerror(sock, "unget", unget_errno1) + end + local unget_ok2, unget_errno2 = sock:unget(first_2) + if not unget_ok2 then + return nil, onerror(sock, "unget", unget_errno2) + end + if errno == ce.ETIMEDOUT then + local timeout = deadline and deadline-monotime() + if cqueues.poll(sock, timeout) ~= timeout then + -- retry + return read_frame(sock, deadline) + end + elseif err == nil then sock:seterror("r", ce.EILSEQ) - return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ + return nil, onerror(sock, "read_frame", ce.EILSEQ) end return nil, err, errno end - frame.length = sunpack(">I8", length) end if frame.MASK then - local key, err, errno = sock:xread(4, "b", deadline and (deadline-monotime())) - if not key or #key ~= 4 then - if err == nil then - sock:seterror("r", ce.EILSEQ) - return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ - end - return nil, err, errno - end + local key = assert(sock:xread(4, "b", 0)) frame.key = { key:byte(1, 4) } end do - local data, err, errno = sock:xread(frame.length, "b", deadline and (deadline-monotime())) - if data == nil or #data ~= frame.length then - if err == nil then - sock:seterror("r", ce.EILSEQ) - return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ - end - return nil, err, errno - end - + local data = assert(sock:xread(frame.length, "b", 0)) if frame.MASK then frame.data = apply_mask(data, frame.key) else @@ -267,9 +303,9 @@ end function websocket_methods:send_frame(frame, timeout) if self.readyState < 1 then - return nil, ce.strerror(ce.ENOTCONN), ce.ENOTCONN + return nil, onerror(self.socket, "send_frame", ce.ENOTCONN) elseif self.readyState > 2 then - return nil, ce.strerror(ce.EPIPE), ce.EPIPE + return nil, onerror(self.socket, "send_frame", ce.EPIPE) end local ok, err, errno = self.socket:xwrite(build_frame(frame), "bn", timeout) if not ok then @@ -349,9 +385,9 @@ end function websocket_methods:receive(timeout) if self.readyState < 1 then - return nil, ce.strerror(ce.ENOTCONN), ce.ENOTCONN + return nil, onerror(self.socket, "receive", ce.ENOTCONN) elseif self.readyState > 2 then - return nil, ce.strerror(ce.EPIPE), ce.EPIPE + return nil, onerror(self.socket, "receive", ce.EPIPE) end local deadline = timeout and (monotime()+timeout) while true do @@ -638,6 +674,7 @@ local function handle_websocket_response(self, headers, stream) -- Success! assert(self.socket == nil, "websocket:connect called twice") self.socket = assert(stream.connection:take_socket()) + self.socket:onerror(onerror) self.request = nil self.headers = headers self.readyState = 1 @@ -776,6 +813,7 @@ function websocket_methods:accept(options, timeout) end self.socket = assert(self.stream.connection:take_socket()) + self.socket:onerror(onerror) self.stream = nil self.readyState = 1 self.protocol = chosen_protocol diff --git a/spec/websocket_spec.lua b/spec/websocket_spec.lua index fbbdf52f..8dc1ed6a 100644 --- a/spec/websocket_spec.lua +++ b/spec/websocket_spec.lua @@ -169,14 +169,18 @@ describe("http.websocket", function() end) end) describe("http.websocket module two sided tests", function() + local onerror = require "http.connection_common".onerror local server = require "http.server" local util = require "http.util" local websocket = require "http.websocket" local cqueues = require "cqueues" local ca = require "cqueues.auxlib" + local ce = require "cqueues.errno" local cs = require "cqueues.socket" local function new_pair() local s, c = ca.assert(cs.pair()) + s:onerror(onerror) + c:onerror(onerror) local ws_server = websocket.new("server") ws_server.socket = s ws_server.readyState = 1 @@ -201,6 +205,26 @@ describe("http.websocket module two sided tests", function() assert_loop(cq, TEST_TIMEOUT) assert.truthy(cq:empty()) end) + it("timeouts return nil, err, errno", function() + local cq = cqueues.new() + local c, s = new_pair() + local ok, _, errno = c:receive(0) + assert.same(nil, ok) + assert.same(ce.ETIMEDOUT, errno) + -- Check it still works afterwards + cq:wrap(function() + assert(c:send("hello")) + assert.same("world", c:receive()) + assert(c:close()) + end) + cq:wrap(function() + assert.same("hello", s:receive()) + assert(s:send("world")) + assert(s:close()) + end) + assert_loop(cq, TEST_TIMEOUT) + assert.truthy(cq:empty()) + end) it("doesn't fail when data contains a \\r\\n", function() local cq = cqueues.new() local c, s = new_pair()