Skip to content

Commit

Permalink
remove cstruct from dns, dns-client, dns-client-lwt
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesm committed Apr 3, 2024
1 parent 4c32e12 commit 8a85e10
Show file tree
Hide file tree
Showing 10 changed files with 557 additions and 558 deletions.
32 changes: 16 additions & 16 deletions client/dns_client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module Pure = struct

let make_query rng protocol ?(dnssec = false) edns hostname
: 'xy ->
Cstruct.t * 'xy query_state =
string * 'xy query_state =
(* SRV records: Service + Protocol are case-insensitive, see RFC2728 pg2. *)
fun record_type ->
let edns = match edns with
Expand All @@ -37,9 +37,9 @@ module Pure = struct
begin match protocol with
| `Udp -> cs
| `Tcp ->
let len_field = Cstruct.create 2 in
Cstruct.BE.set_uint16 len_field 0 (Cstruct.length cs) ;
Cstruct.concat [len_field ; cs]
let len_field = Bytes.create 2 in
Bytes.set_uint16_be len_field 0 (String.length cs) ;
String.concat "" [Bytes.unsafe_to_string len_field ; cs]
end, { protocol ; query ; key = record_type }

(* name: the originally requested domain name. *)
Expand Down Expand Up @@ -71,16 +71,16 @@ module Pure = struct
function (* consume TCP two-byte length prefix: *)
| `Udp -> Ok buf
| `Tcp ->
match Cstruct.BE.get_uint16 buf 0 with
match String.get_uint16_be buf 0 with
| exception Invalid_argument _ -> Error () (* TODO *)
| pkt_len when pkt_len > Cstruct.length buf -2 ->
| pkt_len when pkt_len > String.length buf -2 ->
Log.debug (fun m -> m "Partial: %d >= %d-2"
pkt_len (Cstruct.length buf));
pkt_len (String.length buf));
Error () (* TODO return remaining # *)
| pkt_len ->
if 2 + pkt_len < Cstruct.length buf then
if 2 + pkt_len < String.length buf then
Log.warn (fun m -> m "Extraneous data in DNS response");
Ok (Cstruct.sub buf 2 pkt_len)
Ok (String.sub buf 2 pkt_len)

let find_soa authority =
Domain_name.Map.fold (fun k rr_map acc ->
Expand Down Expand Up @@ -143,7 +143,7 @@ module Pure = struct
to_msg t (Packet.reply_matches_request ~request:state.query t)

let parse_response (type requested)
: requested Rr_map.key query_state -> Cstruct.t ->
: requested Rr_map.key query_state -> string ->
(Packet.reply,
[> `Partial
| `Msg of string]) result =
Expand All @@ -153,7 +153,7 @@ module Pure = struct
| Error () -> Error `Partial

let handle_response (type requested)
: requested Rr_map.key query_state -> Cstruct.t ->
: requested Rr_map.key query_state -> string ->
( [ `Data of requested
| `Partial
| `No_data of [`raw] Domain_name.t * Soa.t
Expand Down Expand Up @@ -183,11 +183,11 @@ module type S = sig
val create : ?nameservers:(Dns.proto * io_addr list) -> timeout:int64 -> stack -> t

val nameservers : t -> Dns.proto * io_addr list
val rng : int -> Cstruct.t
val rng : int -> string
val clock : unit -> int64

val connect : t -> (Dns.proto * context, [> `Msg of string ]) result io
val send_recv : context -> Cstruct.t -> (Cstruct.t, [> `Msg of string ]) result io
val send_recv : context -> string -> (string, [> `Msg of string ]) result io
val close : context -> unit io

val bind : 'a io -> ('a -> 'b io) -> 'b io
Expand Down Expand Up @@ -261,8 +261,8 @@ struct
in
(Transport.send_recv socket tx >>| fun recv_buffer ->
Log.debug (fun m -> m "Read @[<v>%d bytes@]"
(Cstruct.length recv_buffer)) ;
Log.debug (fun m -> m "received: %a" Cstruct.hexdump_pp recv_buffer);
(String.length recv_buffer)) ;
Log.debug (fun m -> m "received: %a" (Ohex.pp_hexdump ()) recv_buffer);
Transport.lift (Pure.parse_response state recv_buffer)) >>= fun r ->
Transport.close socket >>= fun () ->
Transport.lift r
Expand Down Expand Up @@ -291,7 +291,7 @@ struct
in
(Transport.send_recv socket tx >>| fun recv_buffer ->
Log.debug (fun m -> m "Read @[<v>%d bytes@]"
(Cstruct.length recv_buffer)) ;
(String.length recv_buffer)) ;
let update_cache entry =
let rank = Dns_cache.NonAuthoritativeAnswer in
let cache =
Expand Down
12 changes: 6 additions & 6 deletions client/dns_client.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ module type S = sig
the underlying context, can be used if the user does not want to
bother with configuring their own.*)

