Skip to content

Commit

Permalink
Directly use Happy_eyeballs_lwt instead of a copy of it (#346)
Browse files Browse the repository at this point in the history
* Directly use Happy_eyeballs_lwt instead of a copy of it

This commit delete a dependency cycle between happy-eyeballs,
happy-eyeballs-lwt, dns and dns-client-lwt. The basic happy-eyeballs-lwt
implementation is not able yet to resolve domain-name but the user can
inject a getaddrinfo which may come from dns-client-lwt. The idea is:
1) create a happy-eyeballs-lwt instance
2) create a dns-client-lwt instance
3) inject Dns_client_lwt.getaddrinfo into our happy-eyeballs-lwt
   instance

This patch delete a duplicate code about happy-eyeballs implementations.

* Add happy-eyeballs with the .dev version to satisfy the OPAM solver

* Remove useless he_timer_interval into dns-client-lwt

* Update tools with the new intf. of dns-client-lwt

* Use the happy-eyeballs impl. which use Unix.getaddrinfo as the default
resolver and provide a new function, [create_happy_eyeballs], which does
the injection of the [ocaml-dns] DNS resolver.

* more

* adapt opam

* minor

* adjust docstring

* dns-client-mirage: use happy-eyeballs-mirage

* add he deps

* minor nit

* [WIP] a possible solution for happy-eyeballs/ocaml-dns and mirage

* adapt to intentioned use case

* dns-client-mirage: use connect_timeout for connecting to the remote

* Remove the optional argument of happy-eyeballs for dns-client-mirage

* remove unneeded types and values

* adapt dns-client-lwt to the dns-client-mirage interface

---------

Co-authored-by: Hannes Mehnert <hannes@mehnert.org>
  • Loading branch information
dinosaure and hannesm committed May 29, 2024
1 parent 1a80bd4 commit 2238017
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 298 deletions.
6 changes: 4 additions & 2 deletions app/odns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ let pp_nameserver ppf = function
((Tls.Config.of_client tls_cfg).Tls.Config.peer_name)

let do_a nameservers domains () =
let t = Dns_client_lwt.create ?nameservers () in
let happy_eyeballs = Happy_eyeballs_lwt.create () in
let t = Dns_client_lwt.create ?nameservers happy_eyeballs in
let (_, ns) = Dns_client_lwt.nameservers t in
Logs.info (fun m -> m "querying NS %a for A records of %a"
pp_nameserver (List.hd ns) Fmt.(list ~sep:(any ", ") Domain_name.pp) domains);
Expand Down Expand Up @@ -65,7 +66,8 @@ let for_all_domains nameservers ~domains typ f =
(* [for_all_domains] is a utility function that lets us avoid duplicating
this block of code in all the subcommands.
We leave {!do_a} simple to provide a more readable example. *)
let t = Dns_client_lwt.create ?nameservers () in
let happy_eyeballs = Happy_eyeballs_lwt.create () in
let t = Dns_client_lwt.create ?nameservers happy_eyeballs in
let _, ns = Dns_client_lwt.nameservers t in
Logs.info (fun m -> m "NS: %a" pp_nameserver (List.hd ns));
let open Lwt in
Expand Down
3 changes: 2 additions & 1 deletion app/odnssec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ let jump () hostname typ ns =
| None -> None
| Some ip -> Some (`Tcp, [ `Plaintext (ip, 53) ])
in
let t = Dns_client_lwt.create ?nameservers ~edns:(`Manual edns) () in
let happy_eyeballs = Happy_eyeballs_lwt.create () in
let t = Dns_client_lwt.create ?nameservers ~edns:(`Manual edns) happy_eyeballs in
let (_, ns) = Dns_client_lwt.nameservers t in
Logs.info (fun m -> m "querying NS %a for A records of %a"
pp_nameserver (List.hd ns) Domain_name.pp hostname);
Expand Down
2 changes: 2 additions & 0 deletions client/dns_client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ struct
edns : [ `None | `Auto | `Manual of Dns.Edns.t ] ;
}

let transport { transport ; _ } = transport

(* TODO eventually use Auto, and retry without on FormErr *)
let create ?(cache_size = 32) ?(edns = `None) ?nameservers ?(timeout = Duration.of_sec 5) stack =
{ cache = Dns_cache.empty cache_size ;
Expand Down
4 changes: 4 additions & 0 deletions client/dns_client.mli
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ module Make : functor (T : S) ->
sig

type t
(** The abstract type of a DNS client. *)

val transport : t -> T.t
(** [transport t] is the transport of [t]. *)

val create : ?cache_size:int ->
?edns:[ `None | `Auto | `Manual of Dns.Edns.t ] ->
Expand Down
3 changes: 2 additions & 1 deletion dns-client-lwt.opam
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ depends: [
"lwt" {>= "4.2.1"}
"mtime" {>= "1.2.0"}
"mirage-crypto-rng-lwt" {>= "0.11.0"}
"happy-eyeballs" {>= "0.6.0"}
"happy-eyeballs-lwt" {>= "1.1.0"}
"happy-eyeballs" {>= "1.0.0"}
"tls-lwt" {>= "0.16.0"}
"ca-certs"
]
Expand Down
3 changes: 2 additions & 1 deletion dns-client-mirage.opam
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ depends: [
"mirage-random" {>= "2.0.0"}
"mirage-time" {>= "2.0.0"}
"mirage-clock" {>= "3.0.0"}
"happy-eyeballs" {>= "0.6.0"}
"happy-eyeballs-mirage" {>= "1.1.0"}
"happy-eyeballs" {>= "1.0.0"}
"tls-mirage" {>= "0.16.0"}
"x509" {>= "0.16.0"}
"ca-certs-nss"
Expand Down
164 changes: 31 additions & 133 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ module Log = (val Logs.src_log src : Logs.LOG)
module Transport : Dns_client.S
with type io_addr = [ `Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int ]
and type +'a io = 'a Lwt.t
and type stack = unit
and type stack = Happy_eyeballs_lwt.t
= struct
type io_addr = [ `Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int ]
type +'a io = 'a Lwt.t
type stack = unit
type stack = Happy_eyeballs_lwt.t
type nameservers =
| Static of io_addr list
| Resolv_conf of {
Expand All @@ -26,10 +26,7 @@ module Transport : Dns_client.S
mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ;
mutable connected_condition : (unit, [ `Msg of string ]) result Lwt_condition.t option ;
mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable he : Happy_eyeballs.t ;
mutable cancel_connecting : (int * unit Lwt.u) list Happy_eyeballs.Waiter_map.t;
mutable waiters : ((Ipaddr.t * int) * Lwt_unix.file_descr, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ;
timer_condition : unit Lwt_condition.t ;
he : Happy_eyeballs_lwt.t ;
}
type context = t

