diff --git a/socket.exclude b/socket.exclude index 44c18a9..55df588 100644 --- a/socket.exclude +++ b/socket.exclude @@ -15,7 +15,7 @@ !/csrc/ !/media/ -!/socket* +!/socket.* !/mime.lua !/ltn12.lua !/socket/ diff --git a/socket.lua b/socket.lua index 3913e6f..8f2d159 100644 --- a/socket.lua +++ b/socket.lua @@ -39,16 +39,16 @@ function _M.bind(host, port, backlog) if not sock then return nil, err end sock:setoption("reuseaddr", true) res, err = sock:bind(alt.addr, port) - if not res then + if not res then sock:close() - else + else res, err = sock:listen(backlog) - if not res then + if not res then sock:close() else return sock end - end + end end return nil, err end diff --git a/socket/http.lua b/socket/http.lua index 1d0eb50..a16b3c7 100644 --- a/socket/http.lua +++ b/socket/http.lua @@ -44,12 +44,12 @@ local function receiveheaders(sock, headers) if not (name and value) then return nil, "malformed reponse headers" end name = string.lower(name) -- get next line (value might be folded) - line, err = sock:receive() + line, err = sock:receive() if err then return nil, err end -- unfold any folded values while string.find(line, "^%s") do value = value .. line - line = sock:receive() + line, err = sock:receive() if err then return nil, err end end -- save pair in table @@ -76,7 +76,7 @@ socket.sourcet["http-chunked"] = function(sock, headers) -- was it the last chunk? if size > 0 then -- if not, get chunk and skip terminating CRLF - local chunk, err, part = sock:receive(size) + local chunk, err = sock:receive(size) if chunk then sock:receive() end return chunk, err else @@ -219,7 +219,7 @@ local function adjustheaders(reqt) } -- if we have authentication information, pass it along if reqt.user and reqt.password then - lower["authorization"] = + lower["authorization"] = "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) end -- override with user headers @@ -243,7 +243,7 @@ local function adjustrequest(reqt) -- explicit components override url for i,v in base.pairs(reqt) do nreqt[i] = v end if nreqt.port == "" then nreqt.port = 80 end - socket.try(nreqt.host and nreqt.host ~= "", + socket.try(nreqt.host and nreqt.host ~= "", "invalid host '" .. base.tostring(nreqt.host) .. "'") -- compute uri if user hasn't overriden nreqt.uri = reqt.uri or adjusturi(nreqt) @@ -281,10 +281,10 @@ local trequest, tredirect source = reqt.source, sink = reqt.sink, headers = reqt.headers, - proxy = reqt.proxy, + proxy = reqt.proxy, nredirects = (reqt.nredirects or 0) + 1, create = reqt.create - } + } -- pass location header back as a hint we redirected headers = headers or {} headers.location = headers.location or location @@ -301,7 +301,7 @@ end h:sendheaders(nreqt.headers) -- if there is a body, send it if nreqt.source then - h:sendbody(nreqt.headers, nreqt.source, nreqt.step) + h:sendbody(nreqt.headers, nreqt.source, nreqt.step) end local code, status = h:receivestatusline() -- if it is an HTTP/0.9 server, simply get the body and we are done @@ -311,13 +311,13 @@ end end local headers -- ignore any 100-continue messages - while code == 100 do + while code == 100 do headers = h:receiveheaders() code, status = h:receivestatusline() end headers = h:receiveheaders() -- at this point we should have a honest reply from the server - -- we can't redirect if we already used the source, so we report the error + -- we can't redirect if we already used the source, so we report the error if shouldredirect(nreqt, code, headers) and not nreqt.source then h:close() return tredirect(reqt, headers.location) @@ -348,9 +348,22 @@ local function srequest(u, b) return table.concat(t), code, headers, status end +local function pack(ok, ...) + return ok, {n = select('#', ...), ...} +end + +--reimplement socket.protect in Lua so we can yield across C-stack boundaries. +function socket.protect(f) + return function(...) + local ok, ret = pack(pcall(f, ...)) + if ok then return unpack(ret, 1, ret.n) + else return nil, ret[1] end + end +end + _M.request = socket.protect(function(reqt, body) if base.type(reqt) == "string" then return srequest(reqt, body) else return trequest(reqt) end end) -return _M \ No newline at end of file +return _M