diff --git a/lib/tcp/window.ml b/lib/tcp/window.ml index d1d236bd8..b6a8c51ff 100644 --- a/lib/tcp/window.ml +++ b/lib/tcp/window.ml @@ -60,9 +60,6 @@ let count_ackd_segs = MProf.Counter.make ~name:"tcp-ackd-segs" let default_mss = 536 -let alpha = 0.125 (* see RFC 2988 *) -let beta = 0.25 (* see RFC 2988 *) - (* To string for debugging *) let pp fmt t = Format.fprintf fmt @@ -195,14 +192,16 @@ module Make(Clock:Mirage_clock.MCLOCK) = struct t.rttvar <- Int64.div rtt_m 2L; t.srtt <- rtt_m; end else begin - let adjusted_rttvar = (1.0 -. beta) *. (Int64.to_float t.rttvar) in - let rttvar_addition = beta *. Int64.(sub t.srtt rtt_m |> abs |> to_float) in - let adjusted_srtt = (1.0 -. alpha) *. (Int64.to_float t.srtt) in - let srtt_addition = alpha *. (Int64.to_float rtt_m) in - t.rttvar <- Int64.of_float (adjusted_rttvar +. rttvar_addition); - t.srtt <- Int64.of_float (adjusted_srtt +. srtt_addition); + let (/) = Int64.div + and ( * ) = Int64.mul + and (-) = Int64.sub + and (+) = Int64.add + in + (* RFC2988 2.3 *) + t.rttvar <- (3L * t.rttvar / 4L) + (Int64.abs (t.srtt - rtt_m) / 4L); + t.srtt <- (7L * t.srtt / 8L) + (rtt_m / 8L) end; - t.rto <- (max (Duration.of_ms 667) (Int64.add t.srtt (Int64.mul 4L t.rttvar))); + t.rto <- max (Duration.of_ms 667) Int64.(add t.srtt (mul t.rttvar 4L)); end; end; let cwnd_incr = match t.cwnd < t.ssthresh with diff --git a/lib_test/test_tcp_window.ml b/lib_test/test_tcp_window.ml index ceece58ce..872d20032 100644 --- a/lib_test/test_tcp_window.ml +++ b/lib_test/test_tcp_window.ml @@ -1,14 +1,20 @@ open Lwt.Infix module Clock = struct + (* Mirage_device.S *) type error = string type t = { time: int64 } type 'a io = 'a Lwt.t - let tick {time} = {time = Int64.add time 1L} let disconnect _ = Lwt.return_unit let connect () = Lwt.return { time = 0L } + + (* Mirage_clock.MCLOCK *) let period_ns _ = None let elapsed_ns {time} = time + + (* Test-related function: advance by 1 ns *) + let tick {time} = { time = Int64.add time 1L } + let tick_for {time} duration = { time = Int64.add time duration } end module Timed_window = Tcp.Window.Make(Clock) @@ -90,8 +96,35 @@ let recover_fast () = Alcotest.(check bool) "once entering fast recovery, we can send >0 packets" true ((Int32.compare (Tcp.Window.tx_available window) 0l) > 0); Lwt.return_unit - + +let rto_calculation () = + let window = default_window () in + (* RFC 2988 2.1 *) + Alcotest.(check int64) "initial rto is 2/3 second" (Duration.of_ms 667) @@ Tcp.Window.rto window; + let receive_window = Tcp.Window.ack_win window in + Clock.connect () >>= fun clock -> + Timed_window.tx_advance clock window (Tcp.Window.tx_nxt window); + let clock = Clock.tick_for clock (Duration.of_ms 400) in + let max_size = Tcp.Window.tx_available window |> Tcp.Sequence.of_int32 in + let sz = Tcp.Sequence.add max_size @@ (Tcp.Window.tx_nxt window) in + Timed_window.tx_ack clock window sz receive_window; + (* RFC 2988 2.2 *) + Alcotest.(check int64) "After one RTT measurement, the calculated rto is 400 + (4 * 200) = 1200ms" (Duration.of_ms 1200) @@ Tcp.Window.rto window; + + (* RFC 2988 2.3 *) + Timed_window.tx_advance clock window (Tcp.Window.tx_nxt window); + let receive_window = Tcp.Window.ack_win window in + let clock = Clock.tick_for clock (Duration.of_ms 300) in + let max_size = Tcp.Window.tx_available window |> Tcp.Sequence.of_int32 in + let sz = Tcp.Sequence.add max_size @@ (Tcp.Window.tx_nxt window) in + Timed_window.tx_ack clock window sz receive_window; + Alcotest.(check int64) "After subsequent RTT measurement, the calculated rto is 1087.5ms" (Duration.of_us 1087500) @@ Tcp.Window.rto window; + + Lwt.return_unit + + let suite = [ "fresh window is sensible", `Quick, fresh_window; "fast recovery recovers fast", `Quick, recover_fast; + "smoothed rtt, rtt variation and retransmission timer are calculated according to RFC2988", `Quick, rto_calculation; ]