Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing memory leaks when a simultaneous close is happening #489

Merged
merged 11 commits into from
Jul 27, 2022
12 changes: 7 additions & 5 deletions src/tcp/flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ module Log = (val Logs.src_log src : Logs.LOG)
module Make(Ip: Tcpip.Ip.S)(Time:Mirage_time.S)(Clock:Mirage_clock.MCLOCK)(Random:Mirage_random.S) =
struct

module RXS = Segment.Rx(Time)
module TXS = Segment.Tx(Time)(Clock)
module ACK = Ack.Immediate
module RXS = Segment.Rx(Time)(ACK)
module TXS = Segment.Tx(Time)(Clock)
module UTX = User_buffer.Tx(Time)(Clock)
module WIRE = Wire.Make(Ip)
module STATE = State.Make(Time)
Expand Down Expand Up @@ -75,6 +75,8 @@ struct
connects: (WIRE.t, ((connection, error) result Lwt.u * Sequence.t * Tcpip.Tcp.Keepalive.t option)) Hashtbl.t;
}

let num_open_channels t = Hashtbl.length t.channels

let listen t ~port ?keepalive cb =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
Expand Down Expand Up @@ -356,11 +358,11 @@ struct
let txq, _tx_t =
TXS.create ~xmit:(Tx.xmit_pcb t.ip id) ~wnd ~state ~rx_ack ~tx_ack ~tx_wnd_update
in
(* The user application transmit buffer *)
let utx = UTX.create ~wnd ~txq ~max_size:16384l in
let rxq = RXS.create ~rx_data ~wnd ~state ~tx_ack in
(* Set up ACK module *)
let ack = ACK.t ~send_ack ~last:(Sequence.succ rx_isn) in
(* The user application transmit buffer *)
let utx = UTX.create ~wnd ~txq ~max_size:16384l in
let rxq = RXS.create ~rx_data ~ack ~wnd ~state ~tx_ack in
(* Set up the keepalive state if requested *)
let keepalive = match keepalive with
| None -> None
Expand Down
6 changes: 6 additions & 0 deletions src/tcp/flow.mli
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,10 @@ module Make (IP:Tcpip.Ip.S)
(R:Mirage_random.S) : sig
include Tcpip.Tcp.S with type ipaddr = IP.ipaddr
val connect : IP.t -> t Lwt.t

(**/**)
(* the number of open connections *)
val num_open_channels : t -> int
(**/**)

end
14 changes: 6 additions & 8 deletions src/tcp/segment.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ let rec reset_seq segs =
It also looks for control messages and dispatches them to
the Rtx queue to ack messages or close channels.
*)
module Rx(Time:Mirage_time.S) = struct
module Rx(Time:Mirage_time.S)(ACK: Ack.M) = struct
open Tcp_packet
module StateTick = State.Make(Time)

Expand Down Expand Up @@ -84,14 +84,15 @@ module Rx(Time:Mirage_time.S) = struct
type t = {
mutable segs: S.t;
rx_data: (Cstruct.t list option * Sequence.t option) Lwt_mvar.t; (* User receive channel *)
ack: ACK.t;
tx_ack: (Sequence.t * int) Lwt_mvar.t; (* Acks of our transmitted segs *)
wnd: Window.t;
state: State.t;
}

let create ~rx_data ~wnd ~state ~tx_ack =
let create ~rx_data ~ack ~wnd ~state ~tx_ack =
let segs = S.empty in
{ segs; rx_data; tx_ack; wnd; state }
{ segs; rx_data; ack; tx_ack; wnd; state }

let pp fmt t =
let pp_v fmt seg =
Expand Down Expand Up @@ -135,10 +136,7 @@ module Rx(Time:Mirage_time.S) = struct

let send_challenge_ack q =
(* TODO: rfc5961 ACK Throttling *)
(* Is this the correct way trigger an ack? *)
if Lwt_mvar.is_empty q.rx_data
then Lwt_mvar.put q.rx_data (Some [], Some Sequence.zero)
else Lwt.return_unit
ACK.pushack q.ack Sequence.zero

