Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WiP: Introduce SNI and randomized load-balancing #25

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _tags
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ true : warn(+A-4-6-7-9-40-42-44-48)

<tlstunnel.{ml,byte,native}>: package(tls), package(x509), package(nocrypto), \
package(lwt), package(cmdliner), package(sexplib), \
package(lwt.unix), package(tls.lwt)
package(lwt.unix), package(tls.lwt), package(wildcard)
206 changes: 169 additions & 37 deletions tlstunnel.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

open Lwt.Infix

module SNITree = Wildcard.Labeltree(struct type value = (int ref * Lwt_unix.sockaddr array) end)

module Log = struct
let inet_to_string = function
| Lwt_unix.ADDR_INET (x, p) -> Unix.string_of_inet_addr x ^ ":" ^ string_of_int p
Expand All @@ -21,11 +23,9 @@ module Log = struct
let source = inet_to_string addr in
log_raw out (source ^ ": " ^ event)

let log_initial out back event front =
let listen = inet_to_string front
and forward = inet_to_string back
in
log_raw out (event ^ listen ^ ", forwarding to " ^ forward)
let log_initial out event front =
let listen = inet_to_string front in
log_raw out (event ^ listen)
end

module Stats = struct
Expand Down Expand Up @@ -108,9 +108,22 @@ module Haproxy1 = struct
header ^ "\r\n"
end

let server_config cert priv_key =
X509_lwt.private_of_pems ~cert ~priv_key >|= fun cert ->
Tls.Config.server ~certificates:(`Single cert) ()
let server_config ~sni certs priv_keys =
let combined = try List.combine certs priv_keys
with Invalid_argument _ ->
invalid_arg "You need to specify a --key for each --cert"
(* TODO how do we handle the case when certs have self-contained keys? currently this patch is breaking backwards compatibility in an undesirable way. one option might be to simply bruteforce the combination of (cert, key) and see which ones work ? *)
in
Lwt_list.map_p
(fun (cert, priv_key) -> X509_lwt.private_of_pems ~cert ~priv_key)
combined
>|= fun certchains ->
let default = List.hd certchains in
match sni with
| true ->
Tls.Config.server ~certificates:(`Multiple_default (default , certchains)) ()
| false ->
Tls.Config.server ~certificates:(`Single default) ()

let init_socket log_raw frontend =
Unix.handle_unix_error (fun () ->
Expand Down Expand Up @@ -160,16 +173,22 @@ let rec read_write debug log closing close cnt ic oc =
| Stop -> Lwt.return_unit
| Continue -> read_write debug log closing close cnt ic oc

type hostmapping =
{ hostnames : string list
; backends : Lwt_unix.sockaddr list
}

let tls_info t =
let v, c =
let v, c, n =
match Tls_lwt.Unix.epoch t with
| `Ok data -> (data.Tls.Core.protocol_version, data.Tls.Core.ciphersuite)
| `Ok data -> (data.Tls.Core.protocol_version, data.Tls.Core.ciphersuite, data.Tls.Core.own_name)
| `Error -> assert false
in
let version = Sexplib.Sexp.to_string_hum (Tls.Core.sexp_of_tls_version v)
and cipher = Sexplib.Sexp.to_string_hum (Tls.Ciphersuite.sexp_of_ciphersuite c)
and sni_info = match n with None -> "" | Some sni -> ", SNI: " ^ sni (* TODO sanitize this string to prevent special characters in log file *)
in
version ^ ", " ^ cipher
version ^ sni_info ^ ", " ^ cipher

let safe_close closing tls fd () =
closing := true ;
Expand All @@ -181,18 +200,43 @@ let safe_close closing tls fd () =
| None -> Lwt.return_unit) >>= fun () ->
safely Lwt_unix.close fd

