diff --git a/doc/interfaces/connection.md b/doc/interfaces/connection.md index 433a34bb..05b884ad 100644 --- a/doc/interfaces/connection.md +++ b/doc/interfaces/connection.md @@ -31,9 +31,14 @@ The HTTP version number of the connection as a number. Completes the connection to the remote server using the address specified, HTTP version and any options specified in the `connection.new` constructor. The `connect` function will yield until the connection attempt finishes (success or failure) or until `timeout` is exceeded. Connecting may include DNS lookups, TLS negotiation and HTTP2 settings exchange. Returns `true` on success. On error, returns `nil`, an error message and an error number. +### `connection:starttls(ctx, timeout)` {#connection:starttls} + +Starts a TLS (Transport Layer Security) negotiation with the remote server. `ctx` should be a luaossl SSL context or a luasec SSL context. Returns `true` on success. On error, returns `nil` an error message and an error number. + + ### `connection:checktls()` {#connection:checktls} -Checks the socket for a valid Transport Layer Security connection. Returns the luaossl ssl object if the connection is secured. Returns `nil` and an error message if there is no active TLS session. Please see the [luaossl website](http://25thandclement.com/~william/projects/luaossl.html) for more information about the ssl object. +Checks the socket for a valid TLS (Transport Layer Security) connection. Returns the luaossl ssl object if the connection is secured. Returns `nil` and an error message if there is no active TLS session. Please see the [luaossl website](http://25thandclement.com/~william/projects/luaossl.html) for more information about the ssl object. ### `connection:localname()` {#connection:localname} diff --git a/http/client.lua b/http/client.lua index 5000d675..386714ca 100644 --- a/http/client.lua +++ b/http/client.lua @@ -2,9 +2,8 @@ local ca = require "cqueues.auxlib" local cs = require "cqueues.socket" local http_tls = require "http.tls" local connection_common = require "http.connection_common" -local onerror = connection_common.onerror -local new_h1_connection = require "http.h1_connection".new -local new_h2_connection = require "http.h2_connection".new +local h1_connection = require "http.h1_connection" +local h2_connection = require "http.h2_connection" local openssl_ssl = require "openssl.ssl" local openssl_ctx = require "openssl.ssl.context" local openssl_verify_param = require "openssl.x509.verify_param" @@ -17,8 +16,10 @@ local IPaddress = (IPv4address + IPv6addrz) * EOF -- Create a shared 'default' TLS context local default_ctx = http_tls.new_client_context() -local function negotiate(s, options, timeout) - s:onerror(onerror) +local function negotiate(self, options, timeout) + if cs.type(self) then -- passing cqueues socket + self = connection_common.new(self, "client") + end local tls = options.tls local version = options.version if tls then @@ -56,13 +57,13 @@ local function negotiate(s, options, timeout) old:inherit(params) ssl:setParam(old) end - local ok, err, errno = s:starttls(ssl, timeout) + local ok, err, errno = self:starttls(ssl, timeout) if not ok then return nil, err, errno end end if version == nil then - local ssl = s:checktls() + local ssl = self:checktls() if ssl then if http_tls.has_alpn and ssl:getAlpnSelected() == "h2" then version = 2 @@ -75,9 +76,9 @@ local function negotiate(s, options, timeout) end end if version < 2 then - return new_h1_connection(s, "client", version) + return h1_connection.new_from_common(self, version) elseif version == 2 then - return new_h2_connection(s, "client", options.h2_settings) + return h2_connection.new_from_common(self, options.h2_settings) else error("Unknown HTTP version: " .. tostring(version)) end diff --git a/http/connection_common.lua b/http/connection_common.lua index 432acfe4..4c6a26f4 100644 --- a/http/connection_common.lua +++ b/http/connection_common.lua @@ -3,6 +3,10 @@ local ca = require "cqueues.auxlib" local ce = require "cqueues.errno" local connection_methods = {} +local connection_mt = { + __name = "http.connection_common"; + __index = connection_methods; +} local function onerror(socket, op, why, lvl) -- luacheck: ignore 212 local err = string.format("%s: %s", op, ce.strerror(why)) @@ -25,6 +29,25 @@ local function onerror(socket, op, why, lvl) -- luacheck: ignore 212 return err, why end +-- assumes ownership of the socket +local function new_connection(socket, conn_type) + assert(socket, "must provide a socket") + if conn_type ~= "client" and conn_type ~= "server" then + error('invalid connection type. must be "client" or "server"') + end + local self = setmetatable({ + socket = socket; + type = conn_type; + version = nil; + -- A function that will be called if the connection becomes idle + onidle_ = nil; + }, connection_mt) + socket:setvbuf("full", math.huge) -- 'infinite' buffering; no write locks needed + socket:setmode("b", "bf") + socket:onerror(onerror) + return self +end + function connection_methods:pollfd() if self.socket == nil then return nil @@ -68,6 +91,17 @@ function connection_methods:connect(timeout) return true end +function connection_methods:starttls(ctx, timeout) + if self.socket == nil then + return nil + end + local ok, err, errno = self.socket:starttls(ctx, timeout) + if not ok then + return nil, err, errno + end + return true +end + function connection_methods:checktls() if self.socket == nil then return nil @@ -106,5 +140,7 @@ end return { onerror = onerror; + new = new_connection; methods = connection_methods; + mt = connection_mt; } diff --git a/http/h1_connection.lua b/http/h1_connection.lua index 753e31b4..4c2a4005 100644 --- a/http/h1_connection.lua +++ b/http/h1_connection.lua @@ -24,36 +24,27 @@ function connection_mt:__tostring() self.type, self.version) end --- assumes ownership of the socket -local function new_connection(socket, conn_type, version) - assert(socket, "must provide a socket") - if conn_type ~= "client" and conn_type ~= "server" then - error('invalid connection type. must be "client" or "server"') - end +local function new_from_common(self, version) assert(version == 1 or version == 1.1, "unsupported version") - local self = setmetatable({ - socket = socket; - type = conn_type; - version = version; + self.version = version - -- for server: streams waiting to go out - -- for client: streams waiting for a response - pipeline = new_fifo(); - -- pipeline condition is stored in stream itself + -- for server: streams waiting to go out + -- for client: streams waiting for a response + self.pipeline = new_fifo() + -- pipeline condition is stored in stream itself - -- for server: held while request being read - -- for client: held while writing request - req_locked = nil; - -- signaled when unlocked - req_cond = cc.new(); + -- for server: held while request being read + -- for client: held while writing request + -- self.req_locked = nil + -- signaled when unlocked + self.req_cond = cc.new() - -- A function that will be called if the connection becomes idle - onidle_ = nil; - }, connection_mt) - socket:setvbuf("full", math.huge) -- 'infinite' buffering; no write locks needed - socket:setmode("b", "bf") - socket:onerror(onerror) - return self + return setmetatable(self, connection_mt) +end + +local function new_connection(socket, conn_type, version) + local self = connection_common.new(socket, conn_type) + return new_from_common(self, version) end function connection_methods:clearerr(...) @@ -410,6 +401,7 @@ function connection_methods:write_body_plain(body, timeout) end return { + new_from_common = new_from_common; new = new_connection; methods = connection_methods; mt = connection_mt; diff --git a/http/h2_connection.lua b/http/h2_connection.lua index 0eeb1474..511bc956 100644 --- a/http/h2_connection.lua +++ b/http/h2_connection.lua @@ -94,12 +94,8 @@ local function socket_has_preface(socket, unget, timeout) return is_h2 end -local function new_connection(socket, conn_type, settings) - if conn_type ~= "client" and conn_type ~= "server" then - error('invalid connection type. must be "client" or "server"') - end - - local ssl = socket:checktls() +local function new_from_common(self, settings) + local ssl = self.socket:checktls() if ssl then local cipher = ssl:getCipherInfo() if h2_banned_ciphers[cipher.name] then @@ -107,61 +103,42 @@ local function new_connection(socket, conn_type, settings) end end - local self = setmetatable({ - socket = socket; - type = conn_type; - version = 2; -- for compat with h1_connection - - streams = setmetatable({}, {__mode="kv"}); - n_active_streams = 0; - onidle_ = nil; - stream0 = nil; -- store separately with a strong reference - - has_confirmed_preface = false; - has_first_settings = false; - had_eagain = false; - - -- For continuations - need_continuation = nil; -- stream - promised_stream = nil; -- stream - recv_headers_buffer = nil; - recv_headers_buffer_pos = nil; - recv_headers_buffer_pad_len = nil; - recv_headers_buffer_items = nil; - recv_headers_buffer_length = nil; - - highest_odd_stream = -1; - highest_odd_non_priority_stream = -1; - highest_even_stream = -2; - highest_even_non_priority_stream = -2; - send_goaway_lowest = nil; - recv_goaway_lowest = nil; - recv_goaway = cc.new(); - new_streams = new_fifo(); - new_streams_cond = cc.new(); - peer_settings = {}; - peer_settings_cond = cc.new(); -- signaled when the peer has changed their settings - acked_settings = {}; - send_settings = {n = 0}; - send_settings_ack_cond = cc.new(); -- for when server ACKs our settings - send_settings_acked = 0; - peer_flow_credits = 65535; -- 5.2.1 - peer_flow_credits_increase = cc.new(); - encoding_context = nil; - decoding_context = nil; - pongs = {}; -- pending pings we've sent. keyed by opaque 8 byte payload - }, connection_mt) - self:new_stream(0) + self.version = 2 -- for compat with h1_connection + self.streams = setmetatable({}, {__mode="kv"}) + self.n_active_streams = 0 + self.has_confirmed_preface = false + self.has_first_settings = false + self.had_eagain = false + self.highest_odd_stream = -1 + self.highest_odd_non_priority_stream = -1 + self.highest_even_stream = -2 + self.highest_even_non_priority_stream = -2 + -- self.send_goaway_lowest = nil + -- self.recv_goaway_lowest = nil + self.recv_goaway = cc.new() + self.new_streams = new_fifo() + self.new_streams_cond = cc.new() + self.peer_settings = {} merge_settings(self.peer_settings, default_settings) + self.peer_settings_cond = cc.new() -- signaled when the peer has changed their settings + self.acked_settings = {} merge_settings(self.acked_settings, default_settings) + self.send_settings = {n = 0} + self.send_settings_ack_cond = cc.new() -- for when server ACKs our settings + self.send_settings_acked = 0 + self.peer_flow_credits = 65535 -- 5.2.1 + self.peer_flow_credits_increase = cc.new() self.encoding_context = hpack.new(default_settings[known_settings.HEADER_TABLE_SIZE]) self.decoding_context = hpack.new(default_settings[known_settings.HEADER_TABLE_SIZE]) + self.pongs = {} -- pending pings we've sent. keyed by opaque 8 byte payload + + setmetatable(self, connection_mt) - socket:setvbuf("full", math.huge) -- 'infinite' buffering; no write locks needed - socket:setmode("b", "bna") -- writes that don't explicitly buffer will now flush the buffer. autoflush on - socket:onerror(onerror) + self:new_stream(0) -- sets self.stream0 as a strong reference + + self.socket:setmode("b", "bna") -- writes that don't explicitly buffer will now flush the buffer. autoflush on if self.type == "client" then - assert(socket:xwrite(preface, "f", 0)) + assert(self.socket:xwrite(preface, "f", 0)) end assert(self.stream0:write_settings_frame(false, settings or {}, 0, "f")) -- note that the buffer is *not* flushed right now @@ -169,6 +146,11 @@ local function new_connection(socket, conn_type, settings) return self end +local function new_connection(socket, conn_type, settings) + local self = connection_common.new(socket, conn_type) + return new_from_common(self, settings) +end + function connection_methods:timeout() if not self.had_eagain then return 0 @@ -483,6 +465,7 @@ end return { preface = preface; socket_has_preface = socket_has_preface; + new_from_common = new_from_common; new = new_connection; methods = connection_methods; mt = connection_mt;