diff --git a/mirage/server/dns_server_mirage.ml b/mirage/server/dns_server_mirage.ml index 9db507745..3ee3f8f15 100644 --- a/mirage/server/dns_server_mirage.ml +++ b/mirage/server/dns_server_mirage.ml @@ -178,7 +178,7 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t end in - let rec close ip = + let rec close ~timer ip = (match Ipaddr.Map.find_opt ip !tcp_out with | None -> Lwt.return_unit | Some f -> T.close f) >>= fun () -> @@ -187,12 +187,15 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t let elapsed = M.elapsed_ns () in let state', out = Dns_server.Secondary.closed !state now elapsed ip in state := state' ; - request (ip, out) - and read_and_handle ip f = + if not timer then + request ~timer (ip, out) + else + Lwt.return_unit + and read_and_handle ~timer ip f = Dns.read_tcp f >>= function | Error () -> Log.debug (fun m -> m "removing %a from tcp_out" Ipaddr.pp ip) ; - close ip + close ~timer ip | Ok data -> inc `Tcp_query; let now = Ptime.v (P.now_d_ps ()) in @@ -208,15 +211,15 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t Dns.send_tcp (Dns.flow f) x >>= function | Error () -> Log.debug (fun m -> m "removing %a from tcp_out" Ipaddr.pp ip) ; - close ip >|= fun () -> Error () + close ~timer ip >|= fun () -> Error () | Ok () -> Lwt.return (Ok ())) >>= fun r -> (match out with | None -> Lwt.return_unit - | Some (ip, data) -> request_one (ip, data)) >>= fun () -> + | Some (ip, data) -> request_one ~timer (ip, data)) >>= fun () -> match r with - | Ok () -> read_and_handle ip f + | Ok () -> read_and_handle ~timer ip f | Error () -> Lwt.return_unit - and request (ip, data) = + and request ~timer (ip, data) = inc `Notify; let dport = 53 in match Ipaddr.Map.find_opt ip !tcp_out with @@ -228,13 +231,13 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t | Error e -> Log.err (fun m -> m "error %a while establishing tcp connection to %a:%d" T.pp_error e Ipaddr.pp ip dport) ; - close ip + close ~timer ip | Ok flow -> tcp_out := Ipaddr.Map.add ip flow !tcp_out ; Dns.send_tcp_multiple flow data >>= function - | Error () -> close ip + | Error () -> close ~timer ip | Ok () -> - Lwt.async (fun () -> read_and_handle ip (Dns.of_flow flow)) ; + Lwt.async (fun () -> read_and_handle ~timer ip (Dns.of_flow flow)) ; Lwt.return_unit end | Some flow -> @@ -245,8 +248,8 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t Ipaddr.pp ip dport) ; T.close flow >>= fun () -> tcp_out := Ipaddr.Map.remove ip !tcp_out ; - request (ip, data) - and request_one (ip, d) = request (ip, [ d ]) + request ~timer (ip, data) + and request_one ~timer (ip, d) = request ~timer (ip, [ d ]) in let udp_cb ~src ~dst:_ ~src_port buf = @@ -260,7 +263,7 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t maybe_update_state t >>= fun () -> (match out with | None -> () - | Some (ip, cs) -> Lwt.async (fun () -> request_one (ip, cs))) ; + | Some (ip, cs) -> Lwt.async (fun () -> request_one ~timer:false (ip, cs))) ; match answer with | None -> Lwt.return_unit | Some out -> inc `Udp_answer; Dns.send_udp stack port src src_port out @@ -287,7 +290,7 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t maybe_update_state t >>= fun () -> (match out with | None -> () - | Some (ip, cs) -> Lwt.async (fun () -> request_one (ip, cs))); + | Some (ip, cs) -> Lwt.async (fun () -> request_one ~timer:false (ip, cs))); match answer with | None -> Log.warn (fun m -> m "no TCP output") ; @@ -309,7 +312,7 @@ module Make (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (TIME : Mirage_t let t, out = Dns_server.Secondary.timer !state now elapsed in maybe_update_state t >>= fun () -> List.iter (fun (ip, cs) -> - Lwt.async (fun () -> request (ip, cs))) out ; + Lwt.async (fun () -> request ~timer:true (ip, cs))) out ; TIME.sleep_ns (Duration.of_sec timer) >>= fun () -> time () in