Expand All @@ -51,119 +48,9 @@ module Transport : Dns_client.S

let clock = Mtime_clock.elapsed_ns

let he_timer_interval = Duration.of_ms 10

let close_socket fd =
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit)

let try_connect ip port =
Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number ->
let fam =
Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6))
in
let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in
let open Lwt_result.Infix in
Lwt.catch (fun () ->
let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in
Lwt_result.ok (Lwt_unix.connect socket addr) >|= fun () ->
socket)
(fun e ->
Lwt_result.ok (close_socket socket) >>= fun () ->
let err =
Fmt.str "error %s connecting to nameserver %a:%d"
(Printexc.to_string e) Ipaddr.pp ip port
in
Lwt.return (Error (`Msg err)))

let handle_one_action t = function
| Happy_eyeballs.Connect (host, id, attempt, (ip, port)) ->
let cancelled, cancel = Lwt.task () in
let entry = attempt, cancel in
t.cancel_connecting <-
Happy_eyeballs.Waiter_map.update id
(function None -> Some [ entry ] | Some c -> Some (entry :: c))
t.cancel_connecting;
let conn =
try_connect ip port >>= function
| Ok fd ->
let cancel_connecting, others =
Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting
in
t.cancel_connecting <- cancel_connecting;
List.iter (fun (att, w) -> if att <> attempt then Lwt.wakeup_later w ())
(Option.value ~default:[] others);
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
Lwt.wakeup_later waiter (Ok ((ip, port), fd));
Lwt.return_unit
| None -> close_socket fd
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, (ip, port)))
| Error `Msg err ->
t.cancel_connecting <-
Happy_eyeballs.Waiter_map.update id
(function None -> None | Some c ->
match List.filter (fun (att, _) -> not (att = attempt)) c with
| [] -> None
| c -> Some c)
t.cancel_connecting;
Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err)))
in
Lwt.pick [ conn ; (cancelled >|= fun () -> None); ]
| Connect_failed (host, id, reason) ->
let cancel_connecting, others =
Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting
in
t.cancel_connecting <- cancel_connecting;
List.iter (fun (_, w) -> Lwt.wakeup_later w ()) (Option.value ~default:[] others);
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
let host_or_ip v =
match Ipaddr.of_domain_name v with
| None -> Domain_name.to_string v
| Some ip -> Ipaddr.to_string ip
in
let err =
Fmt.str "connection to %s failed: %s" (host_or_ip host) reason
in
Lwt.wakeup_later waiter (Error (`Msg err))
| None -> ()
end;
Lwt.return None
| Resolve_a _ | Resolve_aaaa _ as a ->
Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a);
Lwt.return None

let rec handle_action t action =
handle_one_action t action >>= function
| None -> Lwt.return_unit
| Some event ->
let he, actions = Happy_eyeballs.event t.he (clock ()) event in
t.he <- he;
Lwt_list.iter_p (handle_action t) actions

let handle_timer_actions t actions =
Lwt.async (fun () -> Lwt_list.iter_p (fun a -> handle_action t a) actions)

let rec he_timer t =
let open Lwt.Infix in
let rec loop () =
let he, cont, actions = Happy_eyeballs.timer t.he (clock ()) in
t.he <- he ;
handle_timer_actions t actions ;
match cont with
| `Suspend -> he_timer t
| `Act ->
Lwt_unix.sleep (Duration.to_f he_timer_interval) >>= fun () ->
loop ()
in
Lwt_condition.wait t.timer_condition >>= fun () ->
loop ()

let authenticator =
let authenticator_ref = ref None in
fun () ->
Expand Down Expand Up @@ -242,7 +129,7 @@ module Transport : Dns_client.S
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolver ()

let create ?nameservers ~timeout () =
let create ?nameservers ~timeout happy_eyeballs =
let nameservers =
match nameservers with
| Some (`Udp, _) -> invalid_arg "UDP is not supported"
Expand All @@ -252,19 +139,14 @@ module Transport : Dns_client.S
| Error _ -> Resolv_conf { nameservers = default_resolver (); digest = None }
| Ok (ips, digest) -> Resolv_conf { nameservers = ips; digest = Some digest }
in
let t = {
{
nameservers ;
timeout_ns = timeout ;
fd = None ;
connected_condition = None ;
requests = IM.empty ;
he = Happy_eyeballs.create ~connect_timeout:timeout (clock ()) ;
cancel_connecting = Happy_eyeballs.Waiter_map.empty ;
waiters = Happy_eyeballs.Waiter_map.empty ;
timer_condition = Lwt_condition.create () ;
} in
Lwt.async (fun () -> he_timer t);
t
he = happy_eyeballs ;
}

let nameservers { nameservers; _ } = `Tcp, nameserver_ips nameservers

Expand Down Expand Up @@ -377,15 +259,14 @@ module Transport : Dns_client.S
ns

let rec connect_to_ns_list (t : t) connected_condition nameservers =
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
(* The connect_timeout given here is a bit too much, since it should
be (a) connect to the remote NS (b) send query, receive answer.
At the moment, how this is done, is that we use the connect_timeout
for (a) and another separate one for (b). Since we do connection
pooling, it is slightly tricky to use only a single connect_timeout. *)
Happy_eyeballs_lwt.connect_ip ~connect_timeout:t.timeout_ns t.he ns >>= function
| Error `Msg msg ->
let err =
Error (`Msg (Fmt.str "error %s connecting to resolver %a"
Expand Down Expand Up @@ -458,5 +339,22 @@ end
that goes on top of it: *)
include Dns_client.Make(Transport)

let create ?cache_size ?edns ?nameservers ?timeout happy_eyeballs =
let dns = create ?cache_size ?edns ?nameservers ?timeout happy_eyeballs in
let getaddrinfo record domain_name =
let open Lwt_result.Infix in
match record with
| `A ->
getaddrinfo dns Dns.Rr_map.A domain_name >|= fun (_ttl, set) ->
Ipaddr.V4.Set.fold (fun ipv4 -> Ipaddr.Set.add (Ipaddr.V4 ipv4))
set Ipaddr.Set.empty
| `AAAA ->
getaddrinfo dns Dns.Rr_map.Aaaa domain_name >|= fun (_ttl, set) ->
Ipaddr.V6.Set.fold (fun ipv6 -> Ipaddr.Set.add (Ipaddr.V6 ipv6))
set Ipaddr.Set.empty
in
Happy_eyeballs_lwt.inject happy_eyeballs getaddrinfo;
dns

(* initialize the RNG *)
let () = Mirage_crypto_rng_lwt.initialize (module Mirage_crypto_rng.Fortuna)
2 changes: 1 addition & 1 deletion lwt/client/dns_client_lwt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
module Transport : Dns_client.S
with type io_addr = [ `Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int ]
and type +'a io = 'a Lwt.t
and type stack = unit
and type stack = Happy_eyeballs_lwt.t

include module type of Dns_client.Make(Transport)
2 changes: 1 addition & 1 deletion lwt/client/dune
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
(name dns_client_lwt)
(modules dns_client_lwt)
(public_name dns-client-lwt)
(libraries lwt lwt.unix dns dns-client dns-client.resolvconf mtime.clock.os mirage-crypto-rng-lwt ipaddr.unix happy-eyeballs tls-lwt ca-certs)
(libraries lwt lwt.unix dns dns-client dns-client.resolvconf mtime.clock.os mirage-crypto-rng-lwt ipaddr.unix happy-eyeballs happy-eyeballs-lwt tls-lwt ca-certs)
(wrapped false))
Loading

0 comments on commit 2238017

Please sign in to comment.