From 1e6cf6208b87ce290568aec128c4e6a73919f836 Mon Sep 17 00:00:00 2001 From: Thomas Gazagnaire Date: Thu, 3 Aug 2017 17:38:12 +0200 Subject: [PATCH 1/2] Remove the lwt backend Only use uwt which is more efficient on Windows and work pretty well on Linux/macOS as well. Signed-off-by: Thomas Gazagnaire --- Makefile | 4 +- src/bin/connect.ml | 4 +- src/bin/connect.mli | 4 +- src/bin/main.ml | 34 +- src/hostnet/{host_uwt.ml => host.ml} | 0 src/hostnet/{host_lwt_unix.mli => host.mli} | 0 src/hostnet/host_lwt_unix.ml | 806 -------------------- src/hostnet/host_uwt.mli | 1 - src/hostnet/slirp.ml | 1 - src/hostnet/slirp.mli | 1 - src/hostnet/vmnet.ml | 3 +- src/hostnet_test/forwarding.ml | 3 - src/hostnet_test/jbuild | 8 +- src/hostnet_test/{main_uwt.ml => main.ml} | 4 +- src/hostnet_test/main_lwt.ml | 19 - src/hostnet_test/slirp_stack.ml | 4 +- src/hostnet_test/suite.ml | 20 +- src/hostnet_test/test_half_close.ml | 6 - src/hostnet_test/test_http.ml | 6 - src/hostnet_test/test_nat.ml | 7 +- 20 files changed, 26 insertions(+), 909 deletions(-) rename src/hostnet/{host_uwt.ml => host.ml} (100%) rename src/hostnet/{host_lwt_unix.mli => host.mli} (100%) delete mode 100644 src/hostnet/host_lwt_unix.ml delete mode 100644 src/hostnet/host_uwt.mli rename src/hostnet_test/{main_uwt.ml => main.ml} (93%) delete mode 100644 src/hostnet_test/main_lwt.ml diff --git a/Makefile b/Makefile index 226b6f345..3d138a8d7 100644 --- a/Makefile +++ b/Makefile @@ -48,8 +48,8 @@ vpnkit.exe: src/bin/depends.ml .PHONY: test test: - jbuilder build --dev src/hostnet_test/main_uwt.exe - ./_build/default/src/hostnet_test/main_uwt.exe + jbuilder build --dev src/hostnet_test/main.exe + ./_build/default/src/hostnet_test/main.exe .PHONY: OSS-LICENSES OSS-LICENSES: diff --git a/src/bin/connect.ml b/src/bin/connect.ml index a36593b2e..ca64a0dab 100644 --- a/src/bin/connect.ml +++ b/src/bin/connect.ml @@ -13,7 +13,7 @@ let (/) = Filename.concat let home = try Sys.getenv "HOME" with Not_found -> "/Users/root" let vsock_port = 62373l -module Make_unix(Host: Sig.HOST) = struct +module Unix = struct let vsock_path = ref (home / "Library/Containers/com.docker.docker/Data/@connect") @@ -41,7 +41,7 @@ module Make_unix(Host: Sig.HOST) = struct Fmt.kstrf Lwt.fail_with "%a" pp_write_error e end -module Make_hvsock(Host: Sig.HOST) = struct +module Hvsock = struct (* Avoid using `detach` because we don't want to exhaust the thread pool since this will block the main TCP/IP stack. *) module F = diff --git a/src/bin/connect.mli b/src/bin/connect.mli index f5f8faaff..2ca715a4a 100644 --- a/src/bin/connect.mli +++ b/src/bin/connect.mli @@ -1,10 +1,10 @@ -module Make_unix(Host: Sig.HOST): sig +module Unix: sig include Sig.Connector val vsock_path: string ref end -module Make_hvsock(Host: Sig.HOST): sig +module Hvsock: sig include Sig.Connector val set_port_forward_addr: Hvsock.sockaddr -> unit diff --git a/src/bin/main.ml b/src/bin/main.ml index 3ea9dfb48..9baf15dd9 100644 --- a/src/bin/main.ml +++ b/src/bin/main.ml @@ -46,18 +46,16 @@ let hvsock_addr_of_uri ~default_serviceid uri = in { Hvsock.vmid; serviceid } -module Main(Host: Sig.HOST) = struct - module Vnet = Basic_backend.Make - module Connect_unix = Connect.Make_unix(Host) - module Connect_hvsock = Connect.Make_hvsock(Host) + module Connect_unix = Connect.Unix + module Connect_hvsock = Connect.Hvsock module Bind = Bind.Make(Host.Sockets) module Dns_policy = Hostnet_dns.Policy(Host.Files) module Config = Active_config.Make(Host.Time)(Host.Sockets.Stream.Unix) module Forward_unix = Forward.Make(Mclock)(Connect_unix)(Bind) module Forward_hvsock = Forward.Make(Mclock)(Connect_hvsock)(Bind) module HV = Flow_lwt_hvsock.Make(Host.Time)(Host.Fn) - module Hosts = Hosts.Make(Host.Files) + module HostsFile = Hosts.Make(Host.Files) let file_descr_of_int (x: int) : Unix.file_descr = if Sys.os_type <> "Unix" @@ -256,7 +254,7 @@ module Main(Host: Sig.HOST) = struct ~config:(`Upstream { servers; search = []; assume_offline_after_drops = None }) ); - let etc_hosts_watch = match Hosts.watch ~path:hosts () with + let etc_hosts_watch = match HostsFile.watch ~path:hosts () with | Ok watch -> Some watch | Error (`Msg m) -> Log.err (fun f -> f "Failed to watch hosts file %s: %s" hosts m); @@ -330,7 +328,7 @@ module Main(Host: Sig.HOST) = struct | Some "hyperv-connect" -> let module Slirp_stack = Slirp.Make(Config)(Vmnet.Make(HV))(Dns_policy) - (Mclock)(Stdlibrandom)(Host)(Vnet) + (Mclock)(Stdlibrandom)(Vnet) in let sockaddr = hvsock_addr_of_uri ~default_serviceid:ethernet_serviceid @@ -352,7 +350,7 @@ module Main(Host: Sig.HOST) = struct | _ -> let module Slirp_stack = Slirp.Make(Config)(Vmnet.Make(Host.Sockets.Stream.Unix))(Dns_policy) - (Mclock)(Stdlibrandom)(Host)(Vnet) + (Mclock)(Stdlibrandom)(Vnet) in unix_listen socket_url >>= fun server -> ( match config with @@ -370,7 +368,7 @@ module Main(Host: Sig.HOST) = struct let wait_forever, _ = Lwt.task () in wait_forever >|= fun () -> match etc_hosts_watch with - | Some watch -> Hosts.unwatch watch + | Some watch -> HostsFile.unwatch watch | None -> () let main @@ -382,18 +380,6 @@ module Main(Host: Sig.HOST) = struct (main_t socket_url port_control_url introspection_url diagnostics_url max_connections vsock_path db_path db_branch dns hosts host_names listen_backlog debug) -end - -let main - socket port_control introspection_url diagnostics_url max_connections - vsock_path db_path db_branch dns hosts host_names select listen_backlog - debug - = - let module Use_lwt_unix = Main(Host_lwt_unix) in - let module Use_uwt = Main(Host_uwt) in - (if select then Use_lwt_unix.main else Use_uwt.main) - socket port_control introspection_url diagnostics_url max_connections - vsock_path db_path db_branch dns hosts host_names listen_backlog debug open Cmdliner @@ -510,10 +496,6 @@ let host_names = in Arg.(value & opt string "vpnkit.host" doc) -let select = - let doc = "Use a select event loop rather than the default libuv-based one" in - Arg.(value & flag & info [ "select" ] ~doc) - let listen_backlog = let doc = "Specify a maximum listen(2) backlog. If no override is specified \ then we will use SOMAXCONN." in @@ -533,7 +515,7 @@ let command = Term.(pure main $ socket $ port_control_path $ introspection_path $ diagnostics_path $ max_connections $ vsock_path $ db_path $ db_branch $ dns $ hosts - $ host_names $ select $ listen_backlog $ debug), + $ host_names $ listen_backlog $ debug), Term.info (Filename.basename Sys.argv.(0)) ~version:Depends.version ~doc ~man let () = diff --git a/src/hostnet/host_uwt.ml b/src/hostnet/host.ml similarity index 100% rename from src/hostnet/host_uwt.ml rename to src/hostnet/host.ml diff --git a/src/hostnet/host_lwt_unix.mli b/src/hostnet/host.mli similarity index 100% rename from src/hostnet/host_lwt_unix.mli rename to src/hostnet/host.mli diff --git a/src/hostnet/host_lwt_unix.ml b/src/hostnet/host_lwt_unix.ml deleted file mode 100644 index f54105416..000000000 --- a/src/hostnet/host_lwt_unix.ml +++ /dev/null @@ -1,806 +0,0 @@ -open Lwt.Infix - -let src = - let src = - Logs.Src.create "Lwt_unix" ~doc:"Host interface based on Lwt_unix" - in - Logs.Src.set_level src (Some Logs.Info); - src - -module Log = (val Logs.src_log src : Logs.LOG) - -let default_read_buffer_size = 65536 - -let log_exception_continue description f = - Lwt.catch - (fun () -> f ()) - (fun e -> - Log.err (fun f -> f "%s: caught %s" description (Printexc.to_string e)); - Lwt.return () - ) - -module Common = struct - (** FLOW boilerplate *) - - type 'a io = 'a Lwt.t - type buffer = Cstruct.t - type error = [`Msg of string] - type write_error = [Mirage_flow.write_error | error] - let pp_error ppf (`Msg x) = Fmt.string ppf x - - let pp_write_error ppf = function - | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e - | #error as e -> pp_error ppf e - - let errorf fmt = Fmt.kstrf (fun s -> Lwt_result.fail (`Msg s)) fmt -end - -module Sockets = struct - - let max_connections = ref None - - let set_max_connections x = max_connections := x - - let next_connection_idx = - let idx = ref 0 in - fun () -> - let next = !idx in - incr idx; - next - - exception Too_many_connections - - let connection_table = Hashtbl.create 511 - - let get_num_connections () = Hashtbl.length connection_table - - let connections () = - let xs = Hashtbl.fold (fun _ c acc -> c :: acc) connection_table [] in - Vfs.File.ro_of_string (String.concat "\n" xs) - - let register_connection_no_limit description = - let idx = next_connection_idx () in - Hashtbl.replace connection_table idx description; - idx - - let register_connection = - let last_error_log = ref 0. in - fun description -> match !max_connections with - | Some m when Hashtbl.length connection_table >= m -> - let now = Unix.gettimeofday () in - if (now -. !last_error_log) > 30. then begin - (* Avoid hammering the logging system *) - Log.err (fun f -> - f "exceeded maximum number of forwarded connections (%d)" m); - last_error_log := now; - end; - Lwt.fail Too_many_connections - | _ -> - let idx = register_connection_no_limit description in - Lwt.return idx - - let deregister_connection idx = - Hashtbl.remove connection_table idx - - let address_of_sockaddr = function - | Lwt_unix .ADDR_INET(ip, port) -> - (try Some (Ipaddr.of_string_exn @@ Unix.string_of_inet_addr ip, port) - with _ -> None) - | _ -> None - - let string_of_sockaddr = function - | Lwt_unix.ADDR_INET(ip, port) -> - Fmt.strf "%s:%d" (Unix.string_of_inet_addr ip) port - | Lwt_unix.ADDR_UNIX path -> path - - let string_of_address (dst, dst_port) = - Fmt.strf "%s:%d" (Ipaddr.to_string dst) dst_port - - let sockaddr_of_address (dst, dst_port) = - Unix.ADDR_INET(Unix.inet_addr_of_string @@ Ipaddr.to_string dst, dst_port) - - let unix_bind_one ?(description="") pf ty ip port = - let protocol = match pf, ty with - | (Unix.PF_INET | Unix.PF_INET6), Unix.SOCK_STREAM -> "tcp:" - | (Unix.PF_INET | Unix.PF_INET6), Unix.SOCK_DGRAM -> "udp:" - | _, _ -> "unknown:" in - let description = - Fmt.strf "%s%a:%d %s" protocol Ipaddr.pp_hum ip port description - in - register_connection description >>= fun idx -> - let addr = - Lwt_unix.ADDR_INET(Unix.inet_addr_of_string @@ Ipaddr.to_string ip, port) - in - let fd = - try Lwt_unix.socket pf ty 0 - with e -> deregister_connection idx; raise e - in - Lwt.catch (fun () -> - Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; - Lwt_unix.bind fd addr >|= fun () -> - idx, fd - ) (fun e -> - Lwt_unix.close fd - >>= fun () -> - deregister_connection idx; - Lwt.fail e - ) - - let unix_bind ?description ty (local_ip, local_port) = - let pf = match local_ip with - | Ipaddr.V4 _ -> Lwt_unix.PF_INET - | Ipaddr.V6 _ -> Lwt_unix.PF_INET6 in - unix_bind_one ?description pf ty local_ip local_port - >>= fun (idx, fd) -> - let local_port = match local_port, Lwt_unix.getsockname fd with - | 0, Unix.ADDR_INET(_, local_port) -> local_port - | 0, _ -> assert false (* common only uses ADDR_INET *) - | _ -> local_port in - (* On some systems localhost will resolve to ::1 first and this - can cause performance problems (particularly on - Windows). Perform a best-effort bind to the ::1 address. *) - Lwt.catch (fun () -> - if Ipaddr.compare local_ip (Ipaddr.V4 Ipaddr.V4.localhost) = 0 - || Ipaddr.compare local_ip (Ipaddr.V4 Ipaddr.V4.any) = 0 - then begin - Log.info (fun f -> - f "attempting a best-effort bind of ::1:%d" local_port); - unix_bind_one - ?description Lwt_unix.PF_INET6 ty Ipaddr.(V6 V6.localhost) local_port - >|= fun (idx, fd) -> - [ idx, fd ] - end else - Lwt.return [] - ) (fun e -> - Log.info (fun f -> - f "ignoring failed bind to ::1:%d (%a)" local_port Fmt.exn e); - Lwt.return [] - ) - >|= fun extra -> - (idx, fd) :: extra - - module Datagram = struct - type address = Ipaddr.t * int - - module Udp = struct - include Common - - type flow = { - mutable idx: int option; - description: string; - mutable fd: Lwt_unix.file_descr option; - read_buffer_size: int; - mutable already_read: Cstruct.t option; - sockaddr: Unix.sockaddr; - address: address; - } - - type address = Ipaddr.t * int - - let string_of_flow t = Fmt.strf "udp -> %s" (string_of_address t.address) - - let of_fd ?idx ~description ?(read_buffer_size = Constants.max_udp_length) - ?(already_read = None) sockaddr address fd = - { idx; description; fd = Some fd; read_buffer_size; already_read; - sockaddr; address } - - let connect ?read_buffer_size address = - let description = "udp:" ^ string_of_address address in - register_connection description - >>= fun idx -> - let pf, addr = match fst address with - | Ipaddr.V4 _ -> Lwt_unix.PF_INET, Unix.inet_addr_any - | Ipaddr.V6 _ -> Lwt_unix.PF_INET6, Unix.inet6_addr_any in - let fd = Lwt_unix.socket pf Lwt_unix.SOCK_DGRAM 0 in - (* Win32 requires all sockets to be bound however macOS and - Linux don't *) - Lwt.catch (fun () -> - Lwt_unix.bind fd (Lwt_unix.ADDR_INET(addr, 0)) - ) (fun _ -> Lwt.return_unit) - >|= fun () -> - let sockaddr = sockaddr_of_address address in - Ok (of_fd ~idx ~description ?read_buffer_size sockaddr address fd) - - let read t = match t.fd, t.already_read with - | None, _ -> Lwt.return (Ok `Eof) - | Some _, Some data when Cstruct.len data > 0 -> - t.already_read <- Some (Cstruct.sub data 0 0); (* next read is `Eof *) - Lwt.return (Ok (`Data data)) - | Some _, Some _ -> - Lwt.return (Ok `Eof) - | Some fd, None -> - let buffer = Cstruct.create t.read_buffer_size in - let bytes = Bytes.make t.read_buffer_size '\000' in - Lwt.catch (fun () -> - (* Lwt on Win32 doesn't support Lwt_bytes.recvfrom *) - Lwt_unix.recvfrom fd bytes 0 (Bytes.length bytes) [] - >>= fun (n, _) -> - Cstruct.blit_from_bytes bytes 0 buffer 0 n; - let response = Cstruct.sub buffer 0 n in - Lwt.return (Ok (`Data response)) - ) (fun e -> - Log.err (fun f -> - f "%s: recvfrom caught %a returning Eof" (string_of_flow t) - Fmt.exn e); - Lwt.return (Ok `Eof) - ) - - let write t buf = match t.fd with - | None -> Lwt.return (Error `Closed) - | Some fd -> - Lwt.catch (fun () -> - (* Lwt on Win32 doesn't support Lwt_bytes.sendto *) - let bytes = Bytes.make (Cstruct.len buf) '\000' in - Cstruct.blit_to_bytes buf 0 bytes 0 (Cstruct.len buf); - Lwt_unix.sendto fd bytes 0 (Bytes.length bytes) [] t.sockaddr - >|= fun _n -> - Ok () - ) (fun e -> - Log.err (fun f -> - f "%s: sendto caught %a returning Eof" (string_of_flow t) - Fmt.exn e); - Lwt.return (Error `Closed) - ) - - let writev t bufs = write t (Cstruct.concat bufs) - - let close t = match t.fd with - | None -> Lwt.return_unit - | Some fd -> - t.fd <- None; - Log.debug (fun f -> f "%s: close" (string_of_flow t)); - (match t.idx with Some idx -> deregister_connection idx | None -> ()); - Lwt_unix.close fd - - let shutdown_read _t = Lwt.return_unit - let shutdown_write _t = Lwt.return_unit - - type server = { - idx: int; - fd: Lwt_unix.file_descr; - mutable closed: bool; - mutable disable_connection_tracking: bool; - } - - let make ~idx fd = - { idx; fd; closed = false; disable_connection_tracking = false } - - let disable_connection_tracking server = - server.disable_connection_tracking <- true - - let bind ?description (ip, port) = - let pf = match ip with - | Ipaddr.V4 _ -> Lwt_unix.PF_INET - | Ipaddr.V6 _ -> Lwt_unix.PF_INET6 in - unix_bind_one ?description pf Lwt_unix.SOCK_DGRAM ip port - >|= fun (idx, fd) -> - make ~idx fd - - let of_bound_fd ?read_buffer_size:_ fd = - let description = match Unix.getsockname fd with - | Lwt_unix.ADDR_INET(iaddr, port) -> - Fmt.strf "udp:%s:%d" (Unix.string_of_inet_addr iaddr) port - | _ -> "unknown:" - in - let idx = register_connection_no_limit description in - make ~idx (Lwt_unix.of_unix_file_descr fd) - - let getsockname { fd; _ } = - match Lwt_unix.getsockname fd with - | Lwt_unix.ADDR_INET(iaddr, port) -> - Ipaddr.of_string_exn (Unix.string_of_inet_addr iaddr), port - | _ -> invalid_arg "Udp.getsockname passed a non-INET socket" - - let shutdown server = - if not server.closed then begin - server.closed <- true; - Lwt_unix.close server.fd >|= fun () -> - deregister_connection server.idx - end else - Lwt.return_unit - - let listen t flow_cb = - let bytes = Bytes.make Constants.max_udp_length '\000' in - let rec loop () = - Lwt.catch (fun () -> - (* Lwt on Win32 doesn't support Lwt_bytes.recvfrom *) - Lwt_unix.recvfrom t.fd bytes 0 (Bytes.length bytes) [] - >>= fun (n, sockaddr) -> - (* Allocate a fresh buffer because the packet will be processed - in a background thread *) - let data = Cstruct.create n in - Cstruct.blit_from_bytes bytes 0 data 0 n; - (* construct a flow with this buffer available for reading *) - ( match address_of_sockaddr sockaddr with - | Some address -> Lwt.return address - | None -> Lwt.fail_with "failed to discover incoming socket address" - ) >>= fun address -> - (* No new fd so no new idx *) - let description = Fmt.strf "udp:%s" (string_of_address address) in - let flow = - of_fd ~description ~read_buffer_size:0 ~already_read:(Some data) - sockaddr address t.fd - in - Lwt.async (fun () -> - Lwt.catch - (fun () -> flow_cb flow) - (fun e -> - Log.info (fun f -> - f "Udp.listen callback caught: %a" Fmt.exn e); - Lwt.return_unit - )); - Lwt.return true - ) (fun e -> - Log.err (fun f -> - f "Udp.listen caught %a shutting down server" Fmt.exn e); - Lwt.return false - ) - >>= function - | false -> Lwt.return_unit - | true -> loop () - in - Lwt.async loop - - let recvfrom server buf = - (* Lwt on Win32 doesn't support Lwt_bytes.sendto *) - let str = Bytes.create (Cstruct.len buf) in - Lwt_unix.recvfrom server.fd str 0 (String.length str) [] - >|= fun (len, sockaddr) -> - Cstruct.blit_from_string str 0 buf 0 len; - let address = match sockaddr with - | Lwt_unix.ADDR_INET(ip, port) -> - Ipaddr.of_string_exn @@ Unix.string_of_inet_addr ip, port - | _ -> - invalid_arg "recvfrom returned wrong sockaddr type" - in - len, address - - let sendto server (ip, port) buf = - (* Lwt on Win32 doesn't support Lwt_bytes.sendto *) - let len = Cstruct.len buf in - let str = Bytes.create len in - Cstruct.blit_to_bytes buf 0 str 0 len; - let sockaddr = - Lwt_unix.ADDR_INET - (Unix.inet_addr_of_string @@ Ipaddr.to_string ip, port) - in - Lwt_unix.sendto server.fd str 0 len [] sockaddr - >|= fun _ -> () - end - - end - - module Stream = struct - - (* Using Lwt_unix we share an implementation across various - transport types *) - module Fd = struct - - include Common - - type flow = { - idx: int; - description: string; - fd: Lwt_unix.file_descr; - read_buffer_size: int; - mutable read_buffer: Cstruct.t; - mutable closed: bool; - } - - let of_fd - ~idx ?(read_buffer_size = default_read_buffer_size) ~description fd - = - let read_buffer = Cstruct.create read_buffer_size in - let closed = false in - { idx; description; fd; read_buffer; read_buffer_size; closed } - - let shutdown_read { description; fd; closed; _ } = - try - if not closed then Lwt_unix.shutdown fd Unix.SHUTDOWN_RECEIVE; - Lwt.return () - with - | Unix.Unix_error(Unix.ENOTCONN, _, _) -> Lwt.return () - | e -> - Log.err (fun f -> - f "Socket.TCPV4.shutdown_read %s: caught %a returning Eof" - description Fmt.exn e); - Lwt.return () - - let shutdown_write { description; fd; closed; _ } = - try - if not closed then Lwt_unix.shutdown fd Unix.SHUTDOWN_SEND; - Lwt.return () - with - | Unix.Unix_error(Unix.ENOTCONN, _, _) -> Lwt.return () - | e -> - Log.err (fun f -> - f "Socket.TCPV4.shutdown_write %s: caught %a returning Eof" - description Fmt.exn e); - Lwt.return () - - let read t = - if t.closed then Lwt.return (Ok `Eof) - else begin - if Cstruct.len t.read_buffer = 0 - then t.read_buffer <- Cstruct.create t.read_buffer_size; - Lwt.catch (fun () -> - Lwt_bytes.read t.fd t.read_buffer.Cstruct.buffer - t.read_buffer.Cstruct.off t.read_buffer.Cstruct.len - >|= function - | 0 -> Ok `Eof - | n -> - let results = Cstruct.sub t.read_buffer 0 n in - t.read_buffer <- Cstruct.shift t.read_buffer n; - Ok (`Data results) - ) (fun e -> - Log.err (fun f -> - f "Socket.TCPV4.read %s: caught %a returning Eof" - t.description Fmt.exn e); - Lwt.return (Ok `Eof) - ) - end - - let read_into t buffer = - if t.closed then Lwt.return (Ok `Eof) - else Lwt.catch (fun () -> - Lwt_cstruct.(complete (read t.fd) buffer) >|= fun () -> - Ok (`Data ()) - ) (fun _e -> Lwt.return (Ok `Eof)) - - let write t buf = - if t.closed then Lwt.return (Error `Closed) - else Lwt.catch (fun () -> - Lwt_cstruct.(complete (write t.fd) buf) >|= fun () -> - Ok () - ) (fun e -> - Log.err (fun f -> - f "Socket.TCPV4.write %s: caught %a returning Eof" t.description - Fmt.exn e); - Lwt.return (Error `Closed) - ) - - let writev t bufs = - let rec loop = function - | [] -> Lwt.return (Ok ()) - | buf :: bufs -> - if t.closed then Lwt.return (Error `Closed) - else - Lwt_cstruct.(complete (write t.fd) buf) >>= fun () -> - loop bufs - in - Lwt.catch - (fun () -> loop bufs) - (fun e -> - Log.err (fun f -> - f "Socket.TCPV4.writev %s: caught %a returning Eof" - t.description Fmt.exn e); - Lwt.return (Error `Closed) - ) - - let close t = - if not t.closed then begin - t.closed <- true; - Lwt_unix.close t.fd >|= fun () -> - deregister_connection t.idx - end else - Lwt.return () - - let connect - description ?(read_buffer_size = default_read_buffer_size) - sock_domain sock_ty sockaddr - = - register_connection description >>= fun idx -> - let fd = Lwt_unix.socket sock_domain sock_ty 0 in - Lwt.catch (fun () -> - Log.debug (fun f -> f "%s: connecting" description); - Lwt_unix.connect fd sockaddr >|= fun () -> - Ok (of_fd ~idx ~read_buffer_size ~description fd) - ) (fun e -> - Lwt_unix.close fd >>= fun () -> - deregister_connection idx; - errorf "%s: Lwt_unix.connect: caught %a" description Fmt.exn e - ) - - type server = { - mutable listening_fds: (int * Lwt_unix.file_descr) list; - read_buffer_size: int; - path: string; (* only for Win32 *) - mutable closed: bool; - mutable disable_connection_tracking: bool; - } - - let make - ?(read_buffer_size = default_read_buffer_size) ?(path="") - listening_fds - = - { listening_fds; read_buffer_size; path; closed = false; - disable_connection_tracking = false } - - let disable_connection_tracking server = - server.disable_connection_tracking <- true - - let shutdown server = - let fds = server.listening_fds in - server.listening_fds <- []; - server.closed <- true; - Lwt_list.iter_s (fun (idx, fd) -> - Lwt_unix.close fd >|= fun () -> - deregister_connection idx - ) fds - - let of_bound_fd ?(read_buffer_size = default_read_buffer_size) fd = - let description = match Unix.getsockname fd with - | Lwt_unix.ADDR_INET(iaddr, port) -> - Fmt.strf "udp:%s:%d" (Unix.string_of_inet_addr iaddr) port - | _ -> "unknown:" - in - let idx = register_connection_no_limit description in - make ~read_buffer_size [ idx, Lwt_unix.of_unix_file_descr fd ] - - let listen server cb = - let rec loop fd = - Lwt_unix.accept fd >>= fun (client, sockaddr) -> - let read_buffer_size = server.read_buffer_size in - let description = string_of_sockaddr sockaddr in - Lwt.async (fun () -> - Lwt.catch (fun () -> - ( if server.disable_connection_tracking - then Lwt.return @@ register_connection_no_limit description - else register_connection description ) - >|= fun idx -> - Some (of_fd ~idx ~read_buffer_size ~description client) - ) (fun _e -> Lwt_unix.close client >|= fun () -> None) - >>= function - | None -> Lwt.return_unit - | Some flow -> - Lwt.finalize (fun () -> - log_exception_continue "Socket.Stream" (fun () -> cb flow) - ) (fun () -> close flow) - ); - loop fd - in - List.iter (fun (_idx, fd) -> - Lwt.async (fun () -> - log_exception_continue "Socket.Stream" (fun () -> - Lwt.finalize (fun () -> - Lwt_unix.listen fd (!Utils.somaxconn); - loop fd - ) (fun () -> shutdown server) - ) - ) - ) server.listening_fds - - end - - module Tcp = struct - include Fd - - type address = Ipaddr.t * int - - let connect ?read_buffer_size (ip, port) = - let description = Fmt.strf "%a:%d" Ipaddr.pp_hum ip port in - let sockaddr = - Unix.ADDR_INET (Unix.inet_addr_of_string @@ Ipaddr.to_string ip, port) - in - let pf = match ip with - | Ipaddr.V4 _ -> Lwt_unix.PF_INET - | Ipaddr.V6 _ -> Lwt_unix.PF_INET6 - in - connect description ?read_buffer_size pf Lwt_unix.SOCK_STREAM sockaddr - - let bind ?description (ip, port) = - unix_bind ?description Lwt_unix.SOCK_STREAM (ip, port) >|= make - - let getsockname server = match server.listening_fds with - | [] -> failwith "Tcp.getsockname: socket is closed" - | (_idx, fd) :: _ -> - match Lwt_unix.getsockname fd with - | Lwt_unix.ADDR_INET(iaddr, port) -> - Ipaddr.of_string_exn (Unix.string_of_inet_addr iaddr), port - | _ -> invalid_arg "Tcp.getsockname passed a non-INET socket" - end - - module Unix = struct - include Fd - - type address = string - - let is_win32 = Sys.os_type = "win32" - - let connect ?read_buffer_size path = - let description = "unix:" ^ path in - if is_win32 then - register_connection description>>= fun idx -> - Named_pipe_lwt.Client.openpipe path >|= fun p -> - let fd = Named_pipe_lwt.Client.to_fd p in - Ok (of_fd ~idx ?read_buffer_size ~description fd) - else - let sockaddr = Unix.ADDR_UNIX path in - connect description ?read_buffer_size Lwt_unix.PF_UNIX - Lwt_unix.SOCK_STREAM sockaddr - - let bind ?(description="") path = - let description = Fmt.strf "unix:%s %s" path description in - if is_win32 - then Lwt.return (make ~path []) - else - Lwt.catch - (fun () -> Lwt_unix.unlink path) - (function - | Unix.Unix_error(Unix.ENOENT, _, _) -> Lwt.return () - | e -> Lwt.fail e) - >>= fun () -> - register_connection description >>= fun idx -> - let s = Lwt_unix.socket Lwt_unix.PF_UNIX Lwt_unix.SOCK_STREAM 0 in - Lwt.catch (fun () -> - Lwt_unix.bind s (Lwt_unix.ADDR_UNIX path) >|= fun () -> - make ~path [ idx, s ] - ) (fun e -> - Lwt_unix.close s >>= fun () -> - deregister_connection idx; - Lwt.fail e - ) - - let listen server cb = - let rec loop () = - if server.closed - then Lwt.return_unit - else Lwt.catch (fun () -> - let p = Named_pipe_lwt.Server.create server.path in - Named_pipe_lwt.Server.connect p >>= fun () -> - let description = "named-pipe:" ^ server.path in - let read_buffer_size = server.read_buffer_size in - let fd = Named_pipe_lwt.Server.to_fd p in - Lwt.catch (fun () -> - (if server.disable_connection_tracking - then Lwt.return @@ register_connection_no_limit description - else register_connection description ) - >|= fun idx -> - Some (of_fd ~idx ~read_buffer_size ~description fd) - ) (fun _e -> Lwt_unix.close fd >|= fun () -> None) - >>= function - | None -> Lwt.return_unit - | Some flow -> - Lwt.async (fun () -> - Lwt.finalize (fun () -> - log_exception_continue "Socket.Stream.Unix" - (fun () -> cb flow) - ) (fun () -> close flow) - ); - loop () - ) (fun e -> - Log.err (fun f -> - f "Named-pipe connection failed on %s: %a" - server.path Fmt.exn e); - Lwt.return () - ) - in - if not is_win32 - then listen server cb - else Lwt.async (fun () -> - log_exception_continue "Socket.Stream.Unix" (fun () -> loop ()) - ) - - let getsockname server = server.path - - let unsafe_get_raw_fd t = - (* By default Lwt sets fds to non-blocking mode. Reverse this - to avoid surprising the caller. *) - Lwt_unix.set_blocking ~set_flags:true t.fd true; - Lwt_unix.unix_file_descr t.fd - - end - - end -end - -module Files = struct - - let read_file path = - Lwt.catch (fun () -> - Lwt_unix.openfile path [ Lwt_unix.O_RDONLY ] 0 >>= fun fd -> - let buffer = Buffer.create 128 in - let frag = Bytes.make 1024 ' ' in - Lwt.finalize (fun () -> - let rec loop () = - Lwt_unix.read fd frag 0 (Bytes.length frag) >>= function - | 0 -> Lwt_result.return (Buffer.contents buffer) - | n -> Buffer.add_substring buffer frag 0 n; loop () - in - loop () - ) (fun () -> Lwt_unix.close fd) - ) (fun e -> - Lwt_result.fail (`Msg (Fmt.strf "reading %s: %a" path Fmt.exn e)) - ) - - type watch = unit Lwt.t - - let watch_file path callback = - (* Poll the file every 5s seconds *) - let start () = - Lwt_unix.stat path - >>= function { Lwt_unix.st_mtime; _ } -> - callback (); - let rec poll st_mtime' = - Lwt_unix.stat path >>= fun { Lwt_unix.st_mtime; _ } -> - if st_mtime' <> st_mtime then callback (); - Lwt_unix.sleep 5. >>= fun () -> - poll st_mtime - in - poll st_mtime - in - (* On failure, wait another 5s and try again *) - let rec loop () = - Lwt.catch start (fun e -> - Log.err (fun f -> f "While watching %s: %a" path Fmt.exn e); - Lwt.return () - ) - >>= fun () -> - Lwt_unix.sleep 5. >>= fun () -> - loop () - in - Ok (loop ()) - - let unwatch = Lwt.cancel -end - -module Time = Time - -module Dns = struct - - (* FIXME: error handling completely missing *) - let getaddrinfo host domain = - let opts = [ Unix.AI_FAMILY domain ] in - let service = "" in - Lwt_unix.getaddrinfo host service opts - >>= fun x -> - Lwt.return @@ - List.fold_left (fun acc addr_info -> match addr_info.Unix.ai_addr with - | Unix.ADDR_INET(ip, _) -> - begin match Ipaddr.of_string @@ Unix.string_of_inet_addr ip with - | Some ip -> ip :: acc - | None -> acc - end - | _ -> acc - ) [] x - - let localhost_local = Dns.Name.of_string "localhost.local" - - let resolve question = - let open Dns.Packet in - begin match question with - | { q_class = Q_IN; q_name; _ } when q_name = localhost_local -> - Log.debug (fun f -> f "DNS lookup of localhost.local: return NXDomain"); - Lwt.return (q_name, []) - | { q_class = Q_IN; q_type = Q_A; q_name; _ } -> - getaddrinfo (Dns.Name.to_string q_name) Unix.PF_INET >|= fun ips -> - (q_name, ips) - | { q_class = Q_IN; q_type = Q_AAAA; q_name; _ } -> - getaddrinfo (Dns.Name.to_string q_name) Unix.PF_INET6 >|= fun ips -> - (q_name, ips) - | _ -> Lwt.return (Dns.Name.of_string "", []) - end - >>= function - | _, [] -> Lwt.return [] - | q_name, ips -> - let answers = List.map (function - | Ipaddr.V4 v4 -> - { name = q_name; cls = RR_IN; flush = false; ttl = 0l; rdata = A v4 } - | Ipaddr.V6 v6 -> - { name = q_name; cls = RR_IN; flush = false; ttl = 0l; rdata = AAAA v6 } - ) ips - in - Lwt.return answers -end - -module Main = struct - let run = Lwt_main.run - let run_in_main = Lwt_preemptive.run_in_main -end - -module Fn = struct - type ('request, 'response) t = 'request -> 'response - let create f = f - let destroy _ = () - let fn = Lwt_preemptive.detach -end diff --git a/src/hostnet/host_uwt.mli b/src/hostnet/host_uwt.mli deleted file mode 100644 index d87182be9..000000000 --- a/src/hostnet/host_uwt.mli +++ /dev/null @@ -1 +0,0 @@ -include Sig.HOST diff --git a/src/hostnet/slirp.ml b/src/hostnet/slirp.ml index 71443c9f0..08e2e2358 100644 --- a/src/hostnet/slirp.ml +++ b/src/hostnet/slirp.ml @@ -98,7 +98,6 @@ module Make val connect: unit -> t Lwt.t end) (Random: Mirage_random.C) - (Host: Sig.HOST) (Vnet : Vnetif.BACKEND with type macaddr = Macaddr.t) = struct (* module Tcpip_stack = Tcpip_stack.Make(Vmnet)(Host.Time) *) diff --git a/src/hostnet/slirp.mli b/src/hostnet/slirp.mli index f46f9c9f7..59d963ccd 100644 --- a/src/hostnet/slirp.mli +++ b/src/hostnet/slirp.mli @@ -42,7 +42,6 @@ module Make val connect: unit -> t Lwt.t end) (Random: Mirage_random.C) - (Host: Sig.HOST) (Vnet : Vnetif.BACKEND with type macaddr = Macaddr.t) : sig diff --git a/src/hostnet/vmnet.ml b/src/hostnet/vmnet.ml index f132f88b0..2cb2f83ae 100644 --- a/src/hostnet/vmnet.ml +++ b/src/hostnet/vmnet.ml @@ -464,8 +464,7 @@ module Make(C: Sig.CONN) = struct let callback buf = Lwt.catch (fun () -> t.callback buf) (function - | Host_uwt.Sockets.Too_many_connections - | Host_lwt_unix.Sockets.Too_many_connections -> + | Host.Sockets.Too_many_connections -> (* No need to log this again *) Lwt.return_unit | e -> diff --git a/src/hostnet_test/forwarding.ml b/src/hostnet_test/forwarding.ml index 02eefbd7f..7e00e83f1 100644 --- a/src/hostnet_test/forwarding.ml +++ b/src/hostnet_test/forwarding.ml @@ -11,8 +11,6 @@ let (>>*=) m f = m >>= function | Ok x -> f x | Error (`Msg m) -> failwith m -module Make(Host: Sig.HOST) = struct - let run ?(timeout=Duration.of_sec 60) t = let timeout = Host.Time.sleep_ns timeout >>= fun () -> @@ -306,4 +304,3 @@ module Make(Host: Sig.HOST) = struct `Quick, test_10_connections ]; ] -end diff --git a/src/hostnet_test/jbuild b/src/hostnet_test/jbuild index 4d9859c2d..bfd34c544 100644 --- a/src/hostnet_test/jbuild +++ b/src/hostnet_test/jbuild @@ -1,10 +1,10 @@ (jbuild_version 1) -(executables - ((names (main_lwt main_uwt)) - (libraries ( +(executable + ((name main) + (libraries ( hostnet cmdliner alcotest lwt.unix logs.fmt protocol-9p mirage-dns lwt.preemptive uwt.preemptive mirage-clock-unix charrua-client-mirage )) - (preprocess no_preprocessing))) + (preprocess no_preprocessing))) diff --git a/src/hostnet_test/main_uwt.ml b/src/hostnet_test/main.ml similarity index 93% rename from src/hostnet_test/main_uwt.ml rename to src/hostnet_test/main.ml index 9addea7a0..413605640 100644 --- a/src/hostnet_test/main_uwt.ml +++ b/src/hostnet_test/main.ml @@ -5,8 +5,6 @@ let src = module Log = (val Logs.src_log src : Logs.LOG) -module Tests = Suite.Make(Host_uwt) - let ppf, flush = let b = Buffer.create 255 in let flush () = let s = Buffer.contents b in Buffer.clear b; s in @@ -38,4 +36,4 @@ let () = Printf.fprintf stderr "Starting test case %s\n%!" case; fn () ) cases - ) (Tests.tests @ Tests.scalability) + ) (Suite.tests @ Suite.scalability) diff --git a/src/hostnet_test/main_lwt.ml b/src/hostnet_test/main_lwt.ml deleted file mode 100644 index 84384c78f..000000000 --- a/src/hostnet_test/main_lwt.ml +++ /dev/null @@ -1,19 +0,0 @@ -let src = - let src = Logs.Src.create "test" ~doc:"Test the slirp stack" in - Logs.Src.set_level src (Some Logs.Debug); - src - -module Log = (val Logs.src_log src : Logs.LOG) - -module Tests = Suite.Make(Host_lwt_unix) - -(* Run it *) -let () = - Logs.set_reporter (Logs_fmt.reporter ()); - Lwt.async_exception_hook := (fun exn -> - Log.err (fun f -> f "Lwt.async failure %s: %s" - (Printexc.to_string exn) - (Printexc.get_backtrace ()) - ) - ); - Alcotest.run "Hostnet" Tests.tests diff --git a/src/hostnet_test/slirp_stack.ml b/src/hostnet_test/slirp_stack.ml index 769e2f8d9..fba5f534f 100644 --- a/src/hostnet_test/slirp_stack.ml +++ b/src/hostnet_test/slirp_stack.ml @@ -62,12 +62,11 @@ module Dns_policy = struct end -module Make(Host: Sig.HOST) = struct module VMNET = Vmnet.Make(Host.Sockets.Stream.Tcp) module Config = Active_config.Make(Host.Time)(Host.Sockets.Stream.Unix) module Vnet = Basic_backend.Make module Slirp_stack = - Slirp.Make(Config)(VMNET)(Dns_policy)(Mclock)(Stdlibrandom)(Host)(Vnet) + Slirp.Make(Config)(VMNET)(Dns_policy)(Mclock)(Stdlibrandom)(Vnet) module Client = struct module Netif = VMNET @@ -206,4 +205,3 @@ module Make(Host: Sig.HOST) = struct (* Server will close when it gets EOF *) VMNET.disconnect client' ) -end diff --git a/src/hostnet_test/suite.ml b/src/hostnet_test/suite.ml index c227417fc..e6a2148a9 100644 --- a/src/hostnet_test/suite.ml +++ b/src/hostnet_test/suite.ml @@ -1,4 +1,5 @@ open Lwt.Infix +open Slirp_stack let src = let src = Logs.Src.create "test" ~doc:"Test the slirp stack" in @@ -10,12 +11,6 @@ module Log = (val Logs.src_log src : Logs.LOG) let pp_ips = Fmt.(list ~sep:(unit ", ") Ipaddr.pp_hum) let pp_ip4s = Fmt.(list ~sep:(unit ", ") Ipaddr.V4.pp_hum) -module Make(Host: Sig.HOST) = struct - - module Dns_policy = Slirp_stack.Dns_policy - module Slirp_stack = Slirp_stack.Make(Host) - open Slirp_stack - let run_test ?(timeout=Duration.of_sec 60) t = let timeout = Host.Time.sleep_ns timeout >>= fun () -> @@ -386,20 +381,13 @@ module Make(Host: Sig.HOST) = struct *) ] - module F = Forwarding.Make(Host) - module N = Test_nat.Make(Host) - module H = Test_http.Make(Host) - module T = Test_half_close.Make(Host) - let tests = - Hosts_test.tests @ F.tests @ test_dhcp + Hosts_test.tests @ Forwarding.tests @ test_dhcp @ (test_dns true) @ (test_dns false) - @ test_tcp @ N.tests @ H.tests @ Test_http.Exclude.tests - @ T.tests + @ test_tcp @ Test_nat.tests @ Test_http.tests @ Test_http.Exclude.tests + @ Test_half_close.tests let scalability = [ "1026conns", [ "Test many connections", `Quick, test_many_connections (1024 + 2) ]; ] - -end diff --git a/src/hostnet_test/test_half_close.ml b/src/hostnet_test/test_half_close.ml index 9f3c84b37..f6c6dc99e 100644 --- a/src/hostnet_test/test_half_close.ml +++ b/src/hostnet_test/test_half_close.ml @@ -9,10 +9,6 @@ let failf fmt = Fmt.kstrf failwith fmt module Log = (val Logs.src_log src : Logs.LOG) -module Make(Host: Sig.HOST) = struct - - module Slirp_stack = Slirp_stack.Make(Host) - module Server = struct type t = { server: Host.Sockets.Stream.Tcp.server; @@ -182,5 +178,3 @@ module Make(Host: Sig.HOST) = struct test_host_half_close ]; ] - -end diff --git a/src/hostnet_test/test_http.ml b/src/hostnet_test/test_http.ml index fd09124fd..730eb97cc 100644 --- a/src/hostnet_test/test_http.ml +++ b/src/hostnet_test/test_http.ml @@ -71,10 +71,6 @@ module Exclude = struct ] end -module Make(Host: Sig.HOST) = struct - - module Slirp_stack = Slirp_stack.Make(Host) - module Server = struct type t = { server: Host.Sockets.Stream.Tcp.server; @@ -330,5 +326,3 @@ module Make(Host: Sig.HOST) = struct "HTTP: CONNECT", [ "check that HTTP CONNECT works for HTTPS", `Quick, test_http_connect ]; ] - -end diff --git a/src/hostnet_test/test_nat.ml b/src/hostnet_test/test_nat.ml index df3f9701d..febfc48c6 100644 --- a/src/hostnet_test/test_nat.ml +++ b/src/hostnet_test/test_nat.ml @@ -1,4 +1,5 @@ open Lwt.Infix +open Slirp_stack let src = let src = Logs.Src.create "test" ~doc:"Test the slirp stack" in @@ -7,8 +8,6 @@ let src = module Log = (val Logs.src_log src : Logs.LOG) -module Make(Host: Sig.HOST) = struct - let run ?(timeout=Duration.of_sec 60) t = let timeout = Host.Time.sleep_ns timeout >>= fun () -> @@ -16,9 +15,6 @@ module Make(Host: Sig.HOST) = struct in Host.Main.run @@ Lwt.pick [ timeout; t ] - module Slirp_stack = Slirp_stack.Make(Host) - open Slirp_stack - module EchoServer = struct (* Receive UDP packets and copy them back to all senders. Roughly simulates a chat protocol, in particular this allows us to test many replies to one @@ -382,4 +378,3 @@ module Make(Host: Sig.HOST) = struct "NAT: punch", [ "", `Quick, test_nat_punch ]; "NAT: source ports", [ "", `Quick, test_source_ports ]; ] -end From d5b3934908d8a45e20fb0249e52060c66cf411a4 Mon Sep 17 00:00:00 2001 From: Thomas Gazagnaire Date: Thu, 3 Aug 2017 17:39:52 +0200 Subject: [PATCH 2/2] Run ocp-indent -i on all the source code --- src/hostnet/arp.ml | 6 +- src/hostnet/host.ml | 38 +- src/hostnet/mux.ml | 10 +- src/hostnet/sig.ml | 6 +- src/hostnet/slirp.ml | 2 +- src/hostnet/vmnet.ml | 2 +- src/hostnet_test/forwarding.ml | 564 +++++++++++------------ src/hostnet_test/main.ml | 2 +- src/hostnet_test/slirp_stack.ml | 284 ++++++------ src/hostnet_test/suite.ml | 670 +++++++++++++-------------- src/hostnet_test/test_half_close.ml | 330 +++++++------- src/hostnet_test/test_http.ml | 454 +++++++++---------- src/hostnet_test/test_nat.ml | 676 ++++++++++++++-------------- 13 files changed, 1522 insertions(+), 1522 deletions(-) diff --git a/src/hostnet/arp.ml b/src/hostnet/arp.ml index 9a2534323..887365512 100644 --- a/src/hostnet/arp.ml +++ b/src/hostnet/arp.ml @@ -143,7 +143,7 @@ module Make (Ethif: Mirage_protocols_lwt.ETHIF) = struct f "error while reading ARP packet: %a" Ethif.pp_error e); end else Lwt.return_unit |2 -> (* Reply *) - (* the requested address *) + (* the requested address *) let spa = Ipaddr.V4.of_int32 (get_arp_tpa frame) in Log.debug (fun f -> f "ARP ignoring reply %s" (Ipaddr.V4.to_string spa)); Lwt.return_unit @@ -156,8 +156,8 @@ module Make (Ethif: Mirage_protocols_lwt.ETHIF) = struct let connect ~table ethif = let table = List.fold_left (fun acc (ip, mac) -> - Table.add ip mac acc - ) Table.empty table + Table.add ip mac acc + ) Table.empty table in { table; ethif } diff --git a/src/hostnet/host.ml b/src/hostnet/host.ml index 45fc51114..08ce39370 100644 --- a/src/hostnet/host.ml +++ b/src/hostnet/host.ml @@ -409,30 +409,30 @@ module Sockets = struct let connect ?(read_buffer_size = default_read_buffer_size) (ip, port) = let description = Fmt.strf "tcp:%a:%d" Ipaddr.pp_hum ip port in let label = match ip with - | Ipaddr.V4 _ -> "TCPv4" - | Ipaddr.V6 _ -> "TCPv6" in + | Ipaddr.V4 _ -> "TCPv4" + | Ipaddr.V6 _ -> "TCPv6" in register_connection_noexn description >>= function | None -> errorf "Socket.%s.connect %s: hit connection limit" label description | Some idx -> - let fd = - try match ip with - | Ipaddr.V4 _ -> Uwt.Tcp.init_ipv4_exn () - | Ipaddr.V6 _ -> Uwt.Tcp.init_ipv6_exn () - with e -> deregister_connection idx; raise e in - Lwt.catch (fun () -> - let sockaddr = make_sockaddr (ip, port) in - Uwt.Tcp.connect fd ~addr:sockaddr >>= fun () -> - of_fd ~idx ~label ~read_buffer_size ~description fd - |> Lwt_result.return - ) (fun e -> - deregister_connection idx; - log_exception_continue "Tcp.connect Uwt.Tcp.close_wait" - (fun () -> Uwt.Tcp.close_wait fd) - >>= fun () -> - errorf "Socket.%s.connect %s: caught %a" label description Fmt.exn e - ) + let fd = + try match ip with + | Ipaddr.V4 _ -> Uwt.Tcp.init_ipv4_exn () + | Ipaddr.V6 _ -> Uwt.Tcp.init_ipv6_exn () + with e -> deregister_connection idx; raise e in + Lwt.catch (fun () -> + let sockaddr = make_sockaddr (ip, port) in + Uwt.Tcp.connect fd ~addr:sockaddr >>= fun () -> + of_fd ~idx ~label ~read_buffer_size ~description fd + |> Lwt_result.return + ) (fun e -> + deregister_connection idx; + log_exception_continue "Tcp.connect Uwt.Tcp.close_wait" + (fun () -> Uwt.Tcp.close_wait fd) + >>= fun () -> + errorf "Socket.%s.connect %s: caught %a" label description Fmt.exn e + ) let shutdown_read _ = Lwt.return () diff --git a/src/hostnet/mux.ml b/src/hostnet/mux.ml index 6d8889774..dd1e98083 100644 --- a/src/hostnet/mux.ml +++ b/src/hostnet/mux.ml @@ -93,11 +93,11 @@ module Make (Netif: Mirage_net_lwt.S) = struct let t = { netif; rules; default_callback } in Lwt.async (fun () -> - Netif.listen netif @@ callback t >>= function - | Ok () -> Lwt.return_unit - | Error _e -> - Log.err (fun f -> f "Mux.connect calling Netif.listen: failed"); - Lwt.return_unit + Netif.listen netif @@ callback t >>= function + | Ok () -> Lwt.return_unit + | Error _e -> + Log.err (fun f -> f "Mux.connect calling Netif.listen: failed"); + Lwt.return_unit ); Lwt.return (Ok t) diff --git a/src/hostnet/sig.ml b/src/hostnet/sig.ml index 8f9f48f04..1eb1be917 100644 --- a/src/hostnet/sig.ml +++ b/src/hostnet/sig.ml @@ -4,7 +4,7 @@ module type READ_INTO = sig val read_into: flow -> Cstruct.t -> (unit Mirage_flow.or_eof, error) result Lwt.t - (** Completely fills the given buffer with data from [fd] *) + (** Completely fills the given buffer with data from [fd] *) end module type FLOW_CLIENT = sig @@ -14,8 +14,8 @@ module type FLOW_CLIENT = sig val connect: ?read_buffer_size:int -> address -> (flow, [`Msg of string]) result Lwt.t - (** [connect address] creates a connection to [address] and returns - he connected flow. *) + (** [connect address] creates a connection to [address] and returns + he connected flow. *) end module type CONN = sig diff --git a/src/hostnet/slirp.ml b/src/hostnet/slirp.ml index 08e2e2358..94122bae2 100644 --- a/src/hostnet/slirp.ml +++ b/src/hostnet/slirp.ml @@ -102,7 +102,7 @@ module Make struct (* module Tcpip_stack = Tcpip_stack.Make(Vmnet)(Host.Time) *) -module Filteredif = Filter.Make(Vmnet) + module Filteredif = Filter.Make(Vmnet) module Netif = Capture.Make(Filteredif) module Recorder = (Netif: Sig.RECORDER with type t = Netif.t) module Switch = Mux.Make(Netif) diff --git a/src/hostnet/vmnet.ml b/src/hostnet/vmnet.ml index 2cb2f83ae..a8dc614c7 100644 --- a/src/hostnet/vmnet.ml +++ b/src/hostnet/vmnet.ml @@ -331,7 +331,7 @@ module Make(C: Sig.CONN) = struct server_negotiate ~fd:channel ~client_macaddr_of_uuid ~mtu >>= fun (client_uuid, client_macaddr) -> let t = make ~client_macaddr ~server_macaddr ~mtu ~client_uuid - ~log_prefix:server_log_prefix channel in + ~log_prefix:server_log_prefix channel in Lwt_result.return t let client_of_fd ~uuid ~server_macaddr flow = diff --git a/src/hostnet_test/forwarding.ml b/src/hostnet_test/forwarding.ml index 7e00e83f1..ce961ad7c 100644 --- a/src/hostnet_test/forwarding.ml +++ b/src/hostnet_test/forwarding.ml @@ -11,296 +11,296 @@ let (>>*=) m f = m >>= function | Ok x -> f x | Error (`Msg m) -> failwith m - let run ?(timeout=Duration.of_sec 60) t = - let timeout = - Host.Time.sleep_ns timeout >>= fun () -> - Lwt.fail_with "timeout" - in - Host.Main.run @@ Lwt.pick [ timeout; t ] - - module Channel = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) - - module ForwardServer = struct - (** Accept connections, read the forwarding header and run a proxy *) - - module Proxy = - Mirage_flow_lwt.Proxy - (Mclock)(Host.Sockets.Stream.Tcp)(Host.Sockets.Stream.Tcp) - - let accept flow = - let sizeof = 1 + 2 + 4 + 2 in - let header = Cstruct.create sizeof in - Host.Sockets.Stream.Tcp.read_into flow header >>= function - | Ok `Eof -> failwith "EOF" - | Error e -> Fmt.kstrf failwith "%a" Host.Sockets.Stream.Tcp.pp_error e - | Ok (`Data ()) -> - let ip_len = Cstruct.LE.get_uint16 header 1 in - let ip = - let bytes = Cstruct.(to_string @@ sub header 3 ip_len) in - if String.length bytes = 4 - then Ipaddr.V4.of_bytes_exn bytes - else assert false in (* IPv4 only *) - let port = Cstruct.LE.get_uint16 header 7 in - assert (Cstruct.get_uint8 header 0 == 1); (* TCP only *) - - Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 ip, port) >>= function - | Error (`Msg x) -> failwith x - | Ok remote -> - Mclock.connect () >>= fun clock -> - Lwt.finalize (fun () -> - Proxy.proxy clock flow remote >>= function - | Error e -> Fmt.kstrf failwith "%a" Proxy.pp_error e - | Ok (_l_stats, _r_stats) -> Lwt.return () - ) (fun () -> - Host.Sockets.Stream.Tcp.close remote - ) +let run ?(timeout=Duration.of_sec 60) t = + let timeout = + Host.Time.sleep_ns timeout >>= fun () -> + Lwt.fail_with "timeout" + in + Host.Main.run @@ Lwt.pick [ timeout; t ] + +module Channel = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) + +module ForwardServer = struct + (** Accept connections, read the forwarding header and run a proxy *) + + module Proxy = + Mirage_flow_lwt.Proxy + (Mclock)(Host.Sockets.Stream.Tcp)(Host.Sockets.Stream.Tcp) + + let accept flow = + let sizeof = 1 + 2 + 4 + 2 in + let header = Cstruct.create sizeof in + Host.Sockets.Stream.Tcp.read_into flow header >>= function + | Ok `Eof -> failwith "EOF" + | Error e -> Fmt.kstrf failwith "%a" Host.Sockets.Stream.Tcp.pp_error e + | Ok (`Data ()) -> + let ip_len = Cstruct.LE.get_uint16 header 1 in + let ip = + let bytes = Cstruct.(to_string @@ sub header 3 ip_len) in + if String.length bytes = 4 + then Ipaddr.V4.of_bytes_exn bytes + else assert false in (* IPv4 only *) + let port = Cstruct.LE.get_uint16 header 7 in + assert (Cstruct.get_uint8 header 0 == 1); (* TCP only *) + + Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 ip, port) >>= function + | Error (`Msg x) -> failwith x + | Ok remote -> + Mclock.connect () >>= fun clock -> + Lwt.finalize (fun () -> + Proxy.proxy clock flow remote >>= function + | Error e -> Fmt.kstrf failwith "%a" Proxy.pp_error e + | Ok (_l_stats, _r_stats) -> Lwt.return () + ) (fun () -> + Host.Sockets.Stream.Tcp.close remote + ) - let port = - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) - >>= fun server -> - let _, local_port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server accept; - Lwt.return local_port - - type t = { - local_port: int; - server: Host.Sockets.Stream.Tcp.server; - } - end - - module Forward = Forward.Make(Mclock)(struct - include Host.Sockets.Stream.Tcp - - open Lwt.Infix - - let connect () = - ForwardServer.port >>= fun port -> - Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 Ipaddr.V4.localhost, port) - >>= function - | Error (`Msg m) -> failwith m - | Ok x -> Lwt.return x - end)(Host.Sockets) - - let localhost = Ipaddr.V4.localhost - - module PortsServer = struct - module Ports = Active_list.Make(Forward) - module Server = Protocol_9p.Server.Make(Log)(Host.Sockets.Stream.Tcp)(Ports) - - let with_server f = - Mclock.connect () >>= fun clock -> - let ports = Ports.make clock in - Ports.set_context ports ""; - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 localhost, 0) - >>= fun server -> - let _, port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server - (fun conn -> - Server.connect ports conn () - >>= function - | Error (`Msg m) -> - Log.err (fun f -> f "failed to establish 9P connection: %s" m); - Lwt.return () - | Ok server -> - Server.after_disconnect server - ); - f port - >>= fun () -> - Host.Sockets.Stream.Tcp.shutdown server - end - - module LocalClient = struct - let connect (ip, port) = - Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 ip, port) - >>= function - | Ok fd -> Lwt.return fd - | Error (`Msg m) -> failwith m - let disconnect fd = Host.Sockets.Stream.Tcp.close fd - end - - let read_http ch = - let rec loop acc = - Channel.read_line ch >>= function - | Ok `Eof - | Error _ -> Lwt.return acc - | Ok (`Data bufs) -> - let txt = Cstruct.(to_string (concat bufs)) in - if txt = "" - then Lwt.return acc - else loop (acc ^ txt) - in - loop "" - - module LocalServer = struct - type t = { - local_port: int; - server: Host.Sockets.Stream.Tcp.server; - } - - let accept flow = - let ch = Channel.create flow in - read_http ch >>= fun request -> - if not(Astring.String.is_prefix ~affix:"GET" request) - then failwith (Printf.sprintf "unrecognised HTTP GET: [%s]" request); - let response = "HTTP/1.0 404 Not found\r\ncontent-length: 0\r\n\r\n" in - Channel.write_string ch response 0 (String.length response); - Channel.flush ch >|= function - | Ok () -> () - | Error e -> Fmt.kstrf failwith "%a" Channel.pp_write_error e - - let create () = - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 localhost, 0) - >|= fun server -> - let _, local_port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server accept; - { local_port; server } - - let to_string t = Printf.sprintf "tcp:127.0.0.1:%d" t.local_port - let destroy t = Host.Sockets.Stream.Tcp.shutdown t.server - let with_server f = - create () >>= fun server -> - Lwt.finalize (fun () -> f server) (fun () -> destroy server) - end - - module ForwardControl = struct - module Log = (val Logs.src_log Logs.default) - module Client = Protocol_9p.Client.Make(Log)(Host.Sockets.Stream.Tcp) - - type t = { - ninep: Client.t - } - - let connect ports_port = - Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 localhost, ports_port) + let port = + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) + >>= fun server -> + let _, local_port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server accept; + Lwt.return local_port + + type t = { + local_port: int; + server: Host.Sockets.Stream.Tcp.server; + } +end + +module Forward = Forward.Make(Mclock)(struct + include Host.Sockets.Stream.Tcp + + open Lwt.Infix + + let connect () = + ForwardServer.port >>= fun port -> + Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 Ipaddr.V4.localhost, port) >>= function | Error (`Msg m) -> failwith m - | Ok flow -> - Client.connect flow () >>*= fun ninep -> - Lwt.return { ninep } - - let disconnect { ninep } = Client.disconnect ninep - - let with_connection ports_port f = - connect ports_port >>= fun c -> - Lwt.finalize (fun () -> f c) (fun () -> disconnect c) - - type forward = { - t: t; - fid: Protocol_9p.Types.Fid.t; - ip: Ipaddr.V4.t; - port: int; - } - - let create t string = - let mode = Protocol_9p.Types.FileMode.make ~is_directory:true - ~owner:[`Read; `Write; `Execute] ~group:[`Read; `Execute] - ~other:[`Read; `Execute ] () in - Client.mkdir t.ninep [] string mode - >>*= fun () -> - Client.LowLevel.allocate_fid t.ninep - >>*= fun fid -> - Client.walk_from_root t.ninep fid [ string; "ctl" ] - >>*= fun _walk -> - Client.LowLevel.openfid t.ninep fid Protocol_9p.Types.OpenMode.read_write - >>*= fun _open -> - let buf = Cstruct.create (String.length string) in - Cstruct.blit_from_string string 0 buf 0 (String.length string); - Client.LowLevel.write t.ninep fid 0L buf - >>*= fun _write -> - Client.LowLevel.read t.ninep fid 0L 1024l - >>*= fun read -> - let response = Cstruct.to_string read.Protocol_9p.Response.Read.data in - if Astring.String.is_prefix ~affix:"OK " response then begin - let line = String.sub response 3 (String.length response - 3) in - (* tcp:127.0.0.1:64500:tcp:127.0.0.1:64499 *) - match Astring.String.cuts ~sep:":" line with - | "tcp" :: ip :: port :: _ -> - let port = int_of_string port in - let ip = Ipaddr.V4.of_string_exn ip in - Lwt.return { t; fid; ip; port } - | _ -> failwith ("failed to parse response: " ^ line) - end else failwith response - let destroy { t; fid; _ } = - Client.LowLevel.clunk t.ninep fid - >>*= fun _clunk -> - Lwt.return () - let with_forward t string f = - create t string - >>= fun forward -> - Lwt.finalize (fun () -> f forward.ip forward.port) (fun () -> destroy forward) - end - - let http_get flow = + | Ok x -> Lwt.return x + end)(Host.Sockets) + +let localhost = Ipaddr.V4.localhost + +module PortsServer = struct + module Ports = Active_list.Make(Forward) + module Server = Protocol_9p.Server.Make(Log)(Host.Sockets.Stream.Tcp)(Ports) + + let with_server f = + Mclock.connect () >>= fun clock -> + let ports = Ports.make clock in + Ports.set_context ports ""; + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 localhost, 0) + >>= fun server -> + let _, port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server + (fun conn -> + Server.connect ports conn () + >>= function + | Error (`Msg m) -> + Log.err (fun f -> f "failed to establish 9P connection: %s" m); + Lwt.return () + | Ok server -> + Server.after_disconnect server + ); + f port + >>= fun () -> + Host.Sockets.Stream.Tcp.shutdown server +end + +module LocalClient = struct + let connect (ip, port) = + Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 ip, port) + >>= function + | Ok fd -> Lwt.return fd + | Error (`Msg m) -> failwith m + let disconnect fd = Host.Sockets.Stream.Tcp.close fd +end + +let read_http ch = + let rec loop acc = + Channel.read_line ch >>= function + | Ok `Eof + | Error _ -> Lwt.return acc + | Ok (`Data bufs) -> + let txt = Cstruct.(to_string (concat bufs)) in + if txt = "" + then Lwt.return acc + else loop (acc ^ txt) + in + loop "" + +module LocalServer = struct + type t = { + local_port: int; + server: Host.Sockets.Stream.Tcp.server; + } + + let accept flow = let ch = Channel.create flow in - let message = "GET / HTTP/1.0\r\nconnection: close\r\n\r\n" in - Channel.write_string ch message 0 (String.length message); - Channel.flush ch >>= function + read_http ch >>= fun request -> + if not(Astring.String.is_prefix ~affix:"GET" request) + then failwith (Printf.sprintf "unrecognised HTTP GET: [%s]" request); + let response = "HTTP/1.0 404 Not found\r\ncontent-length: 0\r\n\r\n" in + Channel.write_string ch response 0 (String.length response); + Channel.flush ch >|= function + | Ok () -> () | Error e -> Fmt.kstrf failwith "%a" Channel.pp_write_error e - | Ok () -> - Host.Sockets.Stream.Tcp.shutdown_write flow - >>= fun () -> - read_http ch - >|= fun response -> - if not(Astring.String.is_prefix ~affix:"HTTP" response) - then failwith (Printf.sprintf "unrecognised HTTP response: [%s]" response) - - let test_one_forward () = - let t = LocalServer.with_server (fun server -> - PortsServer.with_server (fun ports_port -> - ForwardControl.with_connection ports_port (fun connection -> - let name = "tcp:127.0.0.1:0:" ^ LocalServer.to_string server in - ForwardControl.with_forward connection name (fun ip port -> + + let create () = + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 localhost, 0) + >|= fun server -> + let _, local_port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server accept; + { local_port; server } + + let to_string t = Printf.sprintf "tcp:127.0.0.1:%d" t.local_port + let destroy t = Host.Sockets.Stream.Tcp.shutdown t.server + let with_server f = + create () >>= fun server -> + Lwt.finalize (fun () -> f server) (fun () -> destroy server) +end + +module ForwardControl = struct + module Log = (val Logs.src_log Logs.default) + module Client = Protocol_9p.Client.Make(Log)(Host.Sockets.Stream.Tcp) + + type t = { + ninep: Client.t + } + + let connect ports_port = + Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 localhost, ports_port) + >>= function + | Error (`Msg m) -> failwith m + | Ok flow -> + Client.connect flow () >>*= fun ninep -> + Lwt.return { ninep } + + let disconnect { ninep } = Client.disconnect ninep + + let with_connection ports_port f = + connect ports_port >>= fun c -> + Lwt.finalize (fun () -> f c) (fun () -> disconnect c) + + type forward = { + t: t; + fid: Protocol_9p.Types.Fid.t; + ip: Ipaddr.V4.t; + port: int; + } + + let create t string = + let mode = Protocol_9p.Types.FileMode.make ~is_directory:true + ~owner:[`Read; `Write; `Execute] ~group:[`Read; `Execute] + ~other:[`Read; `Execute ] () in + Client.mkdir t.ninep [] string mode + >>*= fun () -> + Client.LowLevel.allocate_fid t.ninep + >>*= fun fid -> + Client.walk_from_root t.ninep fid [ string; "ctl" ] + >>*= fun _walk -> + Client.LowLevel.openfid t.ninep fid Protocol_9p.Types.OpenMode.read_write + >>*= fun _open -> + let buf = Cstruct.create (String.length string) in + Cstruct.blit_from_string string 0 buf 0 (String.length string); + Client.LowLevel.write t.ninep fid 0L buf + >>*= fun _write -> + Client.LowLevel.read t.ninep fid 0L 1024l + >>*= fun read -> + let response = Cstruct.to_string read.Protocol_9p.Response.Read.data in + if Astring.String.is_prefix ~affix:"OK " response then begin + let line = String.sub response 3 (String.length response - 3) in + (* tcp:127.0.0.1:64500:tcp:127.0.0.1:64499 *) + match Astring.String.cuts ~sep:":" line with + | "tcp" :: ip :: port :: _ -> + let port = int_of_string port in + let ip = Ipaddr.V4.of_string_exn ip in + Lwt.return { t; fid; ip; port } + | _ -> failwith ("failed to parse response: " ^ line) + end else failwith response + let destroy { t; fid; _ } = + Client.LowLevel.clunk t.ninep fid + >>*= fun _clunk -> + Lwt.return () + let with_forward t string f = + create t string + >>= fun forward -> + Lwt.finalize (fun () -> f forward.ip forward.port) (fun () -> destroy forward) +end + +let http_get flow = + let ch = Channel.create flow in + let message = "GET / HTTP/1.0\r\nconnection: close\r\n\r\n" in + Channel.write_string ch message 0 (String.length message); + Channel.flush ch >>= function + | Error e -> Fmt.kstrf failwith "%a" Channel.pp_write_error e + | Ok () -> + Host.Sockets.Stream.Tcp.shutdown_write flow + >>= fun () -> + read_http ch + >|= fun response -> + if not(Astring.String.is_prefix ~affix:"HTTP" response) + then failwith (Printf.sprintf "unrecognised HTTP response: [%s]" response) + +let test_one_forward () = + let t = LocalServer.with_server (fun server -> + PortsServer.with_server (fun ports_port -> + ForwardControl.with_connection ports_port (fun connection -> + let name = "tcp:127.0.0.1:0:" ^ LocalServer.to_string server in + ForwardControl.with_forward connection name (fun ip port -> + LocalClient.connect (ip, port) + >>= fun client -> + http_get client + >>= fun () -> + LocalClient.disconnect client + ) + ) + ) + ) in + run t + +let test_10_connections () = + let t = LocalServer.with_server (fun server -> + PortsServer.with_server (fun ports_port -> + ForwardControl.with_connection ports_port (fun connection -> + let name = "tcp:127.0.0.1:0:" ^ LocalServer.to_string server in + ForwardControl.with_forward connection name (fun ip port -> + let rec loop = function + | 0 -> Lwt.return () + | n -> LocalClient.connect (ip, port) >>= fun client -> http_get client >>= fun () -> LocalClient.disconnect client - ) - ) - ) - ) in - run t - - let test_10_connections () = - let t = LocalServer.with_server (fun server -> - PortsServer.with_server (fun ports_port -> - ForwardControl.with_connection ports_port (fun connection -> - let name = "tcp:127.0.0.1:0:" ^ LocalServer.to_string server in - ForwardControl.with_forward connection name (fun ip port -> - let rec loop = function - | 0 -> Lwt.return () - | n -> - LocalClient.connect (ip, port) - >>= fun client -> - http_get client - >>= fun () -> - LocalClient.disconnect client - >>= fun () -> - loop (n - 1) - in - let start = Unix.gettimeofday () in - loop 10 >>= fun () -> - let time = Unix.gettimeofday () -. start in - (* NOTE(djs55): on my MBP this is almost immediate *) - if time > 1. then - Fmt.kstrf failwith "10 connections took %.02f (> 1) \ - seconds" time; - Lwt.return () - ) - ) - ) - ) in - run t - - let tests = [ - "Ports: 1 port forward", - [ "Perform an HTTP GET through a port forward", - `Quick, - test_one_forward ]; - - "Ports: 10 port forwards", - [ "Perform 10 HTTP GETs through a port forward", - `Quick, - test_10_connections ]; - ] + loop (n - 1) + in + let start = Unix.gettimeofday () in + loop 10 + >>= fun () -> + let time = Unix.gettimeofday () -. start in + (* NOTE(djs55): on my MBP this is almost immediate *) + if time > 1. then + Fmt.kstrf failwith "10 connections took %.02f (> 1) \ + seconds" time; + Lwt.return () + ) + ) + ) + ) in + run t + +let tests = [ + "Ports: 1 port forward", + [ "Perform an HTTP GET through a port forward", + `Quick, + test_one_forward ]; + + "Ports: 10 port forwards", + [ "Perform 10 HTTP GETs through a port forward", + `Quick, + test_10_connections ]; +] diff --git a/src/hostnet_test/main.ml b/src/hostnet_test/main.ml index 413605640..7fe7efdbf 100644 --- a/src/hostnet_test/main.ml +++ b/src/hostnet_test/main.ml @@ -18,7 +18,7 @@ let reporter = msgf @@ fun ?header:_ ?tags:_ fmt -> let t = Unix.gettimeofday () -. start in Format.kfprintf k ppf ("%.5f [%a] @[" ^^ fmt ^^ "@]@.") t Logs.pp_level level in - { Logs.report } + { Logs.report } (* Run it *) let () = diff --git a/src/hostnet_test/slirp_stack.ml b/src/hostnet_test/slirp_stack.ml index fba5f534f..add48cbb3 100644 --- a/src/hostnet_test/slirp_stack.ml +++ b/src/hostnet_test/slirp_stack.ml @@ -62,146 +62,146 @@ module Dns_policy = struct end - module VMNET = Vmnet.Make(Host.Sockets.Stream.Tcp) - module Config = Active_config.Make(Host.Time)(Host.Sockets.Stream.Unix) - module Vnet = Basic_backend.Make - module Slirp_stack = - Slirp.Make(Config)(VMNET)(Dns_policy)(Mclock)(Stdlibrandom)(Vnet) - - module Client = struct - module Netif = VMNET - module Ethif1 = Ethif.Make(Netif) - module Arpv41 = Arpv4.Make(Ethif1)(Mclock)(Host.Time) - - module Dhcp_client_mirage1 = Dhcp_client_mirage.Make(Host.Time)(Netif) - module Ipv41 = Dhcp_ipv4.Make(Dhcp_client_mirage1)(Ethif1)(Arpv41) - module Icmpv41 = Icmpv4.Make(Ipv41) - module Udp1 = Udp.Make(Ipv41)(Stdlibrandom) - module Tcp1 = Tcp.Flow.Make(Ipv41)(Host.Time)(Mclock)(Stdlibrandom) - include Tcpip_stack_direct.Make(Host.Time) - (Stdlibrandom)(Netif)(Ethif1)(Arpv41)(Ipv41)(Icmpv41)(Udp1)(Tcp1) - - let or_error name m = - m >>= function - | `Error _ -> Fmt.kstrf failwith "Failed to connect %s device" name - | `Ok x -> Lwt.return x - - let connect (interface: VMNET.t) = - Ethif1.connect interface >>= fun ethif -> - Mclock.connect () >>= fun clock -> - Arpv41.connect ethif clock >>= fun arp -> - Dhcp_client_mirage1.connect interface >>= fun dhcp -> - Ipv41.connect dhcp ethif arp >>= fun ipv4 -> - Icmpv41.connect ipv4 >>= fun icmpv4 -> - Udp1.connect ipv4 >>= fun udp4 -> - Tcp1.connect ipv4 clock >>= fun tcp4 -> - let cfg = { - Mirage_stack_lwt.name = "stackv4_ip"; - interface; - } in - connect cfg ethif arp ipv4 icmpv4 udp4 tcp4 - >>= fun stack -> - Lwt.return stack - end - - module DNS = Dns_resolver_mirage.Make(Host.Time)(Client) - - let primary_dns_ip = Ipaddr.V4.of_string_exn "192.168.65.1" - - let extra_dns_ip = List.map Ipaddr.V4.of_string_exn [ - "192.168.65.3"; "192.168.65.4"; "192.168.65.5"; "192.168.65.6"; - "192.168.65.7"; "192.168.65.8"; "192.168.65.9"; "192.168.65.10"; - ] - - let peer_ip = Ipaddr.V4.of_string_exn "192.168.65.2" - let local_ip = Ipaddr.V4.of_string_exn "192.168.65.1" - let highest_ip = Ipaddr.V4.of_string_exn "192.168.65.254" - let server_macaddr = Slirp.default_server_macaddr - - let global_arp_table : Slirp.arp_table = - { Slirp.mutex = Lwt_mutex.create (); - table = [(local_ip, Slirp.default_server_macaddr)] - } - - let client_uuids : Slirp.uuid_table = - { Slirp.mutex = Lwt_mutex.create (); - table = Hashtbl.create 50; - } - - let config_without_bridge = - Mclock.connect () >|= fun clock -> - { - Slirp.peer_ip; - local_ip; - highest_ip; - extra_dns_ip; - server_macaddr; - get_domain_search = (fun () -> []); - get_domain_name = (fun () -> "local"); - client_uuids; - bridge_connections = false; - global_arp_table; - mtu = 1500; - host_names = []; - clock; - } - - (* This is a hacky way to get a hancle to the server side of the stack. *) - let slirp_stack = ref None - let slirp_stack_c = Lwt_condition.create () - - let rec get_slirp_stack () = - match !slirp_stack with - | None -> Lwt_condition.wait slirp_stack_c >>= get_slirp_stack - | Some x -> Lwt.return x - - let set_slirp_stack c = - slirp_stack := Some c; - Lwt_condition.signal slirp_stack_c () - - let start_stack l2_switch config () = - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) - >|= fun server -> - let _, port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server (fun flow -> - Slirp_stack.connect config flow l2_switch >>= fun stack -> - set_slirp_stack stack; - Log.info (fun f -> f "stack connected"); - Slirp_stack.after_disconnect stack >|= fun () -> - Log.info (fun f -> f "stack disconnected") - ); - port - - let connection = - config_without_bridge >>= fun config -> - start_stack (Vnet.create ()) config () - - let with_stack f = - connection >>= fun port -> - Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 Ipaddr.V4.localhost, port) +module VMNET = Vmnet.Make(Host.Sockets.Stream.Tcp) +module Config = Active_config.Make(Host.Time)(Host.Sockets.Stream.Unix) +module Vnet = Basic_backend.Make +module Slirp_stack = + Slirp.Make(Config)(VMNET)(Dns_policy)(Mclock)(Stdlibrandom)(Vnet) + +module Client = struct + module Netif = VMNET + module Ethif1 = Ethif.Make(Netif) + module Arpv41 = Arpv4.Make(Ethif1)(Mclock)(Host.Time) + + module Dhcp_client_mirage1 = Dhcp_client_mirage.Make(Host.Time)(Netif) + module Ipv41 = Dhcp_ipv4.Make(Dhcp_client_mirage1)(Ethif1)(Arpv41) + module Icmpv41 = Icmpv4.Make(Ipv41) + module Udp1 = Udp.Make(Ipv41)(Stdlibrandom) + module Tcp1 = Tcp.Flow.Make(Ipv41)(Host.Time)(Mclock)(Stdlibrandom) + include Tcpip_stack_direct.Make(Host.Time) + (Stdlibrandom)(Netif)(Ethif1)(Arpv41)(Ipv41)(Icmpv41)(Udp1)(Tcp1) + + let or_error name m = + m >>= function + | `Error _ -> Fmt.kstrf failwith "Failed to connect %s device" name + | `Ok x -> Lwt.return x + + let connect (interface: VMNET.t) = + Ethif1.connect interface >>= fun ethif -> + Mclock.connect () >>= fun clock -> + Arpv41.connect ethif clock >>= fun arp -> + Dhcp_client_mirage1.connect interface >>= fun dhcp -> + Ipv41.connect dhcp ethif arp >>= fun ipv4 -> + Icmpv41.connect ipv4 >>= fun icmpv4 -> + Udp1.connect ipv4 >>= fun udp4 -> + Tcp1.connect ipv4 clock >>= fun tcp4 -> + let cfg = { + Mirage_stack_lwt.name = "stackv4_ip"; + interface; + } in + connect cfg ethif arp ipv4 icmpv4 udp4 tcp4 + >>= fun stack -> + Lwt.return stack +end + +module DNS = Dns_resolver_mirage.Make(Host.Time)(Client) + +let primary_dns_ip = Ipaddr.V4.of_string_exn "192.168.65.1" + +let extra_dns_ip = List.map Ipaddr.V4.of_string_exn [ + "192.168.65.3"; "192.168.65.4"; "192.168.65.5"; "192.168.65.6"; + "192.168.65.7"; "192.168.65.8"; "192.168.65.9"; "192.168.65.10"; + ] + +let peer_ip = Ipaddr.V4.of_string_exn "192.168.65.2" +let local_ip = Ipaddr.V4.of_string_exn "192.168.65.1" +let highest_ip = Ipaddr.V4.of_string_exn "192.168.65.254" +let server_macaddr = Slirp.default_server_macaddr + +let global_arp_table : Slirp.arp_table = + { Slirp.mutex = Lwt_mutex.create (); + table = [(local_ip, Slirp.default_server_macaddr)] + } + +let client_uuids : Slirp.uuid_table = + { Slirp.mutex = Lwt_mutex.create (); + table = Hashtbl.create 50; + } + +let config_without_bridge = + Mclock.connect () >|= fun clock -> + { + Slirp.peer_ip; + local_ip; + highest_ip; + extra_dns_ip; + server_macaddr; + get_domain_search = (fun () -> []); + get_domain_name = (fun () -> "local"); + client_uuids; + bridge_connections = false; + global_arp_table; + mtu = 1500; + host_names = []; + clock; + } + +(* This is a hacky way to get a hancle to the server side of the stack. *) +let slirp_stack = ref None +let slirp_stack_c = Lwt_condition.create () + +let rec get_slirp_stack () = + match !slirp_stack with + | None -> Lwt_condition.wait slirp_stack_c >>= get_slirp_stack + | Some x -> Lwt.return x + +let set_slirp_stack c = + slirp_stack := Some c; + Lwt_condition.signal slirp_stack_c () + +let start_stack l2_switch config () = + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) + >|= fun server -> + let _, port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server (fun flow -> + Slirp_stack.connect config flow l2_switch >>= fun stack -> + set_slirp_stack stack; + Log.info (fun f -> f "stack connected"); + Slirp_stack.after_disconnect stack >|= fun () -> + Log.info (fun f -> f "stack disconnected") + ); + port + +let connection = + config_without_bridge >>= fun config -> + start_stack (Vnet.create ()) config () + +let with_stack f = + connection >>= fun port -> + Host.Sockets.Stream.Tcp.connect (Ipaddr.V4 Ipaddr.V4.localhost, port) + >>= function + | Error (`Msg x) -> failwith x + | Ok flow -> + Log.info (fun f -> f "Made a loopback connection"); + let client_macaddr = Slirp.default_client_macaddr in + let uuid = + match Uuidm.of_string "d1d9cd61-d0dc-4715-9bb3-4c11da7ad7a5" with + | Some x -> x + | None -> failwith "unable to parse test uuid" + in + VMNET.client_of_fd ~uuid ~server_macaddr:client_macaddr flow >>= function - | Error (`Msg x) -> failwith x - | Ok flow -> - Log.info (fun f -> f "Made a loopback connection"); - let client_macaddr = Slirp.default_client_macaddr in - let uuid = - match Uuidm.of_string "d1d9cd61-d0dc-4715-9bb3-4c11da7ad7a5" with - | Some x -> x - | None -> failwith "unable to parse test uuid" - in - VMNET.client_of_fd ~uuid ~server_macaddr:client_macaddr flow - >>= function - | Error (`Msg x ) -> - (* Server will close when it gets EOF *) - Host.Sockets.Stream.Tcp.close flow >>= fun () -> - failwith x - | Ok client' -> - Lwt.finalize (fun () -> - Log.info (fun f -> f "Initialising client TCP/IP stack"); - Client.connect client' >>= fun client -> - get_slirp_stack () >>= fun slirp_stack -> - f slirp_stack client - ) (fun () -> - (* Server will close when it gets EOF *) - VMNET.disconnect client' - ) + | Error (`Msg x ) -> + (* Server will close when it gets EOF *) + Host.Sockets.Stream.Tcp.close flow >>= fun () -> + failwith x + | Ok client' -> + Lwt.finalize (fun () -> + Log.info (fun f -> f "Initialising client TCP/IP stack"); + Client.connect client' >>= fun client -> + get_slirp_stack () >>= fun slirp_stack -> + f slirp_stack client + ) (fun () -> + (* Server will close when it gets EOF *) + VMNET.disconnect client' + ) diff --git a/src/hostnet_test/suite.ml b/src/hostnet_test/suite.ml index e6a2148a9..d3964c97d 100644 --- a/src/hostnet_test/suite.ml +++ b/src/hostnet_test/suite.ml @@ -11,362 +11,362 @@ module Log = (val Logs.src_log src : Logs.LOG) let pp_ips = Fmt.(list ~sep:(unit ", ") Ipaddr.pp_hum) let pp_ip4s = Fmt.(list ~sep:(unit ", ") Ipaddr.V4.pp_hum) - let run_test ?(timeout=Duration.of_sec 60) t = - let timeout = - Host.Time.sleep_ns timeout >>= fun () -> - Lwt.fail_with "timeout" - in - Host.Main.run @@ Lwt.pick [ timeout; t ] +let run_test ?(timeout=Duration.of_sec 60) t = + let timeout = + Host.Time.sleep_ns timeout >>= fun () -> + Lwt.fail_with "timeout" + in + Host.Main.run @@ Lwt.pick [ timeout; t ] - let run ?timeout t = run_test ?timeout (with_stack t) +let run ?timeout t = run_test ?timeout (with_stack t) - let test_dhcp_query () = - let t _ stack = - let ips = Client.IPV4.get_ip (Client.ipv4 stack) in - Log.info (fun f -> f "Got an IP: %a" pp_ip4s ips); - Lwt.return () - in - run t +let test_dhcp_query () = + let t _ stack = + let ips = Client.IPV4.get_ip (Client.ipv4 stack) in + Log.info (fun f -> f "Got an IP: %a" pp_ip4s ips); + Lwt.return () + in + run t - let set_dns_policy ?host_names use_host = - Mclock.connect () >|= fun clock -> - Dns_policy.remove ~priority:3; - Dns_policy.add ~priority:3 - ~config:(if use_host then `Host else Dns_policy.google_dns); - Slirp_stack.Debug.update_dns ?host_names clock +let set_dns_policy ?host_names use_host = + Mclock.connect () >|= fun clock -> + Dns_policy.remove ~priority:3; + Dns_policy.add ~priority:3 + ~config:(if use_host then `Host else Dns_policy.google_dns); + Slirp_stack.Debug.update_dns ?host_names clock - let test_dns_query server use_host () = - let t _ stack = - set_dns_policy use_host >>= fun () -> - let resolver = DNS.create stack in - DNS.gethostbyname ~server resolver "www.google.com" >|= function - | (_ :: _) as ips -> - Log.info (fun f -> f "www.google.com has IPs: %a" pp_ips ips); - | _ -> - Log.err (fun f -> f "Failed to lookup www.google.com"); - failwith "Failed to lookup www.google.com" - in - run t +let test_dns_query server use_host () = + let t _ stack = + set_dns_policy use_host >>= fun () -> + let resolver = DNS.create stack in + DNS.gethostbyname ~server resolver "www.google.com" >|= function + | (_ :: _) as ips -> + Log.info (fun f -> f "www.google.com has IPs: %a" pp_ips ips); + | _ -> + Log.err (fun f -> f "Failed to lookup www.google.com"); + failwith "Failed to lookup www.google.com" + in + run t - let test_builtin_dns_query server use_host () = - let name = "experimental.host.name.localhost" in - let t _ stack = - set_dns_policy ~host_names:[ Dns.Name.of_string name ] use_host - >>= fun () -> - let resolver = DNS.create stack in - DNS.gethostbyname ~server resolver name >>= function - | (_ :: _) as ips -> - Log.info (fun f -> f "%s has IPs: %a" name pp_ips ips); - Lwt.return () - | _ -> - Log.err (fun f -> f "Failed to lookup %s" name); - failwith ("Failed to lookup " ^ name) - in - run t +let test_builtin_dns_query server use_host () = + let name = "experimental.host.name.localhost" in + let t _ stack = + set_dns_policy ~host_names:[ Dns.Name.of_string name ] use_host + >>= fun () -> + let resolver = DNS.create stack in + DNS.gethostbyname ~server resolver name >>= function + | (_ :: _) as ips -> + Log.info (fun f -> f "%s has IPs: %a" name pp_ips ips); + Lwt.return () + | _ -> + Log.err (fun f -> f "Failed to lookup %s" name); + failwith ("Failed to lookup " ^ name) + in + run t - let test_etc_hosts_query server use_host () = - let test_name = "vpnkit.is.cool.yes.really" in - let t _ stack = - set_dns_policy use_host >>= fun () -> - let resolver = DNS.create stack in - DNS.gethostbyname ~server resolver test_name >>= function +let test_etc_hosts_query server use_host () = + let test_name = "vpnkit.is.cool.yes.really" in + let t _ stack = + set_dns_policy use_host >>= fun () -> + let resolver = DNS.create stack in + DNS.gethostbyname ~server resolver test_name >>= function + | (_ :: _) as ips -> + Log.err (fun f -> + f "This test relies on the name %s not existing but it really \ + has IPs: %a" test_name pp_ips ips); + Fmt.kstrf failwith "Test name %s really does exist" test_name + | _ -> + Hosts.etc_hosts := [ + test_name, Ipaddr.V4 (Ipaddr.V4.localhost); + ]; + DNS.gethostbyname ~server resolver test_name >|= function | (_ :: _) as ips -> - Log.err (fun f -> - f "This test relies on the name %s not existing but it really \ - has IPs: %a" test_name pp_ips ips); - Fmt.kstrf failwith "Test name %s really does exist" test_name + Log.info (fun f -> f "Name %s has IPs: %a" test_name pp_ips ips); + Hosts.etc_hosts := [] | _ -> - Hosts.etc_hosts := [ - test_name, Ipaddr.V4 (Ipaddr.V4.localhost); - ]; - DNS.gethostbyname ~server resolver test_name >|= function - | (_ :: _) as ips -> - Log.info (fun f -> f "Name %s has IPs: %a" test_name pp_ips ips); - Hosts.etc_hosts := [] - | _ -> - Log.err (fun f -> f "Failed to lookup name from /etc/hosts"); - Hosts.etc_hosts := []; - failwith "failed to lookup name from /etc/hosts" - in - run t + Log.err (fun f -> f "Failed to lookup name from /etc/hosts"); + Hosts.etc_hosts := []; + failwith "failed to lookup name from /etc/hosts" + in + run t - let test_max_connections () = - let t _ stack = - Lwt.finalize (fun () -> - let resolver = DNS.create stack in - DNS.gethostbyname ~server:primary_dns_ip resolver "www.google.com" - >>= function - | Ipaddr.V4 ip :: _ -> - Log.info (fun f -> f "Setting max connections to 0"); - Host.Sockets.set_max_connections (Some 0); - begin - Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) - >|= function - | Ok _ -> - Log.err (fun f -> - f "Connected to www.google.com, max_connections exceeded"); - failwith "too many connections" - | Error _ -> - Log.debug (fun f -> - f "Expected failure to connect to www.google.com") - end - >>= fun () -> - Log.info (fun f -> f "Removing connection limit"); - Host.Sockets.set_max_connections None; - (* Check that connections work again *) - begin - Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) - >|= function - | Ok _ -> - Log.debug (fun f -> f "Connected to www.google.com"); - | Error _ -> - Log.debug (fun f -> - f "Failure to connect to www.google.com: removing \ - max_connections limit didn't work"); - failwith "wrong max connections limit" - end - | _ -> - Log.err (fun f -> - f "Failed to look up an IPv4 address for www.google.com"); - failwith "http_fetch dns" - ) (fun () -> +let test_max_connections () = + let t _ stack = + Lwt.finalize (fun () -> + let resolver = DNS.create stack in + DNS.gethostbyname ~server:primary_dns_ip resolver "www.google.com" + >>= function + | Ipaddr.V4 ip :: _ -> + Log.info (fun f -> f "Setting max connections to 0"); + Host.Sockets.set_max_connections (Some 0); + begin + Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) + >|= function + | Ok _ -> + Log.err (fun f -> + f "Connected to www.google.com, max_connections exceeded"); + failwith "too many connections" + | Error _ -> + Log.debug (fun f -> + f "Expected failure to connect to www.google.com") + end + >>= fun () -> Log.info (fun f -> f "Removing connection limit"); Host.Sockets.set_max_connections None; - Lwt.return_unit - ) - in - run ~timeout:(Duration.of_sec 240) t + (* Check that connections work again *) + begin + Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) + >|= function + | Ok _ -> + Log.debug (fun f -> f "Connected to www.google.com"); + | Error _ -> + Log.debug (fun f -> + f "Failure to connect to www.google.com: removing \ + max_connections limit didn't work"); + failwith "wrong max connections limit" + end + | _ -> + Log.err (fun f -> + f "Failed to look up an IPv4 address for www.google.com"); + failwith "http_fetch dns" + ) (fun () -> + Log.info (fun f -> f "Removing connection limit"); + Host.Sockets.set_max_connections None; + Lwt.return_unit + ) + in + run ~timeout:(Duration.of_sec 240) t - let test_http_fetch () = - let t _ stack = - let resolver = DNS.create stack in - DNS.gethostbyname resolver "www.google.com" >>= function - | Ipaddr.V4 ip :: _ -> - begin - Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) - >>= function +let test_http_fetch () = + let t _ stack = + let resolver = DNS.create stack in + DNS.gethostbyname resolver "www.google.com" >>= function + | Ipaddr.V4 ip :: _ -> + begin + Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) + >>= function + | Error _ -> + Log.err (fun f -> f "Failed to connect to www.google.com:80"); + failwith "http_fetch" + | Ok flow -> + Log.info (fun f -> f "Connected to www.google.com:80"); + let page = Io_page.(to_cstruct (get 1)) in + let http_get = "GET / HTTP/1.0\nHost: anil.recoil.org\n\n" in + Cstruct.blit_from_string http_get 0 page 0 (String.length http_get); + let buf = Cstruct.sub page 0 (String.length http_get) in + Client.TCPV4.write flow buf >>= function + | Error `Closed -> + Log.err (fun f -> + f "EOF writing HTTP request to www.google.com:80"); + failwith "EOF on writing HTTP GET" | Error _ -> - Log.err (fun f -> f "Failed to connect to www.google.com:80"); - failwith "http_fetch" - | Ok flow -> - Log.info (fun f -> f "Connected to www.google.com:80"); - let page = Io_page.(to_cstruct (get 1)) in - let http_get = "GET / HTTP/1.0\nHost: anil.recoil.org\n\n" in - Cstruct.blit_from_string http_get 0 page 0 (String.length http_get); - let buf = Cstruct.sub page 0 (String.length http_get) in - Client.TCPV4.write flow buf >>= function - | Error `Closed -> - Log.err (fun f -> - f "EOF writing HTTP request to www.google.com:80"); - failwith "EOF on writing HTTP GET" - | Error _ -> - Log.err (fun f -> - f "Failure writing HTTP request to www.google.com:80"); - failwith "Failure on writing HTTP GET" - | Ok () -> - let rec loop total_bytes = - Client.TCPV4.read flow >>= function - | Ok `Eof -> Lwt.return total_bytes - | Error _ -> - Log.err (fun f -> - f "Failure read HTTP response from www.google.com:80"); - failwith "Failure on reading HTTP GET" - | Ok (`Data buf) -> - Log.info (fun f -> - f "Read %d bytes from www.google.com:80" (Cstruct.len buf)); - Log.info (fun f -> f "%s" (Cstruct.to_string buf)); - loop (total_bytes + (Cstruct.len buf)) - in - loop 0 >|= fun total_bytes -> - Log.info (fun f -> f "Response had %d total bytes" total_bytes); - if total_bytes == 0 then failwith "response was empty" - end - | _ -> - Log.err (fun f -> - f "Failed to look up an IPv4 address for www.google.com"); - failwith "http_fetch dns" - in - run t + Log.err (fun f -> + f "Failure writing HTTP request to www.google.com:80"); + failwith "Failure on writing HTTP GET" + | Ok () -> + let rec loop total_bytes = + Client.TCPV4.read flow >>= function + | Ok `Eof -> Lwt.return total_bytes + | Error _ -> + Log.err (fun f -> + f "Failure read HTTP response from www.google.com:80"); + failwith "Failure on reading HTTP GET" + | Ok (`Data buf) -> + Log.info (fun f -> + f "Read %d bytes from www.google.com:80" (Cstruct.len buf)); + Log.info (fun f -> f "%s" (Cstruct.to_string buf)); + loop (total_bytes + (Cstruct.len buf)) + in + loop 0 >|= fun total_bytes -> + Log.info (fun f -> f "Response had %d total bytes" total_bytes); + if total_bytes == 0 then failwith "response was empty" + end + | _ -> + Log.err (fun f -> + f "Failed to look up an IPv4 address for www.google.com"); + failwith "http_fetch dns" + in + run t + +module DevNullServer = struct + (* Accept local TCP connections, throw away all incoming data and then return + the total number of bytes processed. *) + type t = { + local_port: int; + server: Host.Sockets.Stream.Tcp.server; + } - module DevNullServer = struct - (* Accept local TCP connections, throw away all incoming data and then return - the total number of bytes processed. *) - type t = { - local_port: int; - server: Host.Sockets.Stream.Tcp.server; - } + let accept flow = + let module Channel = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) in + let ch = Channel.create flow in + (* XXX: this looks like it isn't tail recursive to me *) + let rec drop_all_data count = + Channel.read_some ch >>= function + | Error e -> Fmt.kstrf Lwt.fail_with "%a" Channel.pp_error e + | Ok `Eof -> Lwt.return count + | Ok (`Data buffer) -> + drop_all_data Int64.(add count (of_int (Cstruct.len buffer))) + in + drop_all_data 0L + >>= fun total -> + let response = Cstruct.create 8 in + Cstruct.LE.set_uint64 response 0 total; + Channel.write_buffer ch response; + Channel.flush ch >>= function + | Error e -> Fmt.kstrf Lwt.fail_with "%a" Channel.pp_write_error e + | Ok () -> Lwt.return_unit - let accept flow = - let module Channel = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) in - let ch = Channel.create flow in - (* XXX: this looks like it isn't tail recursive to me *) - let rec drop_all_data count = - Channel.read_some ch >>= function - | Error e -> Fmt.kstrf Lwt.fail_with "%a" Channel.pp_error e - | Ok `Eof -> Lwt.return count - | Ok (`Data buffer) -> - drop_all_data Int64.(add count (of_int (Cstruct.len buffer))) - in - drop_all_data 0L - >>= fun total -> - let response = Cstruct.create 8 in - Cstruct.LE.set_uint64 response 0 total; - Channel.write_buffer ch response; - Channel.flush ch >>= function - | Error e -> Fmt.kstrf Lwt.fail_with "%a" Channel.pp_write_error e - | Ok () -> Lwt.return_unit + let create () = + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) + >|= fun server -> + let _, local_port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server accept; + { local_port; server } - let create () = - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) - >|= fun server -> - let _, local_port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server accept; - { local_port; server } + let to_string t = Printf.sprintf "tcp:127.0.0.1:%d" t.local_port + let destroy t = Host.Sockets.Stream.Tcp.shutdown t.server + let with_server f = + create () >>= fun server -> + Lwt.finalize (fun () -> f server) (fun () -> destroy server) +end - let to_string t = Printf.sprintf "tcp:127.0.0.1:%d" t.local_port - let destroy t = Host.Sockets.Stream.Tcp.shutdown t.server - let with_server f = - create () >>= fun server -> - Lwt.finalize (fun () -> f server) (fun () -> destroy server) - end +let rec count = function 0 -> [] | n -> () :: (count (n - 1)) - let rec count = function 0 -> [] | n -> () :: (count (n - 1)) +let run' ?timeout t = + run ?timeout (fun x b -> + DevNullServer.with_server (fun { DevNullServer.local_port; _ } -> + t local_port x b) + ) - let run' ?timeout t = - run ?timeout (fun x b -> - DevNullServer.with_server (fun { DevNullServer.local_port; _ } -> - t local_port x b) - ) +let test_many_connections n () = + let t local_port _ stack = + (* Note that the stack will consume a small number of file + descriptors itself and each loopback connection will consume + 2: one for client and one for server. *) + (* Instead of counting calls to `connect` and trying to + calculate overheads, we connect until the system tells us + we've hit the target number of connections. *) + let rec loop acc i = + if Host.Sockets.get_num_connections () >= n + then Lwt.return acc + else + Client.TCPV4.create_connection (Client.tcpv4 stack) + (Ipaddr.V4.localhost, local_port) + >>= function + | Ok c -> + Log.info (fun f -> + f "Connected %d, total tracked connections %d" i + (Host.Sockets.get_num_connections ())); + loop (c :: acc) (i + 1) + | Error _ -> + Fmt.kstrf failwith + "Connection %d failed, total tracked connections %d" i + (Host.Sockets.get_num_connections ()) + in + loop [] 0 >|= fun flows -> + Log.info (fun f -> + f "Connected %d, total tracked connections %d" + (List.length flows) (Host.Sockets.get_num_connections ())); + (* How many connections is this? *) + in + run' ~timeout:(Duration.of_sec 240) t - let test_many_connections n () = - let t local_port _ stack = - (* Note that the stack will consume a small number of file - descriptors itself and each loopback connection will consume - 2: one for client and one for server. *) - (* Instead of counting calls to `connect` and trying to - calculate overheads, we connect until the system tells us - we've hit the target number of connections. *) - let rec loop acc i = - if Host.Sockets.get_num_connections () >= n - then Lwt.return acc - else +let test_stream_data connections length () = + let t local_port _ stack = + Lwt_list.iter_p (fun () -> + let rec connect () = Client.TCPV4.create_connection (Client.tcpv4 stack) (Ipaddr.V4.localhost, local_port) >>= function - | Ok c -> - Log.info (fun f -> - f "Connected %d, total tracked connections %d" i - (Host.Sockets.get_num_connections ())); - loop (c :: acc) (i + 1) - | Error _ -> - Fmt.kstrf failwith - "Connection %d failed, total tracked connections %d" i - (Host.Sockets.get_num_connections ()) - in - loop [] 0 >|= fun flows -> - Log.info (fun f -> - f "Connected %d, total tracked connections %d" - (List.length flows) (Host.Sockets.get_num_connections ())); - (* How many connections is this? *) - in - run' ~timeout:(Duration.of_sec 240) t - - let test_stream_data connections length () = - let t local_port _ stack = - Lwt_list.iter_p (fun () -> - let rec connect () = - Client.TCPV4.create_connection (Client.tcpv4 stack) - (Ipaddr.V4.localhost, local_port) - >>= function - | Error `Refused -> - Log.info (fun f -> f "DevNullServer Refused connection"); - Host.Time.sleep_ns (Duration.of_ms 200) - >>= fun () -> - connect () - | Error `Timeout -> - Log.err (fun f -> f "DevNullServer connection timeout"); - failwith "DevNullServer connection timeout"; - | Error e -> + | Error `Refused -> + Log.info (fun f -> f "DevNullServer Refused connection"); + Host.Time.sleep_ns (Duration.of_ms 200) + >>= fun () -> + connect () + | Error `Timeout -> + Log.err (fun f -> f "DevNullServer connection timeout"); + failwith "DevNullServer connection timeout"; + | Error e -> + Log.err (fun f -> + f "DevNullServer connnection failure: %a" + Client.TCPV4.pp_error e); + Fmt.kstrf failwith "%a" Client.TCPV4.pp_error e + | Ok flow -> + Log.info (fun f -> f "Connected to local server"); + Lwt.return flow + in + connect () + >>= fun flow -> + let page = Io_page.(to_cstruct (get 1)) in + Cstruct.memset page 0; + let rec loop remaining = + if remaining = 0 + then Lwt.return () + else begin + let this_time = min remaining (Cstruct.len page) in + let buf = Cstruct.sub page 0 this_time in + Client.TCPV4.write flow buf >>= function + | Error `Closed -> Log.err (fun f -> - f "DevNullServer connnection failure: %a" - Client.TCPV4.pp_error e); - Fmt.kstrf failwith "%a" Client.TCPV4.pp_error e - | Ok flow -> - Log.info (fun f -> f "Connected to local server"); - Lwt.return flow - in - connect () - >>= fun flow -> - let page = Io_page.(to_cstruct (get 1)) in - Cstruct.memset page 0; - let rec loop remaining = - if remaining = 0 - then Lwt.return () - else begin - let this_time = min remaining (Cstruct.len page) in - let buf = Cstruct.sub page 0 this_time in - Client.TCPV4.write flow buf >>= function - | Error `Closed -> - Log.err (fun f -> - f "EOF writing to DevNullServerwith %d bytes left" - remaining); - (* failwith "EOF on writing to DevNullServer" *) - Lwt.return () - | Error _ -> - Log.err (fun f -> - f "Failure writing to DevNullServer with %d bytes left" - remaining); - (* failwith "Failure on writing to DevNullServer" *) - Lwt.return () - | Ok () -> - loop (remaining - this_time) - end - in - loop length >>= fun () -> - Client.TCPV4.close flow >>= fun () -> - Client.TCPV4.read flow >|= function - | Ok `Eof -> - Log.err (fun f -> f "EOF reading result from DevNullServer"); - (* failwith "EOF reading result from DevNullServer" *) - | Error _ -> - Log.err (fun f -> f "Failure reading result from DevNullServer"); - (* failwith "Failure on reading result from DevNullServer" *) - | Ok (`Data buf) -> - Log.info (fun f -> - f "Read %d bytes from DevNullServer" (Cstruct.len buf)); - let response = Cstruct.LE.get_uint64 buf 0 in - if Int64.to_int response != length - then Fmt.kstrf failwith - "Response was %Ld while expected %d" response length; - ) (count connections) - in - run' t + f "EOF writing to DevNullServerwith %d bytes left" + remaining); + (* failwith "EOF on writing to DevNullServer" *) + Lwt.return () + | Error _ -> + Log.err (fun f -> + f "Failure writing to DevNullServer with %d bytes left" + remaining); + (* failwith "Failure on writing to DevNullServer" *) + Lwt.return () + | Ok () -> + loop (remaining - this_time) + end + in + loop length >>= fun () -> + Client.TCPV4.close flow >>= fun () -> + Client.TCPV4.read flow >|= function + | Ok `Eof -> + Log.err (fun f -> f "EOF reading result from DevNullServer"); + (* failwith "EOF reading result from DevNullServer" *) + | Error _ -> + Log.err (fun f -> f "Failure reading result from DevNullServer"); + (* failwith "Failure on reading result from DevNullServer" *) + | Ok (`Data buf) -> + Log.info (fun f -> + f "Read %d bytes from DevNullServer" (Cstruct.len buf)); + let response = Cstruct.LE.get_uint64 buf 0 in + if Int64.to_int response != length + then Fmt.kstrf failwith + "Response was %Ld while expected %d" response length; + ) (count connections) + in + run' t - let test_dhcp = [ - "DHCP: simple query", - ["check that the DHCP server works", `Quick, test_dhcp_query]; - ] +let test_dhcp = [ + "DHCP: simple query", + ["check that the DHCP server works", `Quick, test_dhcp_query]; +] - let test_dns use_host = - let prefix = if use_host then "Host resolver" else "DNS forwarder" in [ - prefix ^ ": lookup ", - ["", `Quick, test_dns_query primary_dns_ip use_host]; +let test_dns use_host = + let prefix = if use_host then "Host resolver" else "DNS forwarder" in [ + prefix ^ ": lookup ", + ["", `Quick, test_dns_query primary_dns_ip use_host]; - prefix ^ ": builtins", - [ "", `Quick, test_builtin_dns_query primary_dns_ip use_host ]; + prefix ^ ": builtins", + [ "", `Quick, test_builtin_dns_query primary_dns_ip use_host ]; - prefix ^ ": _etc_hosts", - [ "", `Quick, test_etc_hosts_query primary_dns_ip use_host ]; - ] + prefix ^ ": _etc_hosts", + [ "", `Quick, test_etc_hosts_query primary_dns_ip use_host ]; + ] - let test_tcp = [ - "HTTP GET", [ "HTTP GET http://www.google.com/", `Quick, test_http_fetch ]; +let test_tcp = [ + "HTTP GET", [ "HTTP GET http://www.google.com/", `Quick, test_http_fetch ]; - "Max connections", - [ "HTTP GET fails beyond max connections", `Quick, test_max_connections ]; + "Max connections", + [ "HTTP GET fails beyond max connections", `Quick, test_max_connections ]; - "TCP streaming", - [ "1 TCP connection transferring 1 KiB", `Quick, test_stream_data 1 1024 ]; + "TCP streaming", + [ "1 TCP connection transferring 1 KiB", `Quick, test_stream_data 1 1024 ]; (* "10 TCP connections each transferring 1 KiB", `Quick, test_stream_data 10 1024; @@ -379,15 +379,15 @@ let pp_ip4s = Fmt.(list ~sep:(unit ", ") Ipaddr.V4.pp_hum) "32 TCP connections each transferring 1 GiB", `Slow, test_stream_data 32 (1024*1024*1024); *) - ] +] - let tests = - Hosts_test.tests @ Forwarding.tests @ test_dhcp - @ (test_dns true) @ (test_dns false) - @ test_tcp @ Test_nat.tests @ Test_http.tests @ Test_http.Exclude.tests - @ Test_half_close.tests +let tests = + Hosts_test.tests @ Forwarding.tests @ test_dhcp + @ (test_dns true) @ (test_dns false) + @ test_tcp @ Test_nat.tests @ Test_http.tests @ Test_http.Exclude.tests + @ Test_half_close.tests - let scalability = [ - "1026conns", - [ "Test many connections", `Quick, test_many_connections (1024 + 2) ]; - ] +let scalability = [ + "1026conns", + [ "Test many connections", `Quick, test_many_connections (1024 + 2) ]; +] diff --git a/src/hostnet_test/test_half_close.ml b/src/hostnet_test/test_half_close.ml index f6c6dc99e..d6091b535 100644 --- a/src/hostnet_test/test_half_close.ml +++ b/src/hostnet_test/test_half_close.ml @@ -9,172 +9,172 @@ let failf fmt = Fmt.kstrf failwith fmt module Log = (val Logs.src_log src : Logs.LOG) - module Server = struct - type t = { - server: Host.Sockets.Stream.Tcp.server; - port: int; - } - let create on_accept = - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) - >>= fun server -> - let _, port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server on_accept; - Lwt.return { server; port } - let destroy t = - Host.Sockets.Stream.Tcp.shutdown t.server - end - let with_server on_accept f = - Server.create on_accept +module Server = struct + type t = { + server: Host.Sockets.Stream.Tcp.server; + port: int; + } + let create on_accept = + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) >>= fun server -> - Lwt.finalize (fun () -> f server) (fun () -> Server.destroy server) - - module Outgoing = struct - module C = Mirage_channel_lwt.Make(Slirp_stack.Client.TCPV4) + let _, port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server on_accept; + Lwt.return { server; port } + let destroy t = + Host.Sockets.Stream.Tcp.shutdown t.server +end +let with_server on_accept f = + Server.create on_accept + >>= fun server -> + Lwt.finalize (fun () -> f server) (fun () -> Server.destroy server) + +module Outgoing = struct + module C = Mirage_channel_lwt.Make(Slirp_stack.Client.TCPV4) +end +module Incoming = struct + module C = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) +end + +let request = "Hello there" +let response = "And hello to you" + +let data = function +| Ok (`Data x) -> x +| Ok `Eof -> failwith "data: eof" +| Error _ -> failwith "data: error" + +let unit = function +| Ok () -> () +| Error _ -> failwith "unit: error" + +let flow ip port = function +| Ok flow -> flow +| Error _ -> + Log.err (fun f -> f "Failed to connect to %a:%d" Ipaddr.V4.pp_hum ip port); + failwith "Client.TCPV4.create_connection" + +(* Run a simple server on localhost and connect to it via vpnkit. + The Mirage client will call `close` to trigger a half-close of + the TCP connection before reading the response. This verifies + that the other side of the connection remains open. *) +let test_mirage_half_close () = + Host.Main.run begin + let forwarded, forwarded_u = Lwt.task () in + Slirp_stack.with_stack (fun _ stack -> with_server (fun flow -> + (* Read the request until EOF *) + let ic = Incoming.C.create flow in + Incoming.C.read_line ic >|= data >>= fun bufs -> + let txt = Cstruct.(to_string @@ concat bufs) in + if txt <> request + then failf "Expected to read '%s', got '%s'" request txt; + Incoming.C.read_line ic >|= data >>= fun bufs -> + assert (Cstruct.(len @@ concat bufs) = 0); + Log.info (fun f -> f "Read the request (up to and including EOF)"); + + (* Write a response. If the connection is fully closed + rather than half-closed then this will fail. *) + Incoming.C.write_line ic response; + Incoming.C.flush ic >|= unit >>= fun () -> + Log.info (fun f -> f "Written response"); + Lwt.wakeup_later forwarded_u (); + Lwt.return_unit + ) (fun server -> + (* Now that the server is running, connect to it and send a + request. *) + let open Slirp_stack in + let ip = Ipaddr.V4.localhost in + let port = server.Server.port in + Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, port) + >|= flow ip port >>= fun flow -> + Log.info (fun f -> f "Connected to %a:%d" Ipaddr.V4.pp_hum ip port); + let oc = Outgoing.C.create flow in + Outgoing.C.write_line oc request; + Outgoing.C.flush oc >|= unit >>= fun () -> + + (* This will perform a TCP half-close *) + Client.TCPV4.close flow >>= fun () -> + + (* Verify the response is still intact *) + Outgoing.C.read_line oc >|= data >>= fun bufs -> + let txt = Cstruct.(to_string @@ concat bufs) in + if txt <> response + then failf "Expected to read '%s', got '%s'" response txt; + Log.info (fun f -> f "Read the response. Waiting for cleanup"); + Lwt.pick [ + (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> `Timeout); + (forwarded >|= fun x -> `Result x) ] + ) >>= function + | `Timeout -> failwith "TCP half close test timed-out" + | `Result x -> Lwt.return x + ) end - module Incoming = struct - module C = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) + +(* Run a simple server on localhost and connect to it via vpnkit. + The server on the host will call `close` to trigger a half-close + of the TCP connection before reading the response. This verifies + that the other side of the connection remains open. *) +let test_host_half_close () = + Host.Main.run begin + let forwarded, forwarded_u = Lwt.task () in + Slirp_stack.with_stack (fun _ stack -> with_server (fun flow -> + (* Write a request *) + let ic = Incoming.C.create flow in + Incoming.C.write_line ic request; + Incoming.C.flush ic >|= unit >>= fun () -> + + (* This will perform a TCP half-close *) + Host.Sockets.Stream.Tcp.shutdown_write flow >>= fun () -> + + (* Read the response from the other side of the connection *) + Incoming.C.read_line ic >|= data + >>= fun bufs -> + let txt = Cstruct.(to_string @@ concat bufs) in + if txt <> response + then failf "Expected to read '%s', got '%s'" response txt; + Log.info (fun f -> f "Read the response, signalling complete"); + Lwt.wakeup_later forwarded_u (); + Lwt.return_unit + ) (fun server -> + (* Now that the server is running, connect to it and send a + request. *) + let open Slirp_stack in + let ip = Ipaddr.V4.localhost in + let port = server.Server.port in + Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, port) + >|= flow ip port >>= fun flow -> + Log.info (fun f -> f "Connected to %a:%d" Ipaddr.V4.pp_hum ip port); + let oc = Outgoing.C.create flow in + (* Read the request *) + Outgoing.C.read_line oc >|= data >>= fun bufs -> + let txt = Cstruct.(to_string @@ concat bufs) in + if txt <> request + then failf "Expected to read '%s', got '%s'" request txt; + (* Check we're at EOF *) + Outgoing.C.read_line oc >|= data >>= fun bufs -> + assert (Cstruct.(len @@ concat bufs) = 0); + Log.info (fun f -> f "Read the request (up to and including EOF)"); + (* Write response *) + Outgoing.C.write_line oc response; + Outgoing.C.flush oc >|= unit >>= fun () -> + Log.info (fun f -> f "Written response and will wait."); + Lwt.pick [ + (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> `Timeout); + (forwarded >|= fun x -> `Result x) ] + ) >>= function + | `Timeout -> failwith "TCP half close test timed-out" + | `Result x -> Lwt.return x + ) end - let request = "Hello there" - let response = "And hello to you" - - let data = function - | Ok (`Data x) -> x - | Ok `Eof -> failwith "data: eof" - | Error _ -> failwith "data: error" - - let unit = function - | Ok () -> () - | Error _ -> failwith "unit: error" - - let flow ip port = function - | Ok flow -> flow - | Error _ -> - Log.err (fun f -> f "Failed to connect to %a:%d" Ipaddr.V4.pp_hum ip port); - failwith "Client.TCPV4.create_connection" - - (* Run a simple server on localhost and connect to it via vpnkit. - The Mirage client will call `close` to trigger a half-close of - the TCP connection before reading the response. This verifies - that the other side of the connection remains open. *) - let test_mirage_half_close () = - Host.Main.run begin - let forwarded, forwarded_u = Lwt.task () in - Slirp_stack.with_stack (fun _ stack -> with_server (fun flow -> - (* Read the request until EOF *) - let ic = Incoming.C.create flow in - Incoming.C.read_line ic >|= data >>= fun bufs -> - let txt = Cstruct.(to_string @@ concat bufs) in - if txt <> request - then failf "Expected to read '%s', got '%s'" request txt; - Incoming.C.read_line ic >|= data >>= fun bufs -> - assert (Cstruct.(len @@ concat bufs) = 0); - Log.info (fun f -> f "Read the request (up to and including EOF)"); - - (* Write a response. If the connection is fully closed - rather than half-closed then this will fail. *) - Incoming.C.write_line ic response; - Incoming.C.flush ic >|= unit >>= fun () -> - Log.info (fun f -> f "Written response"); - Lwt.wakeup_later forwarded_u (); - Lwt.return_unit - ) (fun server -> - (* Now that the server is running, connect to it and send a - request. *) - let open Slirp_stack in - let ip = Ipaddr.V4.localhost in - let port = server.Server.port in - Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, port) - >|= flow ip port >>= fun flow -> - Log.info (fun f -> f "Connected to %a:%d" Ipaddr.V4.pp_hum ip port); - let oc = Outgoing.C.create flow in - Outgoing.C.write_line oc request; - Outgoing.C.flush oc >|= unit >>= fun () -> - - (* This will perform a TCP half-close *) - Client.TCPV4.close flow >>= fun () -> - - (* Verify the response is still intact *) - Outgoing.C.read_line oc >|= data >>= fun bufs -> - let txt = Cstruct.(to_string @@ concat bufs) in - if txt <> response - then failf "Expected to read '%s', got '%s'" response txt; - Log.info (fun f -> f "Read the response. Waiting for cleanup"); - Lwt.pick [ - (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> `Timeout); - (forwarded >|= fun x -> `Result x) ] - ) >>= function - | `Timeout -> failwith "TCP half close test timed-out" - | `Result x -> Lwt.return x - ) - end - - (* Run a simple server on localhost and connect to it via vpnkit. - The server on the host will call `close` to trigger a half-close - of the TCP connection before reading the response. This verifies - that the other side of the connection remains open. *) - let test_host_half_close () = - Host.Main.run begin - let forwarded, forwarded_u = Lwt.task () in - Slirp_stack.with_stack (fun _ stack -> with_server (fun flow -> - (* Write a request *) - let ic = Incoming.C.create flow in - Incoming.C.write_line ic request; - Incoming.C.flush ic >|= unit >>= fun () -> - - (* This will perform a TCP half-close *) - Host.Sockets.Stream.Tcp.shutdown_write flow >>= fun () -> - - (* Read the response from the other side of the connection *) - Incoming.C.read_line ic >|= data - >>= fun bufs -> - let txt = Cstruct.(to_string @@ concat bufs) in - if txt <> response - then failf "Expected to read '%s', got '%s'" response txt; - Log.info (fun f -> f "Read the response, signalling complete"); - Lwt.wakeup_later forwarded_u (); - Lwt.return_unit - ) (fun server -> - (* Now that the server is running, connect to it and send a - request. *) - let open Slirp_stack in - let ip = Ipaddr.V4.localhost in - let port = server.Server.port in - Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, port) - >|= flow ip port >>= fun flow -> - Log.info (fun f -> f "Connected to %a:%d" Ipaddr.V4.pp_hum ip port); - let oc = Outgoing.C.create flow in - (* Read the request *) - Outgoing.C.read_line oc >|= data >>= fun bufs -> - let txt = Cstruct.(to_string @@ concat bufs) in - if txt <> request - then failf "Expected to read '%s', got '%s'" request txt; - (* Check we're at EOF *) - Outgoing.C.read_line oc >|= data >>= fun bufs -> - assert (Cstruct.(len @@ concat bufs) = 0); - Log.info (fun f -> f "Read the request (up to and including EOF)"); - (* Write response *) - Outgoing.C.write_line oc response; - Outgoing.C.flush oc >|= unit >>= fun () -> - Log.info (fun f -> f "Written response and will wait."); - Lwt.pick [ - (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> `Timeout); - (forwarded >|= fun x -> `Result x) ] - ) >>= function - | `Timeout -> failwith "TCP half close test timed-out" - | `Result x -> Lwt.return x - ) - end - - let tests = [ - - "TCP: test Mirage half close", [ - "check that Mirage half-close isn't a full-close", `Quick, - test_mirage_half_close - ] ; - - "TCP: test Host half close", [ - "check that the Host half-close isn't a full-close", `Quick, - test_host_half_close - ]; - ] +let tests = [ + + "TCP: test Mirage half close", [ + "check that Mirage half-close isn't a full-close", `Quick, + test_mirage_half_close + ] ; + + "TCP: test Host half close", [ + "check that the Host half-close isn't a full-close", `Quick, + test_host_half_close + ]; +] diff --git a/src/hostnet_test/test_http.ml b/src/hostnet_test/test_http.ml index 730eb97cc..aca42cfda 100644 --- a/src/hostnet_test/test_http.ml +++ b/src/hostnet_test/test_http.ml @@ -71,53 +71,184 @@ module Exclude = struct ] end - module Server = struct - type t = { - server: Host.Sockets.Stream.Tcp.server; - port: int; - } - let create on_accept = - Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) - >>= fun server -> - let _, port = Host.Sockets.Stream.Tcp.getsockname server in - Host.Sockets.Stream.Tcp.listen server on_accept; - Lwt.return { server; port } - let destroy t = - Host.Sockets.Stream.Tcp.shutdown t.server - end - let with_server on_accept f = - Server.create on_accept +module Server = struct + type t = { + server: Host.Sockets.Stream.Tcp.server; + port: int; + } + let create on_accept = + Host.Sockets.Stream.Tcp.bind (Ipaddr.V4 Ipaddr.V4.localhost, 0) >>= fun server -> - Lwt.finalize (fun () -> f server) (fun () -> Server.destroy server) + let _, port = Host.Sockets.Stream.Tcp.getsockname server in + Host.Sockets.Stream.Tcp.listen server on_accept; + Lwt.return { server; port } + let destroy t = + Host.Sockets.Stream.Tcp.shutdown t.server +end +let with_server on_accept f = + Server.create on_accept + >>= fun server -> + Lwt.finalize (fun () -> f server) (fun () -> Server.destroy server) + +module Outgoing = struct + module C = Mirage_channel_lwt.Make(Slirp_stack.Client.TCPV4) + module IO = Cohttp_mirage_io.Make(C) + module Request = Cohttp.Request.Make(IO) + module Response = Cohttp.Response.Make(IO) +end +module Incoming = struct + module C = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) + module IO = Cohttp_mirage_io.Make(C) + module Request = Cohttp.Request.Make(IO) + module Response = Cohttp.Response.Make(IO) +end - module Outgoing = struct - module C = Mirage_channel_lwt.Make(Slirp_stack.Client.TCPV4) - module IO = Cohttp_mirage_io.Make(C) - module Request = Cohttp.Request.Make(IO) - module Response = Cohttp.Response.Make(IO) +let send_http_request stack ip request = + let open Slirp_stack in + Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) + >>= function + | Ok flow -> + Log.info (fun f -> f "Connected to %s:80" (Ipaddr.V4.to_string ip)); + let oc = Outgoing.C.create flow in + Outgoing.Request.write ~flush:true (fun _writer -> Lwt.return_unit) + request oc + | Error _ -> + Log.err (fun f -> f "Failed to connect to %s:80" (Ipaddr.V4.to_string ip)); + failwith "http_fetch" + +let intercept request = + let forwarded, forwarded_u = Lwt.task () in + Slirp_stack.with_stack (fun _ stack -> + with_server (fun flow -> + let ic = Incoming.C.create flow in + Incoming.Request.read ic >>= function + | `Eof -> + Log.err (fun f -> f "Failed to request"); + failwith "Failed to read request" + | `Invalid x -> + Log.err (fun f -> f "Failed to parse request: %s" x); + failwith ("Failed to parse request: " ^ x) + | `Ok req -> + (* parse the response *) + Lwt.wakeup_later forwarded_u req; + Lwt.return_unit + ) (fun server -> + let json = + Ezjsonm.from_string (" { \"http\": \"127.0.0.1:" ^ + (string_of_int server.Server.port) ^ "\" }") + in + Slirp_stack.Slirp_stack.Debug.update_http_json json () + >>= function + | Error (`Msg m) -> failwith ("Failed to enable HTTP proxy: " ^ m) + | Ok () -> + send_http_request stack (Ipaddr.V4.of_string_exn "127.0.0.1") + request + >>= fun () -> + Lwt.pick [ + (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> + `Timeout); + (forwarded >>= fun x -> Lwt.return (`Result x)) + ] + ) + >|= function + | `Timeout -> failwith "HTTP interception failed" + | `Result x -> x + ) + +(* Test that HTTP interception works at all *) +let test_interception () = + Host.Main.run begin + let request = + Cohttp.Request.make + (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) + in + intercept request >>= fun result -> + Log.info (fun f -> + f "original was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); + Log.info (fun f -> + f "proxied was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); + Alcotest.check Alcotest.string "method" + (Cohttp.Code.string_of_method request.Cohttp.Request.meth) + (Cohttp.Code.string_of_method result.Cohttp.Request.meth); + Alcotest.check Alcotest.string "version" + (Cohttp.Code.string_of_version request.Cohttp.Request.version) + (Cohttp.Code.string_of_version result.Cohttp.Request.version); + Lwt.return () end - module Incoming = struct - module C = Mirage_channel_lwt.Make(Host.Sockets.Stream.Tcp) - module IO = Cohttp_mirage_io.Make(C) - module Request = Cohttp.Request.Make(IO) - module Response = Cohttp.Response.Make(IO) + +(* Test that the URI becomes absolute *) +let test_uri_absolute () = + Host.Main.run begin + let request = + Cohttp.Request.make + (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) + in + intercept request >>= fun result -> + Log.info (fun f -> + f "original was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); + Log.info (fun f -> + f "proxied was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); + let uri = Uri.of_string result.Cohttp.Request.resource in + Alcotest.check Alcotest.(option string) "scheme" + (Some "http") (Uri.scheme uri); + Lwt.return () end - let send_http_request stack ip request = - let open Slirp_stack in - Client.TCPV4.create_connection (Client.tcpv4 stack) (ip, 80) - >>= function - | Ok flow -> - Log.info (fun f -> f "Connected to %s:80" (Ipaddr.V4.to_string ip)); - let oc = Outgoing.C.create flow in - Outgoing.Request.write ~flush:true (fun _writer -> Lwt.return_unit) - request oc - | Error _ -> - Log.err (fun f -> f "Failed to connect to %s:80" (Ipaddr.V4.to_string ip)); - failwith "http_fetch" +(* Verify that a custom X- header is preserved *) +let test_x_header_preserved () = + Host.Main.run begin + let headers = + Cohttp.Header.add (Cohttp.Header.init ()) "X-dave-is-cool" "true" + in + let request = + Cohttp.Request.make ~headers + (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) + in + intercept request >>= fun result -> + Log.info (fun f -> + f "original was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); + Log.info (fun f -> + f "proxied was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); + Alcotest.check Alcotest.(option string) "X-header" + (Some "true") + (Cohttp.Header.get result.Cohttp.Request.headers "X-dave-is-cool"); + Lwt.return () + end - let intercept request = - let forwarded, forwarded_u = Lwt.task () in +(* Verify that the user-agent is preserved. In particular we don't want our + http library to leak here. *) +let test_user_agent_preserved () = + Host.Main.run begin + let headers = + Cohttp.Header.add (Cohttp.Header.init ()) "user-agent" "whatever" + in + let request = + Cohttp.Request.make ~headers + (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) + in + intercept request >>= fun result -> + Log.info (fun f -> + f "original was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); + Log.info (fun f -> + f "proxied was: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); + Alcotest.check Alcotest.(option string) "user-agent" (Some "whatever") + (Cohttp.Header.get result.Cohttp.Request.headers "user-agent"); + Lwt.return () + end + +let err_flush e = Fmt.kstrf failwith "%a" Incoming.C.pp_write_error e + +let test_http_connect () = + let test_dst_ip = Ipaddr.V4.of_string_exn "1.2.3.4" in + Host.Main.run begin Slirp_stack.with_stack (fun _ stack -> with_server (fun flow -> let ic = Incoming.C.create flow in @@ -129,200 +260,69 @@ end Log.err (fun f -> f "Failed to parse request: %s" x); failwith ("Failed to parse request: " ^ x) | `Ok req -> - (* parse the response *) - Lwt.wakeup_later forwarded_u req; - Lwt.return_unit + Log.info (fun f -> + f "received: %s" + (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t req))); + Alcotest.check Alcotest.string "method" + (Cohttp.Code.string_of_method `CONNECT) + (Cohttp.Code.string_of_method req.Cohttp.Request.meth); + let uri = Cohttp.Request.uri req in + Alcotest.check Alcotest.(option string) "host" + (Some (Ipaddr.V4.to_string test_dst_ip)) (Uri.host uri); + Alcotest.check Alcotest.(option int) "port" (Some 443) + (Uri.port uri); + (* Unfortunately cohttp always adds transfer-encoding: chunked + so we write the header ourselves *) + Incoming.C.write_line ic "HTTP/1.0 200 OK\r"; + Incoming.C.write_line ic "\r"; + Incoming.C.flush ic >>= function + | Error e -> err_flush e + | Ok () -> + Incoming.C.write_line ic "hello"; + Incoming.C.flush ic >|= function + | Error e -> err_flush e + | Ok () -> () ) (fun server -> - let json = - Ezjsonm.from_string (" { \"http\": \"127.0.0.1:" ^ - (string_of_int server.Server.port) ^ "\" }") - in - Slirp_stack.Slirp_stack.Debug.update_http_json json () + Slirp_stack.Slirp_stack.Debug.update_http + ~https:("127.0.0.1:" ^ (string_of_int server.Server.port)) () >>= function | Error (`Msg m) -> failwith ("Failed to enable HTTP proxy: " ^ m) | Ok () -> - send_http_request stack (Ipaddr.V4.of_string_exn "127.0.0.1") - request - >>= fun () -> - Lwt.pick [ - (Host.Time.sleep_ns (Duration.of_sec 100) >|= fun () -> - `Timeout); - (forwarded >>= fun x -> Lwt.return (`Result x)) - ] + let open Slirp_stack in + Client.TCPV4.create_connection (Client.tcpv4 stack) + (test_dst_ip, 443) + >>= function + | Error _ -> + Log.err (fun f -> + f "TCPV4.create_connection %a:443 failed" + Ipaddr.V4.pp_hum test_dst_ip); + failwith "TCPV4.create_connection" + | Ok flow -> + let ic = Outgoing.C.create flow in + Outgoing.C.read_some ~len:5 ic >>= function + | Error e -> Fmt.kstrf failwith "%a" Outgoing.C.pp_error e + | Ok `Eof -> failwith "EOF" + | Ok (`Data buf) -> + let txt = Cstruct.to_string buf in + Alcotest.check Alcotest.string "message" "hello" txt; + Lwt.return_unit ) - >|= function - | `Timeout -> failwith "HTTP interception failed" - | `Result x -> x ) + end - (* Test that HTTP interception works at all *) - let test_interception () = - Host.Main.run begin - let request = - Cohttp.Request.make - (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) - in - intercept request >>= fun result -> - Log.info (fun f -> - f "original was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); - Log.info (fun f -> - f "proxied was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); - Alcotest.check Alcotest.string "method" - (Cohttp.Code.string_of_method request.Cohttp.Request.meth) - (Cohttp.Code.string_of_method result.Cohttp.Request.meth); - Alcotest.check Alcotest.string "version" - (Cohttp.Code.string_of_version request.Cohttp.Request.version) - (Cohttp.Code.string_of_version result.Cohttp.Request.version); - Lwt.return () - end - - (* Test that the URI becomes absolute *) - let test_uri_absolute () = - Host.Main.run begin - let request = - Cohttp.Request.make - (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) - in - intercept request >>= fun result -> - Log.info (fun f -> - f "original was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); - Log.info (fun f -> - f "proxied was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); - let uri = Uri.of_string result.Cohttp.Request.resource in - Alcotest.check Alcotest.(option string) "scheme" - (Some "http") (Uri.scheme uri); - Lwt.return () - end - - (* Verify that a custom X- header is preserved *) - let test_x_header_preserved () = - Host.Main.run begin - let headers = - Cohttp.Header.add (Cohttp.Header.init ()) "X-dave-is-cool" "true" - in - let request = - Cohttp.Request.make ~headers - (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) - in - intercept request >>= fun result -> - Log.info (fun f -> - f "original was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); - Log.info (fun f -> - f "proxied was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); - Alcotest.check Alcotest.(option string) "X-header" - (Some "true") - (Cohttp.Header.get result.Cohttp.Request.headers "X-dave-is-cool"); - Lwt.return () - end - - (* Verify that the user-agent is preserved. In particular we don't want our - http library to leak here. *) - let test_user_agent_preserved () = - Host.Main.run begin - let headers = - Cohttp.Header.add (Cohttp.Header.init ()) "user-agent" "whatever" - in - let request = - Cohttp.Request.make ~headers - (Uri.make ~scheme:"http" ~host:"dave.recoil.org" ~path:"/" ()) - in - intercept request >>= fun result -> - Log.info (fun f -> - f "original was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t request))); - Log.info (fun f -> - f "proxied was: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t result))); - Alcotest.check Alcotest.(option string) "user-agent" (Some "whatever") - (Cohttp.Header.get result.Cohttp.Request.headers "user-agent"); - Lwt.return () - end - - let err_flush e = Fmt.kstrf failwith "%a" Incoming.C.pp_write_error e - - let test_http_connect () = - let test_dst_ip = Ipaddr.V4.of_string_exn "1.2.3.4" in - Host.Main.run begin - Slirp_stack.with_stack (fun _ stack -> - with_server (fun flow -> - let ic = Incoming.C.create flow in - Incoming.Request.read ic >>= function - | `Eof -> - Log.err (fun f -> f "Failed to request"); - failwith "Failed to read request" - | `Invalid x -> - Log.err (fun f -> f "Failed to parse request: %s" x); - failwith ("Failed to parse request: " ^ x) - | `Ok req -> - Log.info (fun f -> - f "received: %s" - (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t req))); - Alcotest.check Alcotest.string "method" - (Cohttp.Code.string_of_method `CONNECT) - (Cohttp.Code.string_of_method req.Cohttp.Request.meth); - let uri = Cohttp.Request.uri req in - Alcotest.check Alcotest.(option string) "host" - (Some (Ipaddr.V4.to_string test_dst_ip)) (Uri.host uri); - Alcotest.check Alcotest.(option int) "port" (Some 443) - (Uri.port uri); - (* Unfortunately cohttp always adds transfer-encoding: chunked - so we write the header ourselves *) - Incoming.C.write_line ic "HTTP/1.0 200 OK\r"; - Incoming.C.write_line ic "\r"; - Incoming.C.flush ic >>= function - | Error e -> err_flush e - | Ok () -> - Incoming.C.write_line ic "hello"; - Incoming.C.flush ic >|= function - | Error e -> err_flush e - | Ok () -> () - ) (fun server -> - Slirp_stack.Slirp_stack.Debug.update_http - ~https:("127.0.0.1:" ^ (string_of_int server.Server.port)) () - >>= function - | Error (`Msg m) -> failwith ("Failed to enable HTTP proxy: " ^ m) - | Ok () -> - let open Slirp_stack in - Client.TCPV4.create_connection (Client.tcpv4 stack) - (test_dst_ip, 443) - >>= function - | Error _ -> - Log.err (fun f -> - f "TCPV4.create_connection %a:443 failed" - Ipaddr.V4.pp_hum test_dst_ip); - failwith "TCPV4.create_connection" - | Ok flow -> - let ic = Outgoing.C.create flow in - Outgoing.C.read_some ~len:5 ic >>= function - | Error e -> Fmt.kstrf failwith "%a" Outgoing.C.pp_error e - | Ok `Eof -> failwith "EOF" - | Ok (`Data buf) -> - let txt = Cstruct.to_string buf in - Alcotest.check Alcotest.string "message" "hello" txt; - Lwt.return_unit - ) - ) - end - - let tests = [ - "HTTP: interception", - [ "", `Quick, test_interception ]; +let tests = [ + "HTTP: interception", + [ "", `Quick, test_interception ]; - "HTTP: URI", - [ "check that URIs are rewritten", `Quick, test_uri_absolute ]; + "HTTP: URI", + [ "check that URIs are rewritten", `Quick, test_uri_absolute ]; - "HTTP: custom header", - ["check that custom headers are preserved", `Quick, test_x_header_preserved]; + "HTTP: custom header", + ["check that custom headers are preserved", `Quick, test_x_header_preserved]; - "HTTP: user-agent", - [ "check that user-agent is preserved", `Quick, test_user_agent_preserved ]; + "HTTP: user-agent", + [ "check that user-agent is preserved", `Quick, test_user_agent_preserved ]; - "HTTP: CONNECT", - [ "check that HTTP CONNECT works for HTTPS", `Quick, test_http_connect ]; - ] + "HTTP: CONNECT", + [ "check that HTTP CONNECT works for HTTPS", `Quick, test_http_connect ]; +] diff --git a/src/hostnet_test/test_nat.ml b/src/hostnet_test/test_nat.ml index febfc48c6..affa54096 100644 --- a/src/hostnet_test/test_nat.ml +++ b/src/hostnet_test/test_nat.ml @@ -8,373 +8,373 @@ let src = module Log = (val Logs.src_log src : Logs.LOG) - let run ?(timeout=Duration.of_sec 60) t = - let timeout = - Host.Time.sleep_ns timeout >>= fun () -> - Lwt.fail_with "timeout" - in - Host.Main.run @@ Lwt.pick [ timeout; t ] +let run ?(timeout=Duration.of_sec 60) t = + let timeout = + Host.Time.sleep_ns timeout >>= fun () -> + Lwt.fail_with "timeout" + in + Host.Main.run @@ Lwt.pick [ timeout; t ] - module EchoServer = struct - (* Receive UDP packets and copy them back to all senders. Roughly simulates - a chat protocol, in particular this allows us to test many replies to one - request. *) - type t = { - local_port: int; - server: Host.Sockets.Datagram.Udp.server; - mutable seen_addresses: Host.Sockets.Datagram.address list; - } +module EchoServer = struct + (* Receive UDP packets and copy them back to all senders. Roughly simulates + a chat protocol, in particular this allows us to test many replies to one + request. *) + type t = { + local_port: int; + server: Host.Sockets.Datagram.Udp.server; + mutable seen_addresses: Host.Sockets.Datagram.address list; + } - let create () = - Host.Sockets.Datagram.Udp.bind (Ipaddr.(V4 V4.localhost), 0) - >>= fun server -> - let _, local_port = Host.Sockets.Datagram.Udp.getsockname server in - (* Start a background echo thread. This will naturally fail when the - file descriptor is closed underneath it from `shutdown` *) - let seen_addresses = [] in - let t = { local_port; server; seen_addresses } in - let _ = - let buf = Cstruct.create 2048 in - let rec loop () = - Host.Sockets.Datagram.Udp.recvfrom server buf - >>= fun (len, address) -> - t.seen_addresses <- address :: t.seen_addresses; - Lwt_list.iter_p (fun address -> - Host.Sockets.Datagram.Udp.sendto server address - (Cstruct.sub buf 0 len) - ) t.seen_addresses - >>= - loop - in - loop () + let create () = + Host.Sockets.Datagram.Udp.bind (Ipaddr.(V4 V4.localhost), 0) + >>= fun server -> + let _, local_port = Host.Sockets.Datagram.Udp.getsockname server in + (* Start a background echo thread. This will naturally fail when the + file descriptor is closed underneath it from `shutdown` *) + let seen_addresses = [] in + let t = { local_port; server; seen_addresses } in + let _ = + let buf = Cstruct.create 2048 in + let rec loop () = + Host.Sockets.Datagram.Udp.recvfrom server buf + >>= fun (len, address) -> + t.seen_addresses <- address :: t.seen_addresses; + Lwt_list.iter_p (fun address -> + Host.Sockets.Datagram.Udp.sendto server address + (Cstruct.sub buf 0 len) + ) t.seen_addresses + >>= + loop in - Lwt.return t - - let get_seen_addresses t = t.seen_addresses + loop () + in + Lwt.return t - let to_string t = - Printf.sprintf "udp:127.0.0.1:%d" t.local_port - let destroy t = Host.Sockets.Datagram.Udp.shutdown t.server - let with_server f = - create () >>= fun server -> - Lwt.finalize (fun () -> f server) (fun () -> destroy server) - end + let get_seen_addresses t = t.seen_addresses - module UdpServer = struct - module PortSet = - Set.Make(struct type t = int let compare = Pervasives.compare end) + let to_string t = + Printf.sprintf "udp:127.0.0.1:%d" t.local_port + let destroy t = Host.Sockets.Datagram.Udp.shutdown t.server + let with_server f = + create () >>= fun server -> + Lwt.finalize (fun () -> f server) (fun () -> destroy server) +end - type t = { - port: int; - mutable highest: int; (* highest packet payload received *) - mutable seen_source_ports: PortSet.t; (* all source addresses seen *) - c: unit Lwt_condition.t; - } - let make stack port = - let highest = 0 in - let c = Lwt_condition.create () in - let seen_source_ports = PortSet.empty in - let t = { port; highest; seen_source_ports; c } in - Client.listen_udpv4 stack ~port (fun ~src:_ ~dst:_ ~src_port buffer -> - t.highest <- max t.highest (Cstruct.get_uint8 buffer 0); - t.seen_source_ports <- PortSet.add src_port t.seen_source_ports; - Log.debug (fun f -> - f "Received UDP %d -> %d highest %d" src_port port t.highest); - Lwt_condition.signal c (); - Lwt.return_unit - ); - t - let wait_for_data ~highest t = - if t.highest < highest then begin - Lwt.pick [ Lwt_condition.wait t.c; - Host.Time.sleep_ns (Duration.of_sec 1) ] - >>= fun () -> - Lwt.return (t.highest >= highest) - end else Lwt.return true - let wait_for_ports ~num t = - if PortSet.cardinal t.seen_source_ports < num then begin - Lwt.pick [ Lwt_condition.wait t.c; - Host.Time.sleep_ns (Duration.of_sec 1) ] - >|= fun () -> - PortSet.cardinal t.seen_source_ports >= num - end else Lwt.return true - end +module UdpServer = struct + module PortSet = + Set.Make(struct type t = int let compare = Pervasives.compare end) - let err_udp e = Fmt.kstrf failwith "%a" Client.UDPV4.pp_error e + type t = { + port: int; + mutable highest: int; (* highest packet payload received *) + mutable seen_source_ports: PortSet.t; (* all source addresses seen *) + c: unit Lwt_condition.t; + } + let make stack port = + let highest = 0 in + let c = Lwt_condition.create () in + let seen_source_ports = PortSet.empty in + let t = { port; highest; seen_source_ports; c } in + Client.listen_udpv4 stack ~port (fun ~src:_ ~dst:_ ~src_port buffer -> + t.highest <- max t.highest (Cstruct.get_uint8 buffer 0); + t.seen_source_ports <- PortSet.add src_port t.seen_source_ports; + Log.debug (fun f -> + f "Received UDP %d -> %d highest %d" src_port port t.highest); + Lwt_condition.signal c (); + Lwt.return_unit + ); + t + let wait_for_data ~highest t = + if t.highest < highest then begin + Lwt.pick [ Lwt_condition.wait t.c; + Host.Time.sleep_ns (Duration.of_sec 1) ] + >>= fun () -> + Lwt.return (t.highest >= highest) + end else Lwt.return true + let wait_for_ports ~num t = + if PortSet.cardinal t.seen_source_ports < num then begin + Lwt.pick [ Lwt_condition.wait t.c; + Host.Time.sleep_ns (Duration.of_sec 1) ] + >|= fun () -> + PortSet.cardinal t.seen_source_ports >= num + end else Lwt.return true +end - (* Start a local UDP echo server, send traffic to it and listen for - a response *) - let test_udp () = - let t = EchoServer.with_server (fun { EchoServer.local_port; _ } -> - with_stack (fun _ stack -> - let buffer = Cstruct.create 1024 in - (* Send '1' *) - Cstruct.set_uint8 buffer 0 1; - let udpv4 = Client.udpv4 stack in - let virtual_port = 1024 in - let server = UdpServer.make stack virtual_port in - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port local_port - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~src_port:virtual_port - ~dst:Ipaddr.V4.localhost - ~dst_port:local_port udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> - UdpServer.wait_for_data ~highest:1 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) - in - loop 5 - )) - in - run t +let err_udp e = Fmt.kstrf failwith "%a" Client.UDPV4.pp_error e - (* Start a local UDP mult-echo server, send traffic to it from one - source port, wait for the response, send traffic to it from - another source port, expect responses to *both* source ports. *) - let test_udp_2 () = - let t = EchoServer.with_server (fun { EchoServer.local_port; _ } -> - with_stack (fun _ stack -> - let buffer = Cstruct.create 1024 in - (* Send '1' *) - Cstruct.set_uint8 buffer 0 1; - let udpv4 = Client.udpv4 stack in +(* Start a local UDP echo server, send traffic to it and listen for + a response *) +let test_udp () = + let t = EchoServer.with_server (fun { EchoServer.local_port; _ } -> + with_stack (fun _ stack -> + let buffer = Cstruct.create 1024 in + (* Send '1' *) + Cstruct.set_uint8 buffer 0 1; + let udpv4 = Client.udpv4 stack in + let virtual_port = 1024 in + let server = UdpServer.make stack virtual_port in + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port local_port + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) + in + loop 5 + )) + in + run t - (* Listen on one virtual source port and count received packets *) - let virtual_port1 = 1024 in - let server1 = UdpServer.make stack virtual_port1 in +(* Start a local UDP mult-echo server, send traffic to it from one + source port, wait for the response, send traffic to it from + another source port, expect responses to *both* source ports. *) +let test_udp_2 () = + let t = EchoServer.with_server (fun { EchoServer.local_port; _ } -> + with_stack (fun _ stack -> + let buffer = Cstruct.create 1024 in + (* Send '1' *) + Cstruct.set_uint8 buffer 0 1; + let udpv4 = Client.udpv4 stack in - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port1 local_port - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~src_port:virtual_port1 - ~dst:Ipaddr.V4.localhost - ~dst_port:local_port udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> - UdpServer.wait_for_data ~highest:1 server1 >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) - in - loop 5 >>= fun () -> - (* Listen on a second virtual source port and count - received packets *) - (* Send '2' *) - Cstruct.set_uint8 buffer 0 2; - let virtual_port2 = 1025 in - let server2 = UdpServer.make stack virtual_port2 in - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port2 local_port - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~src_port:virtual_port2 - ~dst:Ipaddr.V4.localhost - ~dst_port:local_port udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> - UdpServer.wait_for_data ~highest:2 server2 >>= fun ok2 -> - (* The server should "multicast" the packet to the - original "connection" *) - UdpServer.wait_for_data ~highest:2 server1 >>= fun ok1 -> - if ok1 && ok2 then Lwt.return_unit else loop (remaining - 1) - in - loop 5 - ) - ) in - run t + (* Listen on one virtual source port and count received packets *) + let virtual_port1 = 1024 in + let server1 = UdpServer.make stack virtual_port1 in - (* Start a local UDP echo server, send some traffic to it over the - virtual interface. Send traffic to the outside address on a - second physical interface, check that this external third party - can traverse the NAT *) - let test_nat_punch () = - let t = EchoServer.with_server (fun echoserver -> - with_stack (fun _ stack -> - let buffer = Cstruct.create 1024 in - (* Send '1' *) - Cstruct.set_uint8 buffer 0 1; - let udpv4 = Client.udpv4 stack in + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port1 local_port + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port1 + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server1 >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) + in + loop 5 >>= fun () -> + (* Listen on a second virtual source port and count + received packets *) + (* Send '2' *) + Cstruct.set_uint8 buffer 0 2; + let virtual_port2 = 1025 in + let server2 = UdpServer.make stack virtual_port2 in + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port2 local_port + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port2 + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:2 server2 >>= fun ok2 -> + (* The server should "multicast" the packet to the + original "connection" *) + UdpServer.wait_for_data ~highest:2 server1 >>= fun ok1 -> + if ok1 && ok2 then Lwt.return_unit else loop (remaining - 1) + in + loop 5 + ) + ) in + run t - (* Listen on one virtual source port and count received packets *) - let virtual_port1 = 1024 in - let server1 = UdpServer.make stack virtual_port1 in +(* Start a local UDP echo server, send some traffic to it over the + virtual interface. Send traffic to the outside address on a + second physical interface, check that this external third party + can traverse the NAT *) +let test_nat_punch () = + let t = EchoServer.with_server (fun echoserver -> + with_stack (fun _ stack -> + let buffer = Cstruct.create 1024 in + (* Send '1' *) + Cstruct.set_uint8 buffer 0 1; + let udpv4 = Client.udpv4 stack in - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - let dst_port = echoserver.EchoServer.local_port in - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port1 dst_port - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~src_port:virtual_port1 - ~dst:Ipaddr.V4.localhost - ~dst_port udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> - UdpServer.wait_for_data ~highest:1 server1 >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) - in - loop 5 >>= fun () -> + (* Listen on one virtual source port and count received packets *) + let virtual_port1 = 1024 in + let server1 = UdpServer.make stack virtual_port1 in - (* Using the physical outside interface, send traffic to - the address and see if this traffic will also be sent - via the NAT. *) - (* Send '2' *) - Cstruct.set_uint8 buffer 0 2; - Host.Sockets.Datagram.Udp.bind (Ipaddr.(V4 V4.localhost), 0) - >>= fun client -> - let _, source_port = Host.Sockets.Datagram.Udp.getsockname client in - let address = List.hd (EchoServer.get_seen_addresses echoserver) in - let _, dest_port = address in - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - Log.debug (fun f -> - f "Sending %d -> %d value %d" source_port dest_port - (Cstruct.get_uint8 buffer 0)); - Host.Sockets.Datagram.Udp.sendto client address buffer - >>= fun () -> - UdpServer.wait_for_data ~highest:2 server1 >>= function + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + let dst_port = echoserver.EchoServer.local_port in + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port1 dst_port + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port1 + ~dst:Ipaddr.V4.localhost + ~dst_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server1 >>= function | true -> Lwt.return_unit | false -> loop (remaining - 1) - in - loop 5)) - in - run t + in + loop 5 >>= fun () -> - (* The NAT table rule should be associated with the virtual address, - rather than physical address. Check if we have 2 physical servers - we have only a single NAT rule *) - let test_shared_nat_rule () = - let t = EchoServer.with_server (fun { EchoServer.local_port; _ } -> - with_stack (fun slirp_server stack -> - let buffer = Cstruct.create 1024 in - (* Send '1' *) - Cstruct.set_uint8 buffer 0 1; - let udpv4 = Client.udpv4 stack in - let virtual_port = 1024 in - let server = UdpServer.make stack virtual_port in - let init_table_size = - Slirp_stack.Debug.get_nat_table_size slirp_server - in + (* Using the physical outside interface, send traffic to + the address and see if this traffic will also be sent + via the NAT. *) + (* Send '2' *) + Cstruct.set_uint8 buffer 0 2; + Host.Sockets.Datagram.Udp.bind (Ipaddr.(V4 V4.localhost), 0) + >>= fun client -> + let _, source_port = Host.Sockets.Datagram.Udp.getsockname client in + let address = List.hd (EchoServer.get_seen_addresses echoserver) in + let _, dest_port = address in + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" source_port dest_port + (Cstruct.get_uint8 buffer 0)); + Host.Sockets.Datagram.Udp.sendto client address buffer + >>= fun () -> + UdpServer.wait_for_data ~highest:2 server1 >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) + in + loop 5)) + in + run t - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port local_port - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~src_port:virtual_port - ~dst:Ipaddr.V4.localhost - ~dst_port:local_port udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> - UdpServer.wait_for_data ~highest:1 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) - in - loop 5 >>= fun () -> - Alcotest.(check int) "One NAT rule" 1 - (Slirp_stack.Debug.get_nat_table_size slirp_server - - init_table_size); - (* Send '2' *) - Cstruct.set_uint8 buffer 0 2; - (* Create another physical server and send traffic from - the same virtual address *) - EchoServer.with_server (fun { EchoServer.local_port; _ } -> - let rec loop remaining = - if remaining = 0 then - failwith "Timed-out before UDP response arrived"; - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port local_port - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write ~src_port:virtual_port - ~dst:Ipaddr.V4.localhost - ~dst_port:local_port udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> - UdpServer.wait_for_data ~highest:2 server >>= function - | true -> Lwt.return_unit - | false -> loop (remaining - 1) - in - loop 5 >|= fun () -> - Alcotest.(check int) "Still one NAT rule" 1 - (Slirp_stack.Debug.get_nat_table_size slirp_server - - init_table_size) - ))) - in - run t +(* The NAT table rule should be associated with the virtual address, + rather than physical address. Check if we have 2 physical servers + we have only a single NAT rule *) +let test_shared_nat_rule () = + let t = EchoServer.with_server (fun { EchoServer.local_port; _ } -> + with_stack (fun slirp_server stack -> + let buffer = Cstruct.create 1024 in + (* Send '1' *) + Cstruct.set_uint8 buffer 0 1; + let udpv4 = Client.udpv4 stack in + let virtual_port = 1024 in + let server = UdpServer.make stack virtual_port in + let init_table_size = + Slirp_stack.Debug.get_nat_table_size slirp_server + in - (* If we have two physical servers but send data from the same source port, - we should see both physical server source ports *) - let test_source_ports () = - let t = EchoServer.with_server - (fun { EchoServer.local_port = local_port1; _ } -> - EchoServer.with_server - (fun { EchoServer.local_port = local_port2; _ } -> - with_stack (fun _ stack -> - let buffer = Cstruct.create 1024 in - let udpv4 = Client.udpv4 stack in - (* This is the port we shall send from *) - let virtual_port = 1024 in - let server = UdpServer.make stack virtual_port in - let rec loop remaining = - Printf.fprintf stderr "remaining=%d\n%!" remaining; - if remaining = 0 then - failwith "Timed-out before both UDP ports were seen"; + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port local_port + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:1 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) + in + loop 5 >>= fun () -> + Alcotest.(check int) "One NAT rule" 1 + (Slirp_stack.Debug.get_nat_table_size slirp_server + - init_table_size); + (* Send '2' *) + Cstruct.set_uint8 buffer 0 2; + (* Create another physical server and send traffic from + the same virtual address *) + EchoServer.with_server (fun { EchoServer.local_port; _ } -> + let rec loop remaining = + if remaining = 0 then + failwith "Timed-out before UDP response arrived"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port local_port + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> + UdpServer.wait_for_data ~highest:2 server >>= function + | true -> Lwt.return_unit + | false -> loop (remaining - 1) + in + loop 5 >|= fun () -> + Alcotest.(check int) "Still one NAT rule" 1 + (Slirp_stack.Debug.get_nat_table_size slirp_server + - init_table_size) + ))) + in + run t + +(* If we have two physical servers but send data from the same source port, + we should see both physical server source ports *) +let test_source_ports () = + let t = EchoServer.with_server + (fun { EchoServer.local_port = local_port1; _ } -> + EchoServer.with_server + (fun { EchoServer.local_port = local_port2; _ } -> + with_stack (fun _ stack -> + let buffer = Cstruct.create 1024 in + let udpv4 = Client.udpv4 stack in + (* This is the port we shall send from *) + let virtual_port = 1024 in + let server = UdpServer.make stack virtual_port in + let rec loop remaining = + Printf.fprintf stderr "remaining=%d\n%!" remaining; + if remaining = 0 then + failwith "Timed-out before both UDP ports were seen"; + Log.debug (fun f -> + f "Sending %d -> %d value %d" virtual_port local_port1 + (Cstruct.get_uint8 buffer 0)); + Client.UDPV4.write + ~src_port:virtual_port + ~dst:Ipaddr.V4.localhost + ~dst_port:local_port1 udpv4 buffer + >>= function + | Error e -> err_udp e + | Ok () -> Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port local_port1 + f "Sending %d -> %d value %d" virtual_port local_port2 (Cstruct.get_uint8 buffer 0)); Client.UDPV4.write ~src_port:virtual_port ~dst:Ipaddr.V4.localhost - ~dst_port:local_port1 udpv4 buffer + ~dst_port:local_port2 udpv4 buffer >>= function | Error e -> err_udp e | Ok () -> - Log.debug (fun f -> - f "Sending %d -> %d value %d" virtual_port local_port2 - (Cstruct.get_uint8 buffer 0)); - Client.UDPV4.write - ~src_port:virtual_port - ~dst:Ipaddr.V4.localhost - ~dst_port:local_port2 udpv4 buffer - >>= function - | Error e -> err_udp e - | Ok () -> UdpServer.wait_for_ports ~num:2 server >>= function | true -> Lwt.return_unit | false -> loop (remaining - 1) - in - loop 5))) - in - Host.Main.run t + in + loop 5))) + in + Host.Main.run t - let tests = [ - "NAT: shared rule", [ "", `Quick, test_shared_nat_rule ]; - "NAT: 1 UDP connection", [ "", `Quick, test_udp ]; - "NAT: 2 UDP connections", [ "", `Quick, test_udp_2 ]; - "NAT: punch", [ "", `Quick, test_nat_punch ]; - "NAT: source ports", [ "", `Quick, test_source_ports ]; - ] +let tests = [ + "NAT: shared rule", [ "", `Quick, test_shared_nat_rule ]; + "NAT: 1 UDP connection", [ "", `Quick, test_udp ]; + "NAT: 2 UDP connections", [ "", `Quick, test_udp_2 ]; + "NAT: punch", [ "", `Quick, test_nat_punch ]; + "NAT: source ports", [ "", `Quick, test_source_ports ]; +]