(* Given an input segment, the window information, and a receive
queue, update the window, extract any ready segments into the
Expand Down Expand Up @@ -287,7 +285,7 @@ module Tx (Time:Mirage_time.S) (Clock:Mirage_clock.MCLOCK) = struct
let ontimer xmit st segs wnd seq =
match State.state st with
| State.Syn_rcvd _ | State.Established | State.Fin_wait_1 _
| State.Close_wait | State.Last_ack _ ->
| State.Close_wait | State.Closing _ | State.Last_ack _ ->
begin match peek_opt_l segs with
| None -> Lwt.return Tcptimer.Stoptimer
| Some rexmit_seg ->
Expand Down
3 changes: 2 additions & 1 deletion src/tcp/segment.mli
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
the Rtx queue to ack messages or close channels.
*)

module Rx (T:Mirage_time.S) : sig
module Rx (T:Mirage_time.S)(ACK:Ack.M) : sig

type segment = { header: Tcp_packet.t; payload: Cstruct.t }
(** Individual received TCP segment *)
Expand All @@ -38,6 +38,7 @@ module Rx (T:Mirage_time.S) : sig

val create:
rx_data:(Cstruct.t list option * Sequence.t option) Lwt_mvar.t ->
ack:ACK.t ->
wnd:Window.t ->
state:State.t ->
tx_ack:(Sequence.t * int) Lwt_mvar.t ->
Expand Down
13 changes: 9 additions & 4 deletions src/tcp/state.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ module Make(Time:Mirage_time.S) = struct
t.on_close ();
Lwt.return_unit

let transition_to_timewait t =
Lwt.async (fun () -> timewait t time_wait_time);
Time_wait

let tick t (i:action) =
let diffone x y = Sequence.succ y = x in
let tstr s (i:action) =
Expand Down Expand Up @@ -148,10 +152,11 @@ module Make(Time:Mirage_time.S) = struct
| Fin_wait_1 _, Recv_rst -> t.on_close (); Reset
| Fin_wait_2 i, Recv_ack _ -> Fin_wait_2 (i + 1)
| Fin_wait_2 _, Recv_rst -> t.on_close (); Reset
| Fin_wait_2 _, Recv_fin ->
Lwt.async (fun () -> timewait t time_wait_time);
Time_wait
| Closing a, Recv_ack b -> if diffone b a then Time_wait else Closing a
| Fin_wait_2 _, Recv_fin -> transition_to_timewait t
| Closing a, Recv_ack b ->
if diffone b a then
transition_to_timewait t
else Closing a
| Closing _, Timeout -> t.on_close (); Closed
| Closing _, Recv_rst -> t.on_close (); Reset
| Time_wait, Timeout -> t.on_close (); Closed
Expand Down
150 changes: 150 additions & 0 deletions test/low_level.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
open Lwt.Infix

(*
* Connects two stacks to the same backend.
* One is a complete v4 stack (the system under test, referred to as [sut]).
* The other gives us low level access to inject crafted TCP packets,
* and sends and receives crafted packets to check the [sut] behavior.
*)
module VNETIF_STACK = Vnetif_common.VNETIF_STACK(Vnetif_backends.Basic)

module Time = Vnetif_common.Time
module V = Vnetif.Make(Vnetif_backends.Basic)
module E = Ethernet.Make(V)
module A = Arp.Make(E)(Time)
module I = Static_ipv4.Make(Mirage_random_test)(Vnetif_common.Clock)(E)(A)
module Wire = Tcp.Wire
module WIRE = Wire.Make(I)
module Tcp_wire = Tcp.Tcp_wire
module Tcp_unmarshal = Tcp.Tcp_packet.Unmarshal
module Sequence = Tcp.Sequence

let sut_cidr = Ipaddr.V4.Prefix.of_string_exn "10.0.0.101/24"
let server_ip = Ipaddr.V4.of_string_exn "10.0.0.100"
let server_cidr = Ipaddr.V4.Prefix.make 24 server_ip
let gateway = Ipaddr.V4.of_string_exn "10.0.0.1"

let header_size = Ethernet.Packet.sizeof_ethernet



(* defaults when injecting packets *)
let options = []
let window = 5120

(* Helper functions *)
let reply_id_from ~src ~dst data =
let sport = Tcp_wire.get_tcp_src_port data in
let dport = Tcp_wire.get_tcp_dst_port data in
WIRE.v ~dst_port:sport ~dst:src ~src_port:dport ~src:dst