let worker config backend log s haproxy1 logfds debug trace () =
let worker config log s haproxy1 logfds debug trace sni snimap () =
let closing = ref false in
Lwt.catch (fun () ->
Tls_lwt.Unix.server_of_fd config ?trace s >>= fun t ->
let ic, oc = Tls_lwt.of_t t in
log ("connection established (" ^ (tls_info t) ^ ")") ;
let stats = Stats.new_stats () in

let fd = Lwt_unix.socket PF_INET SOCK_STREAM 0 in
let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in
if logfds then Fd_logger.add_fd fd ;
let close = safe_close closing (Some t) fd in

let backend =
begin match
begin match sni, Tls_lwt.Unix.epoch t with
| true , `Error -> failwith "Error reading client TLS epoch"
| false, _ -> SNITree.match_string ~wildcard:true snimap "*" (* case: we do not handle SNI (no --sni) *)
| true, `Ok data ->
begin match data.Tls.Core.own_name with
| Some sni_name -> (* handle SNI: *)
SNITree.match_string snimap sni_name
| None -> (* we handle SNI (--sni), but client didn't ask: *)
SNITree.match_string ~wildcard:true snimap "*"
end
end
with
| Some (round_robin_counter , backend_arr) ->
round_robin_counter := succ !round_robin_counter mod Array.(length backend_arr)
; backend_arr.(!round_robin_counter)
| None
-> failwith ("No backend configured for " ^ (tls_info t)
^ " - you can do this with: --map '* = host:port'")
| exception Not_found
-> failwith ("No backend configured for " ^ (tls_info t)
^ " - you can do this with: --map '* = host:port'")
end
in
log ("client connection established (" ^ (tls_info t)
^ ", forwarding to " ^ Log.(inet_to_string backend) ^ ")") ;
Lwt.catch (fun () ->
Lwt_unix.connect fd backend >>= fun () ->
let pic = Lwt_io.of_fd ~close ~mode:Lwt_io.Input fd
Expand All @@ -213,7 +257,7 @@ let worker config backend log s haproxy1 logfds debug trace () =
(function
| Unix.Unix_error (e, f, _) ->
let msg = Unix.error_message e in
log ("backend refused connection: " ^ msg ^ " while calling " ^ f) ;
log ("backend " ^ Log.(inet_to_string backend) ^ " refused connection: " ^ msg ^ " while calling " ^ f) ;
close ()
| exn ->
close () >|= fun () ->
Expand All @@ -234,13 +278,13 @@ let init out =
Lwt.async_exception_hook := (fun exn ->
Printf.fprintf out "async error %s\n%!" (Printexc.to_string exn))

let accept_loop s log_raw log_conn tls_config backend haproxy1 logfds debug trace =
let accept_loop s log_raw log_conn tls_config haproxy1 logfds debug trace sni (snimap : SNITree.t) =
let rec loop () =
Lwt.catch (fun () ->
Lwt_unix.accept s >>= fun (client_socket, addr) ->
(* log_conn addr "accepted incoming connection" ; *)
if logfds then Fd_logger.add_fd client_socket ;
Lwt.async (worker tls_config backend (log_conn addr) client_socket haproxy1 logfds debug trace) ;
Lwt.async (worker tls_config (log_conn addr) client_socket haproxy1 logfds debug trace sni snimap) ;
loop ())
(function
| Unix.Unix_error (e, f, _) ->
Expand All @@ -253,17 +297,15 @@ let accept_loop s log_raw log_conn tls_config backend haproxy1 logfds debug trac
in
loop ()

let serve (fip, fport) (bip, bport) certificate privkey haproxy1 logfd logfds debug =
let serve (fip, fport) certificates privkeys haproxy1 logfd logfds debug sni snimap =
let logchan = match logfd with
| Some fd -> Some (Unix.out_channel_of_descr fd)
| None -> None
in
init logchan ;
let frontend = Lwt_unix.ADDR_INET (fip, fport)
and backend = Lwt_unix.ADDR_INET (bip, bport)
in
server_config certificate privkey >>= fun tls_config ->
let server_socket = init_socket (Log.log_initial logchan backend) frontend in
let frontend = Lwt_unix.ADDR_INET (fip, fport) in
server_config ~sni certificates privkeys >>= fun tls_config ->
let server_socket = init_socket (Log.log_initial logchan) frontend in
let raw_log = Log.log_raw logchan in
if logfds then ignore (Fd_logger.start raw_log ()) ;
let trace =
Expand All @@ -277,22 +319,47 @@ let serve (fip, fport) (bip, bport) certificate privkey haproxy1 logfd logfds de
None
in
(* drop privileges here! *)
accept_loop server_socket raw_log (Log.log logchan) tls_config backend haproxy1 logfds debug trace
accept_loop server_socket raw_log (Log.log logchan) tls_config haproxy1 logfds debug trace sni snimap

let run_server frontend backend certificate privkey haproxy1 log quiet logfds debug =
let run_server frontend backend (certificates : string list) privkeys haproxy1 log quiet logfds debug sni hostmap =
Sys.(set_signal sigpipe Signal_ignore) ;
let logfd = match quiet, log with
| true, None -> None
| false, None -> Some Unix.stdout
| false, Some x -> Some (Unix.openfile x [Unix.O_WRONLY ; Unix.O_APPEND; Unix.O_CREAT] 0o640)
| true, Some _ -> invalid_arg "cannot specify logfile and quiet"
in
let c, p = match certificate, privkey with
| Some c, Some p -> (c, p)
| Some c, None -> (c, c)
| None, _ -> invalid_arg "missing certificate file"
in
Lwt_main.run (serve frontend backend c p haproxy1 logfd logfds debug)
if 0 = List.length certificates then
invalid_arg("You must specify a server certificate using --cert") else
if not sni && 1 < List.length certificates then
invalid_arg("You must set --sni for tlstunnel to be able to handle multiple certificates")
else
let snimap =
let populated =
SNITree.of_list_with_val @@
List.(fold_left
(fun a o ->
let data = Some (ref 0 , Array.of_list o.backends) in
map (fun hs -> hs , data) o.hostnames
|> append a
)
[]
hostmap
)
(* TODO should check that the passed certificates correspond to the --map mappings *)
in
match populated with
| `Component_error _ -> invalid_arg "fuck TODO"
| `Ok r when r = SNITree.empty ->
begin match backend with (h,p)
(* if --sni is passed, but no --map, the user probably wants to use --backend for all: *)
-> begin match "*" |> SNITree.add_with_val SNITree.empty @@ Some (ref 0, Array.of_list [Lwt_unix.ADDR_INET (h,p)] ) with
| `Component_error _ -> assert false (* TODO *)
| `Ok ok -> ok end
end
| `Ok ok -> ok
in
Lwt_main.run (serve frontend certificates privkeys haproxy1 logfd logfds debug sni snimap)

open Cmdliner

Expand Down Expand Up @@ -325,6 +392,58 @@ let host_port default : (Unix.inet_addr * int) Arg.converter =
in
parse, fun ppf (h, p) -> Format.fprintf ppf "%s:%d" (Unix.string_of_inet_addr h) p

let hostmap_parse : hostmapping Arg.converter =
let parse s : [`Error of string | `Ok of 'a]=
let split_by_char c s =
let rec _split_by_char c acc i s : string list =
match String.index_from s i c with
| next ->
let acc = String.(sub s i (next -i)) :: acc in
let next = succ next in
if next = String.length s
then List.rev acc
else _split_by_char c acc next s
| exception Not_found ->
List.rev (String.(sub s i (length s - i)) :: acc)
in _split_by_char c [] 0 s
|> List.filter (function "" -> false | _ -> true) (* remove empty elements*)
in
let parse_and_resolve_host acc (hostpair : string) =
begin match acc with
| `Error _ -> acc
| `Ok (acc) ->
begin match split_by_char ':' hostpair with
| [hostname ; port] ->
begin try
let s, p = resolve hostname , int_of_string port in
`Ok ((Lwt_unix.ADDR_INET (s,p)) :: acc)
with
(* catch int_of_string: *)
| Failure s -> `Error (s ^ ": " ^ hostname ^ ": " ^ port)
| Not_found -> `Error ("unable to resolve hostname: " ^ hostname)
end
| _ -> `Error ("Unable to find host:port in '" ^ hostpair ^ "'")
end
end
in
let parts = split_by_char '=' s in
if List.length parts <> 2 then
`Error "hostmap must contain exactly one '=' character"
else
let backends = split_by_char ' ' List.(tl parts |> hd) in
let () = List.iter (fun s -> Printf.printf "-%s> " s) backends in
let backends = List.fold_left parse_and_resolve_host (`Ok []) backends in
let hostnames = `Ok (split_by_char ' ' List.(hd parts)) in
match hostnames, backends with
| `Error , _ (* TODO should validate hostnames? *)
| `Ok [] , _ -> `Error "No hosts specified"
| _ , `Error e -> `Error ("Error parsing backends in hostmap: " ^ e)
| _ , `Ok [] -> `Error "No backends found in hostmap"
| `Ok hostnames , `Ok backends -> `Ok {hostnames ; backends}
in
parse
, fun ppf (_) -> Format.fprintf ppf "I'M A HOSTMAP"

let backend =
let default = Unix.inet_addr_loopback in
let hp = host_port default in
Expand All @@ -339,12 +458,22 @@ let frontend =
~docv:"frontend"
~doc:"The hostname and port to listen on for incoming connections (default is [*]:4433")

let certificate =
Arg.(value & opt (some string) None & info ["cert"] ~docv:"FILE"
let certificates =
Arg.(value & opt_all string [] & info ["cert"] ~docv:"FILE"
~doc:"The full path to PEM encoded certificate chain FILE (may also include the private key)")

let privkey =
Arg.(value & opt (some string) None & info ["key"] ~docv:"FILE"
let hostmap =
Arg.(value & opt_all hostmap_parse []
& info ["map"] ~docv:"MAPPINGS"
~doc:("Specify the mappings from hostnames to backends (host:port) to connect to."
^"\nThe format used looks like {[SNI-hosts] [=] [list of backends HOST:PORT]}"
^"\n--map 'docs.mirage.io docs.ocaml.org = docs.local:80'"
^"\n--map 'www.mirage.io www.ocaml.org = mirror1.local:80 mirror2.local:80'"
)
)

let privkeys =
Arg.(value & opt_all string [] & info ["key"] ~docv:"FILE"
~doc:"The full path to PEM encoded unencrypted private key in FILE (defaults to certificate_chain)")

let haproxy1 =
Expand All @@ -364,6 +493,9 @@ let quiet =
Arg.(value & flag & info ["q"; "quiet"]
~doc:"Be quiet, no logging of accesses.")

let sni =
Arg.(value & flag & info ["sni"] ~doc: "Enable Server Name Indication (SNI)")

let cmd =
let doc = "Proxy TLS connections to a standard TCP service" in
let man = [
Expand All @@ -374,7 +506,7 @@ let cmd =
`S "SEE ALSO" ;
`P "$(b,stunnel)(1), $(b,stud)(1)" ]
in
Term.(pure run_server $ frontend $ backend $ certificate $ privkey $ haproxy1 $ log $ quiet $ logfds $ debug),
Term.(pure run_server $ frontend $ backend $ certificates $ privkeys $ haproxy1 $ log $ quiet $ logfds $ debug $ sni $ hostmap),
Term.info "tlstunnel" ~version:"0.1.3" ~doc ~man

let () =
Expand Down