Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
copas/src/copas.lua
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2008 lines (1629 sloc)
57.8 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ------------------------------------------------------------------------------- | |
| -- Copas - Coroutine Oriented Portable Asynchronous Services | |
| -- | |
| -- A dispatcher based on coroutines that can be used by TCP/IP servers. | |
| -- Uses LuaSocket as the interface with the TCP/IP stack. | |
| -- | |
| -- Authors: Andre Carregal, Javier Guerra, and Fabio Mascarenhas | |
| -- Contributors: Diego Nehab, Mike Pall, David Burgess, Leonardo Godinho, | |
| -- Thomas Harning Jr., and Gary NG | |
| -- | |
| -- Copyright 2005-2013 - Kepler Project (www.keplerproject.org), 2015-2023 Thijs Schreijer | |
| -- | |
| -- $Id: copas.lua,v 1.37 2009/04/07 22:09:52 carregal Exp $ | |
| ------------------------------------------------------------------------------- | |
| if package.loaded["socket.http"] and (_VERSION=="Lua 5.1") then -- obsolete: only for Lua 5.1 compatibility | |
| error("you must require copas before require'ing socket.http") | |
| end | |
| if package.loaded["copas.http"] and (_VERSION=="Lua 5.1") then -- obsolete: only for Lua 5.1 compatibility | |
| error("you must require copas before require'ing copas.http") | |
| end | |
| local socket = require "socket" | |
| local binaryheap = require "binaryheap" | |
| local gettime = socket.gettime | |
| local ssl -- only loaded upon demand | |
| local WATCH_DOG_TIMEOUT = 120 | |
| local UDP_DATAGRAM_MAX = socket._DATAGRAMSIZE or 8192 | |
| local TIMEOUT_PRECISION = 0.1 -- 100ms | |
| local fnil = function() end | |
| local coroutine_create = coroutine.create | |
| local coroutine_running = coroutine.running | |
| local coroutine_yield = coroutine.yield | |
| local coroutine_resume = coroutine.resume | |
| local coroutine_status = coroutine.status | |
| -- nil-safe versions for pack/unpack | |
| local _unpack = unpack or table.unpack | |
| local unpack = function(t, i, j) return _unpack(t, i or 1, j or t.n or #t) end | |
| local pack = function(...) return { n = select("#", ...), ...} end | |
| local pcall = pcall | |
| if _VERSION=="Lua 5.1" and not jit then -- obsolete: only for Lua 5.1 compatibility | |
| pcall = require("coxpcall").pcall | |
| coroutine_running = require("coxpcall").running | |
| end | |
| do | |
| -- Redefines LuaSocket functions with coroutine safe versions (pure Lua) | |
| -- (this allows the use of socket.http from within copas) | |
| local err_mt = { | |
| __tostring = function (self) | |
| return "Copas 'try' error intermediate table: '"..tostring(self[1].."'") | |
| end, | |
| } | |
| local function statusHandler(status, ...) | |
| if status then return ... end | |
| local err = (...) | |
| if type(err) == "table" and getmetatable(err) == err_mt then | |
| return nil, err[1] | |
| else | |
| error(err) | |
| end | |
| end | |
| function socket.protect(func) | |
| return function (...) | |
| return statusHandler(pcall(func, ...)) | |
| end | |
| end | |
| function socket.newtry(finalizer) | |
| return function (...) | |
| local status = (...) | |
| if not status then | |
| pcall(finalizer or fnil, select(2, ...)) | |
| error(setmetatable({ (select(2, ...)) }, err_mt), 0) | |
| end | |
| return ... | |
| end | |
| end | |
| socket.try = socket.newtry() | |
| end | |
| -- Setup the Copas meta table to auto-load submodules and define a default method | |
| local copas do | |
| local submodules = { "ftp", "http", "lock", "queue", "semaphore", "smtp", "timer" } | |
| for i, key in ipairs(submodules) do | |
| submodules[key] = true | |
| submodules[i] = nil | |
| end | |
| copas = setmetatable({},{ | |
| __index = function(self, key) | |
| if submodules[key] then | |
| self[key] = require("copas."..key) | |
| submodules[key] = nil | |
| return rawget(self, key) | |
| end | |
| end, | |
| __call = function(self, ...) | |
| return self.loop(...) | |
| end, | |
| }) | |
| end | |
| -- Meta information is public even if beginning with an "_" | |
| copas._COPYRIGHT = "Copyright (C) 2005-2013 Kepler Project, 2015-2023 Thijs Schreijer" | |
| copas._DESCRIPTION = "Coroutine Oriented Portable Asynchronous Services" | |
| copas._VERSION = "Copas 4.7.0" | |
| -- Close the socket associated with the current connection after the handler finishes | |
| copas.autoclose = true | |
| -- indicator for the loop running | |
| copas.running = false | |
| ------------------------------------------------------------------------------- | |
| -- Object names, to track names of thread/coroutines and sockets | |
| ------------------------------------------------------------------------------- | |
| local object_names = setmetatable({}, { | |
| __mode = "k", | |
| __index = function(self, key) | |
| local name = tostring(key) | |
| if key ~= nil then | |
| rawset(self, key, name) | |
| end | |
| return name | |
| end | |
| }) | |
| ------------------------------------------------------------------------------- | |
| -- Simple set implementation | |
| -- adds a FIFO queue for each socket in the set | |
| ------------------------------------------------------------------------------- | |
| local function newsocketset() | |
| local set = {} | |
| do -- set implementation | |
| local reverse = {} | |
| -- Adds a socket to the set, does nothing if it exists | |
| -- @return skt if added, or nil if it existed | |
| function set:insert(skt) | |
| if not reverse[skt] then | |
| self[#self + 1] = skt | |
| reverse[skt] = #self | |
| return skt | |
| end | |
| end | |
| -- Removes socket from the set, does nothing if not found | |
| -- @return skt if removed, or nil if it wasn't in the set | |
| function set:remove(skt) | |
| local index = reverse[skt] | |
| if index then | |
| reverse[skt] = nil | |
| local top = self[#self] | |
| self[#self] = nil | |
| if top ~= skt then | |
| reverse[top] = index | |
| self[index] = top | |
| end | |
| return skt | |
| end | |
| end | |
| end | |
| do -- queues implementation | |
| local fifo_queues = setmetatable({},{ | |
| __mode = "k", -- auto collect queue if socket is gone | |
| __index = function(self, skt) -- auto create fifo queue if not found | |
| local newfifo = {} | |
| self[skt] = newfifo | |
| return newfifo | |
| end, | |
| }) | |
| -- pushes an item in the fifo queue for the socket. | |
| function set:push(skt, itm) | |
| local queue = fifo_queues[skt] | |
| queue[#queue + 1] = itm | |
| end | |
| -- pops an item from the fifo queue for the socket | |
| function set:pop(skt) | |
| local queue = fifo_queues[skt] | |
| return table.remove(queue, 1) | |
| end | |
| end | |
| return set | |
| end | |
| -- Threads immediately resumable | |
| local _resumable = {} do | |
| local resumelist = {} | |
| function _resumable:push(co) | |
| resumelist[#resumelist + 1] = co | |
| end | |
| function _resumable:clear_resumelist() | |
| local lst = resumelist | |
| resumelist = {} | |
| return lst | |
| end | |
| function _resumable:done() | |
| return resumelist[1] == nil | |
| end | |
| function _resumable:count() | |
| return #resumelist + #_resumable | |
| end | |
| end | |
| -- Similar to the socket set above, but tailored for the use of | |
| -- sleeping threads | |
| local _sleeping = {} do | |
| local heap = binaryheap.minUnique() | |
| local lethargy = setmetatable({}, { __mode = "k" }) -- list of coroutines sleeping without a wakeup time | |
| -- Required base implementation | |
| ----------------------------------------- | |
| _sleeping.insert = fnil | |
| _sleeping.remove = fnil | |
| -- push a new timer on the heap | |
| function _sleeping:push(sleeptime, co) | |
| if sleeptime < 0 then | |
| lethargy[co] = true | |
| elseif sleeptime == 0 then | |
| _resumable:push(co) | |
| else | |
| heap:insert(gettime() + sleeptime, co) | |
| end | |
| end | |
| -- find the thread that should wake up to the time, if any | |
| function _sleeping:pop(time) | |
| if time < (heap:peekValue() or math.huge) then | |
| return | |
| end | |
| return heap:pop() | |
| end | |
| -- additional methods for time management | |
| ----------------------------------------- | |
| function _sleeping:getnext() -- returns delay until next sleep expires, or nil if there is none | |
| local t = heap:peekValue() | |
| if t then | |
| -- never report less than 0, because select() might block | |
| return math.max(t - gettime(), 0) | |
| end | |
| end | |
| function _sleeping:wakeup(co) | |
| if lethargy[co] then | |
| lethargy[co] = nil | |
| _resumable:push(co) | |
| return | |
| end | |
| if heap:remove(co) then | |
| _resumable:push(co) | |
| end | |
| end | |
| function _sleeping:cancel(co) | |
| lethargy[co] = nil | |
| heap:remove(co) | |
| end | |
| -- @param tos number of timeouts running | |
| function _sleeping:done(tos) | |
| -- return true if we have nothing more to do | |
| -- the timeout task doesn't qualify as work (fallbacks only), | |
| -- the lethargy also doesn't qualify as work ('dead' tasks), | |
| -- but the combination of a timeout + a lethargy can be work | |
| return heap:size() == 1 -- 1 means only the timeout-timer task is running | |
| and not (tos > 0 and next(lethargy)) | |
| end | |
| -- gets number of threads in binaryheap and lethargy | |
| function _sleeping:status() | |
| local c = 0 | |
| for _ in pairs(lethargy) do c = c + 1 end | |
| return heap:size(), c | |
| end | |
| end -- _sleeping | |
| ------------------------------------------------------------------------------- | |
| -- Tracking coroutines and sockets | |
| ------------------------------------------------------------------------------- | |
| local _servers = newsocketset() -- servers being handled | |
| local _threads = setmetatable({}, {__mode = "k"}) -- registered threads added with addthread() | |
| local _canceled = setmetatable({}, {__mode = "k"}) -- threads that are canceled and pending removal | |
| local _autoclose = setmetatable({}, {__mode = "kv"}) -- sockets (value) to close when a thread (key) exits | |
| local _autoclose_r = setmetatable({}, {__mode = "kv"}) -- reverse: sockets (key) to close when a thread (value) exits | |
| -- for each socket we log the last read and last write times to enable the | |
| -- watchdog to follow up if it takes too long. | |
| -- tables contain the time, indexed by the socket | |
| local _reading_log = {} | |
| local _writing_log = {} | |
| local _closed = {} -- track sockets that have been closed (list/array) | |
| local _reading = newsocketset() -- sockets currently being read | |
| local _writing = newsocketset() -- sockets currently being written | |
| local _isSocketTimeout = { -- set of errors indicating a socket-timeout | |
| ["timeout"] = true, -- default LuaSocket timeout | |
| ["wantread"] = true, -- LuaSec specific timeout | |
| ["wantwrite"] = true, -- LuaSec specific timeout | |
| } | |
| ------------------------------------------------------------------------------- | |
| -- Coroutine based socket timeouts. | |
| ------------------------------------------------------------------------------- | |
| local user_timeouts_connect | |
| local user_timeouts_send | |
| local user_timeouts_receive | |
| do | |
| local timeout_mt = { | |
| __mode = "k", | |
| __index = function(self, skt) | |
| -- if there is no timeout found, we insert one automatically, to block forever | |
| self[skt] = math.huge | |
| return self[skt] | |
| end, | |
| } | |
| user_timeouts_connect = setmetatable({}, timeout_mt) | |
| user_timeouts_send = setmetatable({}, timeout_mt) | |
| user_timeouts_receive = setmetatable({}, timeout_mt) | |
| end | |
| local useSocketTimeoutErrors = setmetatable({},{ __mode = "k" }) | |
| -- sto = socket-time-out | |
| local sto_timeout, sto_timed_out, sto_change_queue, sto_error do | |
| local socket_register = setmetatable({}, { __mode = "k" }) -- socket by coroutine | |
| local operation_register = setmetatable({}, { __mode = "k" }) -- operation "read"/"write" by coroutine | |
| local timeout_flags = setmetatable({}, { __mode = "k" }) -- true if timedout, by coroutine | |
| local function socket_callback(co) | |
| local skt = socket_register[co] | |
| local queue = operation_register[co] | |
| -- flag the timeout and resume the coroutine | |
| timeout_flags[co] = true | |
| _resumable:push(co) | |
| -- clear the socket from the current queue | |
| if queue == "read" then | |
| _reading:remove(skt) | |
| elseif queue == "write" then | |
| _writing:remove(skt) | |
| else | |
| error("bad queue name; expected 'read'/'write', got: "..tostring(queue)) | |
| end | |
| end | |
| -- Sets a socket timeout. | |
| -- Calling it as `sto_timeout()` will cancel the timeout. | |
| -- @param queue (string) the queue the socket is currently in, must be either "read" or "write" | |
| -- @param skt (socket) the socket on which to operate | |
| -- @param use_connect_to (bool) timeout to use is determined based on queue (read/write) or if this | |
| -- is truthy, it is the connect timeout. | |
| -- @return true | |
| function sto_timeout(skt, queue, use_connect_to) | |
| local co = coroutine_running() | |
| socket_register[co] = skt | |
| operation_register[co] = queue | |
| timeout_flags[co] = nil | |
| if skt then | |
| local to = (use_connect_to and user_timeouts_connect[skt]) or | |
| (queue == "read" and user_timeouts_receive[skt]) or | |
| user_timeouts_send[skt] | |
| copas.timeout(to, socket_callback) | |
| else | |
| copas.timeout(0) | |
| end | |
| return true | |
| end | |
| -- Changes the timeout to a different queue (read/write). | |
| -- Only usefull with ssl-handshakes and "wantread", "wantwrite" errors, when | |
| -- the queue has to be changed, so the timeout handler knows where to find the socket. | |
| -- @param queue (string) the new queue the socket is in, must be either "read" or "write" | |
| -- @return true | |
| function sto_change_queue(queue) | |
| operation_register[coroutine_running()] = queue | |
| return true | |
| end | |
| -- Responds with `true` if the operation timed-out. | |
| function sto_timed_out() | |
| return timeout_flags[coroutine_running()] | |
| end | |
| -- Returns the proper timeout error | |
| function sto_error(err) | |
| return useSocketTimeoutErrors[coroutine_running()] and err or "timeout" | |
| end | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Coroutine based socket I/O functions. | |
| ------------------------------------------------------------------------------- | |
| -- Returns "tcp"" for plain TCP and "ssl" for ssl-wrapped sockets, so truthy | |
| -- for tcp based, and falsy for udp based. | |
| local isTCP do | |
| local lookup = { | |
| tcp = "tcp", | |
| SSL = "ssl", | |
| } | |
| function isTCP(socket) | |
| return lookup[tostring(socket):sub(1,3)] | |
| end | |
| end | |
| function copas.close(skt, ...) | |
| _closed[#_closed+1] = skt | |
| return skt:close(...) | |
| end | |
| -- nil or negative is indefinitly | |
| function copas.settimeout(skt, timeout) | |
| timeout = timeout or -1 | |
| if type(timeout) ~= "number" then | |
| return nil, "timeout must be 'nil' or a number" | |
| end | |
| return copas.settimeouts(skt, timeout, timeout, timeout) | |
| end | |
| -- negative is indefinitly, nil means do not change | |
| function copas.settimeouts(skt, connect, send, read) | |
| if connect ~= nil and type(connect) ~= "number" then | |
| return nil, "connect timeout must be 'nil' or a number" | |
| end | |
| if connect then | |
| if connect < 0 then | |
| connect = nil | |
| end | |
| user_timeouts_connect[skt] = connect | |
| end | |
| if send ~= nil and type(send) ~= "number" then | |
| return nil, "send timeout must be 'nil' or a number" | |
| end | |
| if send then | |
| if send < 0 then | |
| send = nil | |
| end | |
| user_timeouts_send[skt] = send | |
| end | |
| if read ~= nil and type(read) ~= "number" then | |
| return nil, "read timeout must be 'nil' or a number" | |
| end | |
| if read then | |
| if read < 0 then | |
| read = nil | |
| end | |
| user_timeouts_receive[skt] = read | |
| end | |
| return true | |
| end | |
| -- reads a pattern from a client and yields to the reading set on timeouts | |
| -- UDP: a UDP socket expects a second argument to be a number, so it MUST | |
| -- be provided as the 'pattern' below defaults to a string. Will throw a | |
| -- 'bad argument' error if omitted. | |
| function copas.receive(client, pattern, part) | |
| local s, err | |
| pattern = pattern or "*l" | |
| local current_log = _reading_log | |
| sto_timeout(client, "read") | |
| repeat | |
| s, err, part = client:receive(pattern, part) | |
| -- guarantees that high throughput doesn't take other threads to starvation | |
| if (math.random(100) > 90) then | |
| copas.pause() | |
| end | |
| if s then | |
| current_log[client] = nil | |
| sto_timeout() | |
| return s, err, part | |
| elseif not _isSocketTimeout[err] then | |
| current_log[client] = nil | |
| sto_timeout() | |
| return s, err, part | |
| elseif sto_timed_out() then | |
| current_log[client] = nil | |
| return nil, sto_error(err), part | |
| end | |
| if err == "wantwrite" then -- wantwrite may be returned during SSL renegotiations | |
| current_log = _writing_log | |
| current_log[client] = gettime() | |
| sto_change_queue("write") | |
| coroutine_yield(client, _writing) | |
| else | |
| current_log = _reading_log | |
| current_log[client] = gettime() | |
| sto_change_queue("read") | |
| coroutine_yield(client, _reading) | |
| end | |
| until false | |
| end | |
| -- receives data from a client over UDP. Not available for TCP. | |
| -- (this is a copy of receive() method, adapted for receivefrom() use) | |
| function copas.receivefrom(client, size) | |
| local s, err, port | |
| size = size or UDP_DATAGRAM_MAX | |
| sto_timeout(client, "read") | |
| repeat | |
| s, err, port = client:receivefrom(size) -- upon success err holds ip address | |
| -- garantees that high throughput doesn't take other threads to starvation | |
| if (math.random(100) > 90) then | |
| copas.pause() | |
| end | |
| if s then | |
| _reading_log[client] = nil | |
| sto_timeout() | |
| return s, err, port | |
| elseif err ~= "timeout" then | |
| _reading_log[client] = nil | |
| sto_timeout() | |
| return s, err, port | |
| elseif sto_timed_out() then | |
| _reading_log[client] = nil | |
| return nil, sto_error(err), port | |
| end | |
| _reading_log[client] = gettime() | |
| coroutine_yield(client, _reading) | |
| until false | |
| end | |
| -- same as above but with special treatment when reading chunks, | |
| -- unblocks on any data received. | |
| function copas.receivepartial(client, pattern, part) | |
| local s, err | |
| pattern = pattern or "*l" | |
| local orig_size = #(part or "") | |
| local current_log = _reading_log | |
| sto_timeout(client, "read") | |
| repeat | |
| s, err, part = client:receive(pattern, part) | |
| -- guarantees that high throughput doesn't take other threads to starvation | |
| if (math.random(100) > 90) then | |
| copas.pause() | |
| end | |
| if s or (type(part) == "string" and #part > orig_size) then | |
| current_log[client] = nil | |
| sto_timeout() | |
| return s, err, part | |
| elseif not _isSocketTimeout[err] then | |
| current_log[client] = nil | |
| sto_timeout() | |
| return s, err, part | |
| elseif sto_timed_out() then | |
| current_log[client] = nil | |
| return nil, sto_error(err), part | |
| end | |
| if err == "wantwrite" then | |
| current_log = _writing_log | |
| current_log[client] = gettime() | |
| sto_change_queue("write") | |
| coroutine_yield(client, _writing) | |
| else | |
| current_log = _reading_log | |
| current_log[client] = gettime() | |
| sto_change_queue("read") | |
| coroutine_yield(client, _reading) | |
| end | |
| until false | |
| end | |
| copas.receivePartial = copas.receivepartial -- compat: receivePartial is deprecated | |
| -- sends data to a client. The operation is buffered and | |
| -- yields to the writing set on timeouts | |
| -- Note: from and to parameters will be ignored by/for UDP sockets | |
| function copas.send(client, data, from, to) | |
| local s, err | |
| from = from or 1 | |
| local lastIndex = from - 1 | |
| local current_log = _writing_log | |
| sto_timeout(client, "write") | |
| repeat | |
| s, err, lastIndex = client:send(data, lastIndex + 1, to) | |
| -- guarantees that high throughput doesn't take other threads to starvation | |
| if (math.random(100) > 90) then | |
| copas.pause() | |
| end | |
| if s then | |
| current_log[client] = nil | |
| sto_timeout() | |
| return s, err, lastIndex | |
| elseif not _isSocketTimeout[err] then | |
| current_log[client] = nil | |
| sto_timeout() | |
| return s, err, lastIndex | |
| elseif sto_timed_out() then | |
| current_log[client] = nil | |
| return nil, sto_error(err), lastIndex | |
| end | |
| if err == "wantread" then | |
| current_log = _reading_log | |
| current_log[client] = gettime() | |
| sto_change_queue("read") | |
| coroutine_yield(client, _reading) | |
| else | |
| current_log = _writing_log | |
| current_log[client] = gettime() | |
| sto_change_queue("write") | |
| coroutine_yield(client, _writing) | |
| end | |
| until false | |
| end | |
| function copas.sendto(client, data, ip, port) | |
| -- deprecated; for backward compatibility only, since UDP doesn't block on sending | |
| return client:sendto(data, ip, port) | |
| end | |
| -- waits until connection is completed | |
| function copas.connect(skt, host, port) | |
| skt:settimeout(0) | |
| local ret, err, tried_more_than_once | |
| sto_timeout(skt, "write", true) | |
| repeat | |
| ret, err = skt:connect(host, port) | |
| -- non-blocking connect on Windows results in error "Operation already | |
| -- in progress" to indicate that it is completing the request async. So essentially | |
| -- it is the same as "timeout" | |
| if ret or (err ~= "timeout" and err ~= "Operation already in progress") then | |
| _writing_log[skt] = nil | |
| sto_timeout() | |
| -- Once the async connect completes, Windows returns the error "already connected" | |
| -- to indicate it is done, so that error should be ignored. Except when it is the | |
| -- first call to connect, then it was already connected to something else and the | |
| -- error should be returned | |
| if (not ret) and (err == "already connected" and tried_more_than_once) then | |
| return 1 | |
| end | |
| return ret, err | |
| elseif sto_timed_out() then | |
| _writing_log[skt] = nil | |
| return nil, sto_error(err) | |
| end | |
| tried_more_than_once = tried_more_than_once or true | |
| _writing_log[skt] = gettime() | |
| coroutine_yield(skt, _writing) | |
| until false | |
| end | |
| -- Wraps a tcp socket in an ssl socket and configures it. If the socket was | |
| -- already wrapped, it does nothing and returns the socket. | |
| -- @param wrap_params the parameters for the ssl-context | |
| -- @return wrapped socket, or throws an error | |
| local function ssl_wrap(skt, wrap_params) | |
| if isTCP(skt) == "ssl" then return skt end -- was already wrapped | |
| if not wrap_params then | |
| error("cannot wrap socket into a secure socket (using 'ssl.wrap()') without parameters/context") | |
| end | |
| ssl = ssl or require("ssl") | |
| local nskt = assert(ssl.wrap(skt, wrap_params)) -- assert, because we do not want to silently ignore this one!! | |
| nskt:settimeout(0) -- non-blocking on the ssl-socket | |
| copas.settimeouts(nskt, user_timeouts_connect[skt], | |
| user_timeouts_send[skt], user_timeouts_receive[skt]) -- copy copas user-timeout to newly wrapped one | |
| local co = _autoclose_r[skt] | |
| if co then | |
| -- socket registered for autoclose, move registration to wrapped one | |
| _autoclose[co] = nskt | |
| _autoclose_r[skt] = nil | |
| _autoclose_r[nskt] = co | |
| end | |
| local sock_name = object_names[skt] | |
| if sock_name ~= tostring(skt) then | |
| -- socket had a custom name, so copy it over | |
| object_names[nskt] = sock_name | |
| end | |
| return nskt | |
| end | |
| -- For each luasec method we have a subtable, allows for future extension. | |
| -- Required structure: | |
| -- { | |
| -- wrap = ... -- parameter to 'wrap()'; the ssl parameter table, or the context object | |
| -- sni = { -- parameters to 'sni()' | |
| -- names = string | table -- 1st parameter | |
| -- strict = bool -- 2nd parameter | |
| -- } | |
| -- } | |
| local function normalize_sslt(sslt) | |
| local t = type(sslt) | |
| local r = setmetatable({}, { | |
| __index = function(self, key) | |
| -- a bug if this happens, here as a sanity check, just being careful since | |
| -- this is security stuff | |
| error("accessing unknown 'ssl_params' table key: "..tostring(key)) | |
| end, | |
| }) | |
| if t == "nil" then | |
| r.wrap = false | |
| r.sni = false | |
| elseif t == "table" then | |
| if sslt.mode or sslt.protocol then | |
| -- has the mandatory fields for the ssl-params table for handshake | |
| -- backward compatibility | |
| r.wrap = sslt | |
| r.sni = false | |
| else | |
| -- has the target definition, copy our known keys | |
| r.wrap = sslt.wrap or false -- 'or false' because we do not want nils | |
| r.sni = sslt.sni or false -- 'or false' because we do not want nils | |
| end | |
| elseif t == "userdata" then | |
| -- it's an ssl-context object for the handshake | |
| -- backward compatibility | |
| r.wrap = sslt | |
| r.sni = false | |
| else | |
| error("ssl parameters; did not expect type "..tostring(sslt)) | |
| end | |
| return r | |
| end | |
| --- | |
| -- Peforms an (async) ssl handshake on a connected TCP client socket. | |
| -- NOTE: if not ssl-wrapped already, then replace all previous socket references, with the returned new ssl wrapped socket | |
| -- Throws error and does not return nil+error, as that might silently fail | |
| -- in code like this; | |
| -- copas.addserver(s1, function(skt) | |
| -- skt = copas.wrap(skt, sparams) | |
| -- skt:dohandshake() --> without explicit error checking, this fails silently and | |
| -- skt:send(body) --> continues unencrypted | |
| -- @param skt Regular LuaSocket CLIENT socket object | |
| -- @param wrap_params Table with ssl parameters | |
| -- @return wrapped ssl socket, or throws an error | |
| function copas.dohandshake(skt, wrap_params) | |
| ssl = ssl or require("ssl") | |
| local nskt = ssl_wrap(skt, wrap_params) | |
| sto_timeout(nskt, "write", true) | |
| local queue | |
| repeat | |
| local success, err = nskt:dohandshake() | |
| if success then | |
| sto_timeout() | |
| return nskt | |
| elseif not _isSocketTimeout[err] then | |
| sto_timeout() | |
| error("TLS/SSL handshake failed: " .. tostring(err)) | |
| elseif sto_timed_out() then | |
| return nil, sto_error(err) | |
| elseif err == "wantwrite" then | |
| sto_change_queue("write") | |
| queue = _writing | |
| elseif err == "wantread" then | |
| sto_change_queue("read") | |
| queue = _reading | |
| else | |
| error("TLS/SSL handshake failed: " .. tostring(err)) | |
| end | |
| coroutine_yield(nskt, queue) | |
| until false | |
| end | |
| -- flushes a client write buffer (deprecated) | |
| function copas.flush() | |
| end | |
| -- wraps a TCP socket to use Copas methods (send, receive, flush and settimeout) | |
| local _skt_mt_tcp = { | |
| __tostring = function(self) | |
| return tostring(self.socket).." (copas wrapped)" | |
| end, | |
| __index = { | |
| send = function (self, data, from, to) | |
| return copas.send (self.socket, data, from, to) | |
| end, | |
| receive = function (self, pattern, prefix) | |
| if user_timeouts_receive[self.socket] == 0 then | |
| return copas.receivepartial(self.socket, pattern, prefix) | |
| end | |
| return copas.receive(self.socket, pattern, prefix) | |
| end, | |
| receivepartial = function (self, pattern, prefix) | |
| return copas.receivepartial(self.socket, pattern, prefix) | |
| end, | |
| flush = function (self) | |
| return copas.flush(self.socket) | |
| end, | |
| settimeout = function (self, time) | |
| return copas.settimeout(self.socket, time) | |
| end, | |
| settimeouts = function (self, connect, send, receive) | |
| return copas.settimeouts(self.socket, connect, send, receive) | |
| end, | |
| -- TODO: socket.connect is a shortcut, and must be provided with an alternative | |
| -- if ssl parameters are available, it will also include a handshake | |
| connect = function(self, ...) | |
| local res, err = copas.connect(self.socket, ...) | |
| if res then | |
| if self.ssl_params.sni then self:sni() end | |
| if self.ssl_params.wrap then res, err = self:dohandshake() end | |
| end | |
| return res, err | |
| end, | |
| close = function(self, ...) | |
| return copas.close(self.socket, ...) | |
| end, | |
| -- TODO: socket.bind is a shortcut, and must be provided with an alternative | |
| bind = function(self, ...) return self.socket:bind(...) end, | |
| -- TODO: is this DNS related? hence blocking? | |
| getsockname = function(self, ...) | |
| local ok, ip, port, family = pcall(self.socket.getsockname, self.socket, ...) | |
| if ok then | |
| return ip, port, family | |
| else | |
| return nil, "not implemented by LuaSec" | |
| end | |
| end, | |
| getstats = function(self, ...) return self.socket:getstats(...) end, | |
| setstats = function(self, ...) return self.socket:setstats(...) end, | |
| listen = function(self, ...) return self.socket:listen(...) end, | |
| accept = function(self, ...) return self.socket:accept(...) end, | |
| setoption = function(self, ...) | |
| local ok, res, err = pcall(self.socket.setoption, self.socket, ...) | |
| if ok then | |
| return res, err | |
| else | |
| return nil, "not implemented by LuaSec" | |
| end | |
| end, | |
| getoption = function(self, ...) | |
| local ok, val, err = pcall(self.socket.getoption, self.socket, ...) | |
| if ok then | |
| return val, err | |
| else | |
| return nil, "not implemented by LuaSec" | |
| end | |
| end, | |
| -- TODO: is this DNS related? hence blocking? | |
| getpeername = function(self, ...) | |
| local ok, ip, port, family = pcall(self.socket.getpeername, self.socket, ...) | |
| if ok then | |
| return ip, port, family | |
| else | |
| return nil, "not implemented by LuaSec" | |
| end | |
| end, | |
| shutdown = function(self, ...) return self.socket:shutdown(...) end, | |
| sni = function(self, names, strict) | |
| local sslp = self.ssl_params | |
| self.socket = ssl_wrap(self.socket, sslp.wrap) | |
| if names == nil then | |
| names = sslp.sni.names | |
| strict = sslp.sni.strict | |
| end | |
| return self.socket:sni(names, strict) | |
| end, | |
| dohandshake = function(self, wrap_params) | |
| local nskt, err = copas.dohandshake(self.socket, wrap_params or self.ssl_params.wrap) | |
| if not nskt then return nskt, err end | |
| self.socket = nskt -- replace internal socket with the newly wrapped ssl one | |
| return self | |
| end, | |
| getalpn = function(self, ...) | |
| local ok, proto, err = pcall(self.socket.getalpn, self.socket, ...) | |
| if ok then | |
| return proto, err | |
| else | |
| return nil, "not a tls socket" | |
| end | |
| end, | |
| getsniname = function(self, ...) | |
| local ok, name, err = pcall(self.socket.getsniname, self.socket, ...) | |
| if ok then | |
| return name, err | |
| else | |
| return nil, "not a tls socket" | |
| end | |
| end, | |
| } | |
| } | |
| -- wraps a UDP socket, copy of TCP one adapted for UDP. | |
| local _skt_mt_udp = {__index = { }} | |
| for k,v in pairs(_skt_mt_tcp) do _skt_mt_udp[k] = _skt_mt_udp[k] or v end | |
| for k,v in pairs(_skt_mt_tcp.__index) do _skt_mt_udp.__index[k] = v end | |
| _skt_mt_udp.__index.send = function(self, ...) return self.socket:send(...) end | |
| _skt_mt_udp.__index.sendto = function(self, ...) return self.socket:sendto(...) end | |
| _skt_mt_udp.__index.receive = function (self, size) | |
| return copas.receive (self.socket, (size or UDP_DATAGRAM_MAX)) | |
| end | |
| _skt_mt_udp.__index.receivefrom = function (self, size) | |
| return copas.receivefrom (self.socket, (size or UDP_DATAGRAM_MAX)) | |
| end | |
| -- TODO: is this DNS related? hence blocking? | |
| _skt_mt_udp.__index.setpeername = function(self, ...) return self.socket:setpeername(...) end | |
| _skt_mt_udp.__index.setsockname = function(self, ...) return self.socket:setsockname(...) end | |
| -- do not close client, as it is also the server for udp. | |
| _skt_mt_udp.__index.close = function(self, ...) return true end | |
| _skt_mt_udp.__index.settimeouts = function (self, connect, send, receive) | |
| return copas.settimeouts(self.socket, connect, send, receive) | |
| end | |
| --- | |
| -- Wraps a LuaSocket socket object in an async Copas based socket object. | |
| -- @param skt The socket to wrap | |
| -- @sslt (optional) Table with ssl parameters, use an empty table to use ssl with defaults | |
| -- @return wrapped socket object | |
| function copas.wrap (skt, sslt) | |
| if (getmetatable(skt) == _skt_mt_tcp) or (getmetatable(skt) == _skt_mt_udp) then | |
| return skt -- already wrapped | |
| end | |
| skt:settimeout(0) | |
| if isTCP(skt) then | |
| return setmetatable ({socket = skt, ssl_params = normalize_sslt(sslt)}, _skt_mt_tcp) | |
| else | |
| return setmetatable ({socket = skt}, _skt_mt_udp) | |
| end | |
| end | |
| --- Wraps a handler in a function that deals with wrapping the socket and doing the | |
| -- optional ssl handshake. | |
| function copas.handler(handler, sslparams) | |
| -- TODO: pass a timeout value to set, and use during handshake | |
| return function (skt, ...) | |
| skt = copas.wrap(skt, sslparams) -- this call will normalize the sslparams table | |
| local sslp = skt.ssl_params | |
| if sslp.sni then skt:sni(sslp.sni.names, sslp.sni.strict) end | |
| if sslp.wrap then skt:dohandshake(sslp.wrap) end | |
| return handler(skt, ...) | |
| end | |
| end | |
| -------------------------------------------------- | |
| -- Error handling | |
| -------------------------------------------------- | |
| local _errhandlers = setmetatable({}, { __mode = "k" }) -- error handler per coroutine | |
| function copas.gettraceback(msg, co, skt) | |
| local co_str = co == nil and "nil" or copas.getthreadname(co) | |
| local skt_str = skt == nil and "nil" or copas.getsocketname(skt) | |
| local msg_str = msg == nil and "" or tostring(msg) | |
| if msg_str == "" then | |
| msg_str = ("(coroutine: %s, socket: %s)"):format(msg_str, co_str, skt_str) | |
| else | |
| msg_str = ("%s (coroutine: %s, socket: %s)"):format(msg_str, co_str, skt_str) | |
| end | |
| if type(co) == "thread" then | |
| -- regular Copas coroutine | |
| return debug.traceback(co, msg_str) | |
| end | |
| -- not a coroutine, but the main thread, this happens if a timeout callback | |
| -- (see `copas.timeout` causes an error (those callbacks run on the main thread). | |
| return debug.traceback(msg_str, 2) | |
| end | |
| local function _deferror(msg, co, skt) | |
| print(copas.gettraceback(msg, co, skt)) | |
| end | |
| function copas.seterrorhandler(err, default) | |
| assert(err == nil or type(err) == "function", "Expected the handler to be a function, or nil") | |
| if default then | |
| assert(err ~= nil, "Expected the handler to be a function when setting the default") | |
| _deferror = err | |
| else | |
| _errhandlers[coroutine_running()] = err | |
| end | |
| end | |
| copas.setErrorHandler = copas.seterrorhandler -- deprecated; old casing | |
| function copas.geterrorhandler(co) | |
| co = co or coroutine_running() | |
| return _errhandlers[co] or _deferror | |
| end | |
| -- if `bool` is truthy, then the original socket errors will be returned in case of timeouts; | |
| -- `timeout, wantread, wantwrite, Operation already in progress`. If falsy, it will always | |
| -- return `timeout`. | |
| function copas.useSocketTimeoutErrors(bool) | |
| useSocketTimeoutErrors[coroutine_running()] = not not bool -- force to a boolean | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Thread handling | |
| ------------------------------------------------------------------------------- | |
| local function _doTick (co, skt, ...) | |
| if not co then return end | |
| -- if a coroutine was canceled/removed, don't resume it | |
| if _canceled[co] then | |
| _canceled[co] = nil -- also clean up the registry | |
| _threads[co] = nil | |
| return | |
| end | |
| -- res: the socket (being read/write on) or the time to sleep | |
| -- new_q: either _writing, _reading, or _sleeping | |
| -- local time_before = gettime() | |
| local ok, res, new_q = coroutine_resume(co, skt, ...) | |
| -- local duration = gettime() - time_before | |
| -- if duration > 1 then | |
| -- duration = math.floor(duration * 1000) | |
| -- pcall(_errhandlers[co] or _deferror, "task ran for "..tostring(duration).." milliseconds.", co, skt) | |
| -- end | |
| if new_q == _reading or new_q == _writing or new_q == _sleeping then | |
| -- we're yielding to a new queue | |
| new_q:insert (res) | |
| new_q:push (res, co) | |
| return | |
| end | |
| -- coroutine is terminating | |
| if ok and coroutine_status(co) ~= "dead" then | |
| -- it called coroutine.yield from a non-Copas function which is unexpected | |
| ok = false | |
| res = "coroutine.yield was called without a resume first, user-code cannot yield to Copas" | |
| end | |
| if not ok then | |
| local k, e = pcall(_errhandlers[co] or _deferror, res, co, skt) | |
| if not k then | |
| print("Failed executing error handler: " .. tostring(e)) | |
| end | |
| end | |
| local skt_to_close = _autoclose[co] | |
| if skt_to_close then | |
| skt_to_close:close() | |
| _autoclose[co] = nil | |
| _autoclose_r[skt_to_close] = nil | |
| end | |
| _errhandlers[co] = nil | |
| end | |
| local _accept do | |
| local client_counters = setmetatable({}, { __mode = "k" }) | |
| -- accepts a connection on socket input | |
| function _accept(server_skt, handler) | |
| local client_skt = server_skt:accept() | |
| if client_skt then | |
| local count = (client_counters[server_skt] or 0) + 1 | |
| client_counters[server_skt] = count | |
| object_names[client_skt] = object_names[server_skt] .. ":client_" .. count | |
| client_skt:settimeout(0) | |
| copas.settimeouts(client_skt, user_timeouts_connect[server_skt], -- copy server socket timeout settings | |
| user_timeouts_send[server_skt], user_timeouts_receive[server_skt]) | |
| local co = coroutine_create(handler) | |
| object_names[co] = object_names[server_skt] .. ":handler_" .. count | |
| if copas.autoclose then | |
| _autoclose[co] = client_skt | |
| _autoclose_r[client_skt] = co | |
| end | |
| _doTick(co, client_skt) | |
| end | |
| end | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Adds a server/handler pair to Copas dispatcher | |
| ------------------------------------------------------------------------------- | |
| do | |
| local function addTCPserver(server, handler, timeout, name) | |
| server:settimeout(0) | |
| if name then | |
| object_names[server] = name | |
| end | |
| _servers[server] = handler | |
| _reading:insert(server) | |
| if timeout then | |
| copas.settimeout(server, timeout) | |
| end | |
| end | |
| local function addUDPserver(server, handler, timeout, name) | |
| server:settimeout(0) | |
| local co = coroutine_create(handler) | |
| if name then | |
| object_names[server] = name | |
| end | |
| object_names[co] = object_names[server]..":handler" | |
| _reading:insert(server) | |
| if timeout then | |
| copas.settimeout(server, timeout) | |
| end | |
| _doTick(co, server) | |
| end | |
| function copas.addserver(server, handler, timeout, name) | |
| if isTCP(server) then | |
| addTCPserver(server, handler, timeout, name) | |
| else | |
| addUDPserver(server, handler, timeout, name) | |
| end | |
| end | |
| end | |
| function copas.removeserver(server, keep_open) | |
| local skt = server | |
| local mt = getmetatable(server) | |
| if mt == _skt_mt_tcp or mt == _skt_mt_udp then | |
| skt = server.socket | |
| end | |
| _servers:remove(skt) | |
| _reading:remove(skt) | |
| if keep_open then | |
| return true | |
| end | |
| return server:close() | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Adds an new coroutine thread to Copas dispatcher | |
| ------------------------------------------------------------------------------- | |
| function copas.addnamedthread(name, handler, ...) | |
| if type(name) == "function" and type(handler) == "string" then | |
| -- old call, flip args for compatibility | |
| name, handler = handler, name | |
| end | |
| -- create a coroutine that skips the first argument, which is always the socket | |
| -- passed by the scheduler, but `nil` in case of a task/thread | |
| local thread = coroutine_create(function(_, ...) | |
| copas.pause() | |
| return handler(...) | |
| end) | |
| if name then | |
| object_names[thread] = name | |
| end | |
| _threads[thread] = true -- register this thread so it can be removed | |
| _doTick (thread, nil, ...) | |
| return thread | |
| end | |
| function copas.addthread(handler, ...) | |
| return copas.addnamedthread(nil, handler, ...) | |
| end | |
| function copas.removethread(thread) | |
| -- if the specified coroutine is registered, add it to the canceled table so | |
| -- that next time it tries to resume it exits. | |
| _canceled[thread] = _threads[thread or 0] | |
| _sleeping:cancel(thread) | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Sleep/pause management functions | |
| ------------------------------------------------------------------------------- | |
| -- yields the current coroutine and wakes it after 'sleeptime' seconds. | |
| -- If sleeptime < 0 then it sleeps until explicitly woken up using 'wakeup' | |
| -- TODO: deprecated, remove in next major | |
| function copas.sleep(sleeptime) | |
| coroutine_yield((sleeptime or 0), _sleeping) | |
| end | |
| -- yields the current coroutine and wakes it after 'sleeptime' seconds. | |
| -- if sleeptime < 0 then it sleeps 0 seconds. | |
| function copas.pause(sleeptime) | |
| if sleeptime and sleeptime > 0 then | |
| coroutine_yield(sleeptime, _sleeping) | |
| else | |
| coroutine_yield(0, _sleeping) | |
| end | |
| end | |
| -- yields the current coroutine until explicitly woken up using 'wakeup' | |
| function copas.pauseforever() | |
| coroutine_yield(-1, _sleeping) | |
| end | |
| -- Wakes up a sleeping coroutine 'co'. | |
| function copas.wakeup(co) | |
| _sleeping:wakeup(co) | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Timeout management | |
| ------------------------------------------------------------------------------- | |
| do | |
| local timeout_register = setmetatable({}, { __mode = "k" }) | |
| local time_out_thread | |
| local timerwheel = require("timerwheel").new({ | |
| precision = TIMEOUT_PRECISION, | |
| ringsize = math.floor(60*60*24/TIMEOUT_PRECISION), -- ring size 1 day | |
| err_handler = function(err) | |
| return _deferror(err, time_out_thread) | |
| end, | |
| }) | |
| time_out_thread = copas.addnamedthread("copas_core_timer", function() | |
| while true do | |
| copas.pause(TIMEOUT_PRECISION) | |
| timerwheel:step() | |
| end | |
| end) | |
| -- get the number of timeouts running | |
| function copas.gettimeouts() | |
| return timerwheel:count() | |
| end | |
| --- Sets the timeout for the current coroutine. | |
| -- @param delay delay (seconds), use 0 (or math.huge) to cancel the timerout | |
| -- @param callback function with signature: `function(coroutine)` where coroutine is the routine that timed-out | |
| -- @return true | |
| function copas.timeout(delay, callback) | |
| local co = coroutine_running() | |
| local existing_timer = timeout_register[co] | |
| if existing_timer then | |
| timerwheel:cancel(existing_timer) | |
| end | |
| if delay > 0 and delay ~= math.huge then | |
| timeout_register[co] = timerwheel:set(delay, callback, co) | |
| elseif delay == 0 or delay == math.huge then | |
| timeout_register[co] = nil | |
| else | |
| error("timout value must be greater than or equal to 0, got: "..tostring(delay)) | |
| end | |
| return true | |
| end | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- main tasks: manage readable and writable socket sets | |
| ------------------------------------------------------------------------------- | |
| -- a task is an object with a required method `step()` that deals with a | |
| -- single step for that task. | |
| local _tasks = {} do | |
| function _tasks:add(tsk) | |
| _tasks[#_tasks + 1] = tsk | |
| end | |
| end | |
| -- a task to check ready to read events | |
| local _readable_task = {} do | |
| local function tick(skt) | |
| local handler = _servers[skt] | |
| if handler then | |
| _accept(skt, handler) | |
| else | |
| _reading:remove(skt) | |
| _doTick(_reading:pop(skt), skt) | |
| end | |
| end | |
| function _readable_task:step() | |
| for _, skt in ipairs(self._events) do | |
| tick(skt) | |
| end | |
| end | |
| _tasks:add(_readable_task) | |
| end | |
| -- a task to check ready to write events | |
| local _writable_task = {} do | |
| local function tick(skt) | |
| _writing:remove(skt) | |
| _doTick(_writing:pop(skt), skt) | |
| end | |
| function _writable_task:step() | |
| for _, skt in ipairs(self._events) do | |
| tick(skt) | |
| end | |
| end | |
| _tasks:add(_writable_task) | |
| end | |
| -- sleeping threads task | |
| local _sleeping_task = {} do | |
| function _sleeping_task:step() | |
| local now = gettime() | |
| local co = _sleeping:pop(now) | |
| while co do | |
| -- we're pushing them to _resumable, since that list will be replaced before | |
| -- executing. This prevents tasks running twice in a row with pause(0) for example. | |
| -- So here we won't execute, but at _resumable step which is next | |
| _resumable:push(co) | |
| co = _sleeping:pop(now) | |
| end | |
| end | |
| _tasks:add(_sleeping_task) | |
| end | |
| -- resumable threads task | |
| local _resumable_task = {} do | |
| function _resumable_task:step() | |
| -- replace the resume list before iterating, so items placed in there | |
| -- will indeed end up in the next copas step, not in this one, and not | |
| -- create a loop | |
| local resumelist = _resumable:clear_resumelist() | |
| for _, co in ipairs(resumelist) do | |
| _doTick(co) | |
| end | |
| end | |
| _tasks:add(_resumable_task) | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Checks for reads and writes on sockets | |
| ------------------------------------------------------------------------------- | |
| local _select_plain do | |
| local last_cleansing = 0 | |
| local duration = function(t2, t1) return t2-t1 end | |
| _select_plain = function(timeout) | |
| local err | |
| local now = gettime() | |
| -- remove any closed sockets to prevent select from hanging on them | |
| if _closed[1] then | |
| for i, skt in ipairs(_closed) do | |
| _closed[i] = { _reading:remove(skt), _writing:remove(skt) } | |
| end | |
| end | |
| _readable_task._events, _writable_task._events, err = socket.select(_reading, _writing, timeout) | |
| local r_events, w_events = _readable_task._events, _writable_task._events | |
| -- inject closed sockets in readable/writeable task so they can error out properly | |
| if _closed[1] then | |
| for i, skts in ipairs(_closed) do | |
| _closed[i] = nil | |
| r_events[#r_events+1] = skts[1] | |
| w_events[#w_events+1] = skts[2] | |
| end | |
| end | |
| if duration(now, last_cleansing) > WATCH_DOG_TIMEOUT then | |
| last_cleansing = now | |
| -- Check all sockets selected for reading, and check how long they have been waiting | |
| -- for data already, without select returning them as readable | |
| for skt,time in pairs(_reading_log) do | |
| if not r_events[skt] and duration(now, time) > WATCH_DOG_TIMEOUT then | |
| -- This one timedout while waiting to become readable, so move | |
| -- it in the readable list and try and read anyway, despite not | |
| -- having been returned by select | |
| _reading_log[skt] = nil | |
| r_events[#r_events + 1] = skt | |
| r_events[skt] = #r_events | |
| end | |
| end | |
| -- Do the same for writing | |
| for skt,time in pairs(_writing_log) do | |
| if not w_events[skt] and duration(now, time) > WATCH_DOG_TIMEOUT then | |
| _writing_log[skt] = nil | |
| w_events[#w_events + 1] = skt | |
| w_events[skt] = #w_events | |
| end | |
| end | |
| end | |
| if err == "timeout" and #r_events + #w_events > 0 then | |
| return nil | |
| else | |
| return err | |
| end | |
| end | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Dispatcher loop step. | |
| -- Listen to client requests and handles them | |
| -- Returns false if no socket-data was handled, or true if there was data | |
| -- handled (or nil + error message) | |
| ------------------------------------------------------------------------------- | |
| local copas_stats | |
| local min_ever, max_ever | |
| local _select = _select_plain | |
| -- instrumented version of _select() to collect stats | |
| local _select_instrumented = function(timeout) | |
| if copas_stats then | |
| local step_duration = gettime() - copas_stats.step_start | |
| copas_stats.duration_max = math.max(copas_stats.duration_max, step_duration) | |
| copas_stats.duration_min = math.min(copas_stats.duration_min, step_duration) | |
| copas_stats.duration_tot = copas_stats.duration_tot + step_duration | |
| copas_stats.steps = copas_stats.steps + 1 | |
| else | |
| copas_stats = { | |
| duration_max = -1, | |
| duration_min = 999999, | |
| duration_tot = 0, | |
| steps = 0, | |
| } | |
| end | |
| local err = _select_plain(timeout) | |
| local now = gettime() | |
| copas_stats.time_start = copas_stats.time_start or now | |
| copas_stats.step_start = now | |
| return err | |
| end | |
| function copas.step(timeout) | |
| -- Need to wake up the select call in time for the next sleeping event | |
| if not _resumable:done() then | |
| timeout = 0 | |
| else | |
| timeout = math.min(_sleeping:getnext(), timeout or math.huge) | |
| end | |
| local err = _select(timeout) | |
| for _, tsk in ipairs(_tasks) do | |
| tsk:step() | |
| end | |
| if err then | |
| if err == "timeout" then | |
| if timeout + 0.01 > TIMEOUT_PRECISION and math.random(100) > 90 then | |
| -- we were idle, so occasionally do a GC sweep to ensure lingering | |
| -- sockets are closed, and we don't accidentally block the loop from | |
| -- exiting | |
| collectgarbage() | |
| end | |
| return false | |
| end | |
| return nil, err | |
| end | |
| return true | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Check whether there is something to do. | |
| -- returns false if there are no sockets for read/write nor tasks scheduled | |
| -- (which means Copas is in an empty spin) | |
| ------------------------------------------------------------------------------- | |
| function copas.finished() | |
| return #_reading == 0 and #_writing == 0 and _resumable:done() and _sleeping:done(copas.gettimeouts()) | |
| end | |
| local _getstats do | |
| local _getstats_instrumented, _getstats_plain | |
| function _getstats_plain(enable) | |
| -- this function gets hit if turned off, so turn on if true | |
| if enable == true then | |
| _select = _select_instrumented | |
| _getstats = _getstats_instrumented | |
| -- reset stats | |
| min_ever = nil | |
| max_ever = nil | |
| copas_stats = nil | |
| end | |
| return {} | |
| end | |
| -- convert from seconds to millisecs, with microsec precision | |
| local function useconds(t) | |
| return math.floor((t * 1000000) + 0.5) / 1000 | |
| end | |
| -- convert from seconds to seconds, with millisec precision | |
| local function mseconds(t) | |
| return math.floor((t * 1000) + 0.5) / 1000 | |
| end | |
| function _getstats_instrumented(enable) | |
| if enable == false then | |
| _select = _select_plain | |
| _getstats = _getstats_plain | |
| -- instrumentation disabled, so switch to the plain implementation | |
| return _getstats(enable) | |
| end | |
| if (not copas_stats) or (copas_stats.step == 0) then | |
| return {} | |
| end | |
| local stats = copas_stats | |
| copas_stats = nil | |
| min_ever = math.min(min_ever or 9999999, stats.duration_min) | |
| max_ever = math.max(max_ever or 0, stats.duration_max) | |
| stats.duration_min_ever = min_ever | |
| stats.duration_max_ever = max_ever | |
| stats.duration_avg = stats.duration_tot / stats.steps | |
| stats.step_start = nil | |
| stats.time_end = gettime() | |
| stats.time_tot = stats.time_end - stats.time_start | |
| stats.time_avg = stats.time_tot / stats.steps | |
| stats.duration_avg = useconds(stats.duration_avg) | |
| stats.duration_max = useconds(stats.duration_max) | |
| stats.duration_max_ever = useconds(stats.duration_max_ever) | |
| stats.duration_min = useconds(stats.duration_min) | |
| stats.duration_min_ever = useconds(stats.duration_min_ever) | |
| stats.duration_tot = useconds(stats.duration_tot) | |
| stats.time_avg = useconds(stats.time_avg) | |
| stats.time_start = mseconds(stats.time_start) | |
| stats.time_end = mseconds(stats.time_end) | |
| stats.time_tot = mseconds(stats.time_tot) | |
| return stats | |
| end | |
| _getstats = _getstats_plain | |
| end | |
| function copas.status(enable_stats) | |
| local res = _getstats(enable_stats) | |
| res.running = not not copas.running | |
| res.timeout = copas.gettimeouts() | |
| res.timer, res.inactive = _sleeping:status() | |
| res.read = #_reading | |
| res.write = #_writing | |
| res.active = _resumable:count() | |
| return res | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Dispatcher endless loop. | |
| -- Listen to client requests and handles them forever | |
| ------------------------------------------------------------------------------- | |
| function copas.loop(initializer, timeout) | |
| if type(initializer) == "function" then | |
| copas.addnamedthread("copas_initializer", initializer) | |
| else | |
| timeout = initializer or timeout | |
| end | |
| copas.running = true | |
| while not copas.finished() do copas.step(timeout) end | |
| copas.running = false | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Naming sockets and coroutines. | |
| ------------------------------------------------------------------------------- | |
| do | |
| local function realsocket(skt) | |
| local mt = getmetatable(skt) | |
| if mt == _skt_mt_tcp or mt == _skt_mt_udp then | |
| return skt.socket | |
| else | |
| return skt | |
| end | |
| end | |
| function copas.setsocketname(name, skt) | |
| assert(type(name) == "string", "expected arg #1 to be a string") | |
| skt = assert(realsocket(skt), "expected arg #2 to be a socket") | |
| object_names[skt] = name | |
| end | |
| function copas.getsocketname(skt) | |
| skt = assert(realsocket(skt), "expected arg #1 to be a socket") | |
| return object_names[skt] | |
| end | |
| end | |
| function copas.setthreadname(name, coro) | |
| assert(type(name) == "string", "expected arg #1 to be a string") | |
| coro = coro or coroutine_running() | |
| assert(type(coro) == "thread", "expected arg #2 to be a coroutine or nil") | |
| object_names[coro] = name | |
| end | |
| function copas.getthreadname(coro) | |
| coro = coro or coroutine_running() | |
| assert(type(coro) == "thread", "expected arg #1 to be a coroutine or nil") | |
| return object_names[coro] | |
| end | |
| ------------------------------------------------------------------------------- | |
| -- Debug functionality. | |
| ------------------------------------------------------------------------------- | |
| do | |
| copas.debug = {} | |
| local log_core -- if truthy, the core-timer will also be logged | |
| local debug_log -- function used as logger | |
| local debug_yield = function(skt, queue) | |
| local name = object_names[coroutine_running()] | |
| if log_core or name ~= "copas_core_timer" then | |
| if queue == _sleeping then | |
| debug_log("yielding '", name, "' to SLEEP for ", skt," seconds") | |
| elseif queue == _writing then | |
| debug_log("yielding '", name, "' to WRITE on '", object_names[skt], "'") | |
| elseif queue == _reading then | |
| debug_log("yielding '", name, "' to READ on '", object_names[skt], "'") | |
| else | |
| debug_log("thread '", name, "' yielding to unexpected queue; ", tostring(queue), " (", type(queue), ")", debug.traceback()) | |
| end | |
| end | |
| return coroutine.yield(skt, queue) | |
| end | |
| local debug_resume = function(coro, skt, ...) | |
| local name = object_names[coro] | |
| if skt then | |
| debug_log("resuming '", name, "' for socket '", object_names[skt], "'") | |
| else | |
| if log_core or name ~= "copas_core_timer" then | |
| debug_log("resuming '", name, "'") | |
| end | |
| end | |
| return coroutine.resume(coro, skt, ...) | |
| end | |
| local debug_create = function(f) | |
| local f_wrapped = function(...) | |
| local results = pack(f(...)) | |
| debug_log("exiting '", object_names[coroutine_running()], "'") | |
| return unpack(results) | |
| end | |
| return coroutine.create(f_wrapped) | |
| end | |
| debug_log = fnil | |
| -- enables debug output for all coroutine operations. | |
| function copas.debug.start(logger, core) | |
| log_core = core | |
| debug_log = logger or print | |
| coroutine_yield = debug_yield | |
| coroutine_resume = debug_resume | |
| coroutine_create = debug_create | |
| end | |
| -- disables debug output for coroutine operations. | |
| function copas.debug.stop() | |
| debug_log = fnil | |
| coroutine_yield = coroutine.yield | |
| coroutine_resume = coroutine.resume | |
| coroutine_create = coroutine.create | |
| end | |
| do | |
| local call_id = 0 | |
| -- Description table of socket functions for debug output. | |
| -- each socket function name has TWO entries; | |
| -- 'name_in' and 'name_out', each being an array of names/descriptions of respectively | |
| -- input parameters and return values. | |
| -- If either table has a 'callback' key, then that is a function that will be called | |
| -- with the parameters/return-values for further inspection. | |
| local args = { | |
| settimeout_in = { | |
| "socket ", | |
| "seconds", | |
| "mode ", | |
| }, | |
| settimeout_out = { | |
| "success", | |
| "error ", | |
| }, | |
| connect_in = { | |
| "socket ", | |
| "address", | |
| "port ", | |
| }, | |
| connect_out = { | |
| "success", | |
| "error ", | |
| }, | |
| getfd_in = { | |
| "socket ", | |
| -- callback = function(...) | |
| -- print(debug.traceback("called from:", 4)) | |
| -- end, | |
| }, | |
| getfd_out = { | |
| "fd", | |
| }, | |
| send_in = { | |
| "socket ", | |
| "data ", | |
| "idx-start", | |
| "idx-end ", | |
| }, | |
| send_out = { | |
| "last-idx-send ", | |
| "error ", | |
| "err-last-idx-send", | |
| }, | |
| receive_in = { | |
| "socket ", | |
| "pattern", | |
| "prefix ", | |
| }, | |
| receive_out = { | |
| "received ", | |
| "error ", | |
| "partial data", | |
| }, | |
| dirty_in = { | |
| "socket", | |
| -- callback = function(...) | |
| -- print(debug.traceback("called from:", 4)) | |
| -- end, | |
| }, | |
| dirty_out = { | |
| "data in read-buffer", | |
| }, | |
| close_in = { | |
| "socket", | |
| -- callback = function(...) | |
| -- print(debug.traceback("called from:", 4)) | |
| -- end, | |
| }, | |
| close_out = { | |
| "success", | |
| "error", | |
| }, | |
| } | |
| local function print_call(func, msg, ...) | |
| print(msg) | |
| local arg = pack(...) | |
| local desc = args[func] or {} | |
| for i = 1, math.max(arg.n, #desc) do | |
| local value = arg[i] | |
| if type(value) == "string" then | |
| local xvalue = value:sub(1,30) | |
| if xvalue ~= value then | |
| xvalue = xvalue .."(...truncated)" | |
| end | |
| print("\t"..(desc[i] or i)..": '"..tostring(xvalue).."' ("..type(value).." #"..#value..")") | |
| else | |
| print("\t"..(desc[i] or i)..": '"..tostring(value).."' ("..type(value)..")") | |
| end | |
| end | |
| if desc.callback then | |
| desc.callback(...) | |
| end | |
| end | |
| local debug_mt = { | |
| __index = function(self, key) | |
| local value = self.__original_socket[key] | |
| if type(value) ~= "function" then | |
| return value | |
| end | |
| return function(self2, ...) | |
| local my_id = call_id + 1 | |
| call_id = my_id | |
| local results | |
| if self2 ~= self then | |
| -- there is no self | |
| print_call(tostring(key).."_in", my_id .. "-calling '"..tostring(key) .. "' with; ", self, ...) | |
| results = pack(value(self, ...)) | |
| else | |
| print_call(tostring(key).."_in", my_id .. "-calling '" .. tostring(key) .. "' with; ", self.__original_socket, ...) | |
| results = pack(value(self.__original_socket, ...)) | |
| end | |
| print_call(tostring(key).."_out", my_id .. "-results '"..tostring(key) .. "' returned; ", unpack(results)) | |
| return unpack(results) | |
| end | |
| end, | |
| __tostring = function(self) | |
| return tostring(self.__original_socket) | |
| end | |
| } | |
| -- wraps a socket (copas or luasocket) in a debug version printing all calls | |
| -- and their parameters/return values. Extremely noisy! | |
| -- returns the wrapped socket. | |
| -- NOTE: only for plain sockets, will not support TLS | |
| function copas.debug.socket(original_skt) | |
| if (getmetatable(original_skt) == _skt_mt_tcp) or (getmetatable(original_skt) == _skt_mt_udp) then | |
| -- already wrapped as Copas socket, so recurse with the original luasocket one | |
| original_skt.socket = copas.debug.socket(original_skt.socket) | |
| return original_skt | |
| end | |
| local proxy = setmetatable({ | |
| __original_socket = original_skt | |
| }, debug_mt) | |
| return proxy | |
| end | |
| end | |
| end | |
| return copas |