val rng : int -> Cstruct.t
val rng : int -> string
(** [rng t] is a random number generator. *)

val clock : unit -> int64
Expand All @@ -49,7 +49,7 @@ module type S = sig
val connect : t -> (Dns.proto * context, [> `Msg of string ]) result io
(** [connect t] is a new connection ([context]) to [t], or an error. *)

val send_recv : context -> Cstruct.t -> (Cstruct.t, [> `Msg of string ]) result io
val send_recv : context -> string -> (string, [> `Msg of string ]) result io
(** [send_recv context buffer] sends [buffer] to the [context] upstream, and
then reads a buffer. *)

Expand Down Expand Up @@ -144,16 +144,16 @@ module Pure : sig
application. *)

val make_query :
(int -> Cstruct.t) -> Dns.proto -> ?dnssec:bool ->
(int -> string) -> Dns.proto -> ?dnssec:bool ->
[ `None | `Auto | `Manual of Dns.Edns.t ] ->
'a Domain_name.t ->
'query_type Dns.Rr_map.key ->
Cstruct.t * 'query_type Dns.Rr_map.key query_state
string * 'query_type Dns.Rr_map.key query_state
(** [make_query rng protocol name query_type] is [query, query_state]
where [query] is the serialized DNS query to send to the name server,
and [query_state] is the information required to validate the response. *)

val parse_response : 'query_type Dns.Rr_map.key query_state -> Cstruct.t ->
val parse_response : 'query_type Dns.Rr_map.key query_state -> string ->
(Dns.Packet.reply, [ `Partial | `Msg of string]) result
(** [parse_response query_state response] is the information contained in
[response] parsed using [query_state] when the query was successful, or
Expand All @@ -169,7 +169,7 @@ module Pure : sig
In a UDP usage context the [`Partial] means information was lost, due to
an incomplete packet. *)

val handle_response : 'query_type Dns.Rr_map.key query_state -> Cstruct.t ->
val handle_response : 'query_type Dns.Rr_map.key query_state -> string ->
( [ `Data of 'query_type
| `Partial
| `No_data of [`raw] Domain_name.t * Dns.Soa.t
Expand Down
1 change: 0 additions & 1 deletion dns.opam
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ depends: [
"fmt" {>= "0.8.8"}
"domain-name" {>= "0.4.0"}
"gmap" {>= "0.3.0"}
"cstruct" {>= "6.0.0"}
"ipaddr" {>= "5.2.0"}
"alcotest" {with-test}
"lru" {>= "0.3.0"}
Expand Down
36 changes: 18 additions & 18 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ module Transport : Dns_client.S
(* TODO: avoid race, use a mvar instead of condition *)
mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ;
mutable connected_condition : (unit, [ `Msg of string ]) result Lwt_condition.t option ;
mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable requests : (string * (string, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable he : Happy_eyeballs.t ;
mutable cancel_connecting : (int * unit Lwt.u) list Happy_eyeballs.Waiter_map.t;
mutable waiters : ((Ipaddr.t * int) * Lwt_unix.file_descr, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ;
Expand Down Expand Up @@ -283,9 +283,9 @@ module Transport : Dns_client.S
Lwt.catch (fun () ->
match fd with
| `Plain fd ->
Lwt_unix.send fd (Cstruct.to_bytes tx) 0
(Cstruct.length tx) [] >>= fun res ->
if res <> Cstruct.length tx then
Lwt_unix.send fd (Bytes.unsafe_of_string tx) 0
(String.length tx) [] >>= fun res ->
if res <> String.length tx then
Lwt_result.fail (`Msg ("oops" ^ (string_of_int res)))
else
Lwt_result.return ()
Expand All @@ -294,11 +294,11 @@ module Transport : Dns_client.S
(fun e -> Lwt.return (Error (`Msg (Printexc.to_string e))))

let send_recv (t : context) tx =
if Cstruct.length tx > 4 then
if String.length tx > 4 then
match t.fd with
| None -> Lwt.return (Error (`Msg "no connection to the nameserver established"))
| Some fd ->
let id = Cstruct.BE.get_uint16 tx 2 in
let id = String.get_uint16_be tx 2 in
with_timeout t.timeout_ns
(let open Lwt_result.Infix in
send_query fd tx >>= fun () ->
Expand All @@ -315,20 +315,20 @@ module Transport : Dns_client.S
let bind = Lwt.bind
let lift = Lwt.return

let rec read_loop ?(linger = Cstruct.empty) (t : t) fd =
let rec read_loop ?(linger = "") (t : t) fd =
Lwt.catch (fun () ->
match fd with
| `Plain fd ->
let recv_buffer = Bytes.make 2048 '\000' in
let recv_buffer = Bytes.create 2048 in
Lwt_unix.recv fd recv_buffer 0 (Bytes.length recv_buffer) [] >|= fun r ->
(r, Cstruct.of_bytes recv_buffer)
(r, recv_buffer)
| `Tls fd ->
let recv_buffer = Cstruct.create 2048 in
let recv_buffer = Bytes.create 2048 in
Tls_lwt.Unix.read fd recv_buffer >|= fun r ->
(r, recv_buffer))
(fun e ->
Log.err (fun m -> m "error %s reading from resolver" (Printexc.to_string e));
Lwt.return (0, Cstruct.empty)) >>= function
Lwt.return (0, Bytes.empty)) >>= function
| (0, _) ->
(match fd with
| `Plain fd -> close_socket fd
Expand All @@ -338,16 +338,16 @@ module Transport : Dns_client.S
Log.info (fun m -> m "end of file reading from resolver")
| (read_len, cs) ->
let rec handle_data data =
let cs_len = Cstruct.length data in
let cs_len = String.length data in
if cs_len > 2 then
let len = Cstruct.BE.get_uint16 data 0 in
let len = String.get_uint16_be data 0 in
if cs_len - 2 >= len then
let packet, rest =
if cs_len - 2 = len
then data, Cstruct.empty
else Cstruct.split data (len + 2)
then data, ""
else String.sub data 0 (len + 2), String.sub data (len + 2) (String.length data - len - 2)
in
let id = Cstruct.BE.get_uint16 packet 2 in
let id = String.get_uint16_be packet 2 in
(match IM.find_opt id t.requests with
| None -> Log.warn (fun m -> m "received unsolicited data, ignoring")
| Some (_, cond) ->
Expand All @@ -358,8 +358,8 @@ module Transport : Dns_client.S
else
read_loop ~linger:data t fd
in
let cs = Cstruct.sub cs 0 read_len in
handle_data (if Cstruct.length linger = 0 then cs else Cstruct.append linger cs)
let cs = String.sub (Bytes.unsafe_to_string cs) 0 read_len in
handle_data (if String.length linger = 0 then cs else linger ^ cs)

let req_all fd t =
IM.fold (fun _id (data, _) r ->
Expand Down
Loading

0 comments on commit 8a85e10

Please sign in to comment.