let ack_for data =
match Tcp_unmarshal.of_cstruct data with
| Error s -> Alcotest.fail ("attempting to ack data: " ^ s)
| Ok (packet, data) ->
let open Tcp.Tcp_packet in
let data_len =
Sequence.of_int ((Cstruct.length data) +
(if packet.fin then 1 else 0) +
(if packet.syn then 1 else 0)) in
let sequence = packet.sequence in
let ack_n = Sequence.(add sequence data_len) in
ack_n

let ack data =
Some(ack_for data)

let ack_in_future data off =
Some Sequence.(add (ack_for data) (of_int off))

let ack_from_past data off =
Some Sequence.(sub (ack_for data) (of_int off))

let fail_result_not_expected fail = function
| Error _err ->
fail "error not expected"
| Ok `Eof ->
fail "eof"
| Ok (`Data data) ->
Alcotest.fail (Format.asprintf "data not expected but received: %a"
Cstruct.hexdump_pp data)



let create_sut_stack backend =
VNETIF_STACK.create_stack ~cidr:sut_cidr ~gateway backend

let create_raw_stack backend =
V.connect backend >>= fun netif ->
E.connect netif >>= fun ethif ->
A.connect ethif >>= fun arpv4 ->
I.connect ~cidr:server_cidr ~gateway ethif arpv4 >>= fun ip ->
Lwt.return (netif, ethif, arpv4, ip)

type 'state fsm_result =
| Fsm_next of 'state
| Fsm_done
| Fsm_error of string

(* This could be moved to a common module and reused for other low level tcp tests *)

(* setups network and run a given sut and raw fsm *)
let run backend fsm sut () =
let initial_state, fsm_handler = fsm in
create_sut_stack backend >>= fun stackv4 ->
create_raw_stack backend >>= fun (netif, ethif, arp, rawip) ->
let error_mbox = Lwt_mvar.create_empty () in
let stream, pushf = Lwt_stream.create () in
Lwt.pick [
VNETIF_STACK.Stackv4.listen stackv4;

(* Consume TCP packets one by one, in sequence *)
let rec fsm_thread state =
Lwt_stream.next stream >>= fun (src, dst, data) ->
fsm_handler rawip state ~src ~dst data >>= function
| Fsm_next s ->
fsm_thread s
| Fsm_done ->
Lwt.return_unit
| Fsm_error err ->
Lwt_mvar.put error_mbox err >>= fun () ->
(* it will be terminated anyway when the error is picked up *)
fsm_thread state in

Lwt.async (fun () ->
(V.listen netif ~header_size
(E.input
~arpv4:(A.input arp)
~ipv4:(I.input
~tcp: (fun ~src ~dst data -> pushf (Some(src,dst,data)); Lwt.return_unit)
~udp:(fun ~src:_ ~dst:_ _data -> Lwt.return_unit)
~default:(fun ~proto ~src ~dst _data ->
Logs.debug (fun f -> f "default handler invoked for packet from %a to %a, protocol %d -- dropping" Ipaddr.V4.pp src Ipaddr.V4.pp dst proto); Lwt.return_unit)
rawip
)
~ipv6:(fun _buf ->
Logs.debug (fun f -> f "IPv6 packet -- dropping");
Lwt.return_unit)
ethif) ) >|= fun _ -> ());

(* Either both fsm and the sut terminates, or a timeout occurs, or one of the sut/fsm informs an error *)
Lwt.pick [
(Time.sleep_ns (Duration.of_sec 5) >>= fun () ->
Lwt.return_some "timed out");

(Lwt.join [
(fsm_thread initial_state);

(* time to let the other end connects to the network and listen.
* Otherwise initial syn might need to be repeated slowing down the test *)
(Time.sleep_ns (Duration.of_ms 100) >>= fun () ->
sut stackv4 (Lwt_mvar.put error_mbox) >>= fun _ ->
Time.sleep_ns (Duration.of_ms 100));
] >>= fun () -> Lwt.return_none);

(Lwt_mvar.take error_mbox >>= fun cause ->
Lwt.return_some cause);
] >|= function
| None -> ()
| Some err -> Alcotest.fail err
]
1 change: 1 addition & 0 deletions test/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ let suite = [
"iperf" , Test_iperf.suite ;
"iperf_ipv6" , Test_iperf_ipv6.suite ;
"keepalive" , Test_keepalive.suite ;
"simultaneous_close", Test_simulatenous_close.suite
]

let run test () =
Expand Down
Loading