Skip to content

Commit

Permalink
Close partially-established TCP connections
Browse files Browse the repository at this point in the history
If a client sends SYN, we connect the external socket and reply with
SYN ACK. If the client responds with RST ACK then previously we would
leak the connection.

This patch refactors the connection closing mechanism, creating an
idempotent `close_flow` function which is called

- on normal close when the proxy receives `FIN` etc
- on a reset, including during the handshake
- when a switch port is being timed-out.

This replaces the previous `on_destroy` promise which was used in
`Lwt.pick` since closing the connection should cause the proxy to receive
EOF.

Related to [docker/for-mac#1132]

Signed-off-by: David Scott <dave.scott@docker.com>
  • Loading branch information
djs55 committed Jun 6, 2018
1 parent 3a8cfb4 commit 10f154c
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 63 deletions.
5 changes: 3 additions & 2 deletions src/hostnet/frame.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ and t =
| Icmp: { ty: int; code: int; raw: Cstruct.t; icmp: icmp } -> t
| Ipv4: ipv4 -> t
| Udp: { src: int; dst: int; len: int; raw: Cstruct.t; payload: t } -> t
| Tcp: { src: int; dst: int; syn: bool; raw: Cstruct.t; payload: t } -> t
| Tcp: { src: int; dst: int; syn: bool; rst: bool; raw: Cstruct.t; payload: t } -> t
| Payload: Cstruct.t -> t
| Unknown: t

Expand Down Expand Up @@ -97,7 +97,8 @@ let rec ipv4 inner =
let payload = Cstructs.shift inner ((offres lsr 4) * 4)
|> Cstructs.to_cstruct in
let syn = (flags land (1 lsl 1)) > 0 in
Ok (Tcp { src; dst; syn; raw = Cstructs.to_cstruct inner;
let rst = (flags land (1 lsl 2)) > 0 in
Ok (Tcp { src; dst; syn; rst; raw = Cstructs.to_cstruct inner;
payload = Payload payload })
| 17 ->
let raw = Cstructs.to_cstruct inner in
Expand Down
2 changes: 1 addition & 1 deletion src/hostnet/frame.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and t =
| Icmp: { ty: int; code: int; raw: Cstruct.t; icmp: icmp } -> t
| Ipv4: ipv4 -> t
| Udp: { src: int; dst: int; len: int; raw: Cstruct.t; payload: t } -> t
| Tcp: { src: int; dst: int; syn: bool; raw: Cstruct.t; payload: t } -> t
| Tcp: { src: int; dst: int; syn: bool; rst: bool; raw: Cstruct.t; payload: t } -> t
| Payload: Cstruct.t -> t
| Unknown: t

Expand Down
7 changes: 2 additions & 5 deletions src/hostnet/hostnet_dns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ struct
| Ok buffer ->
Udp.write ~src_port:53 ~dst:src ~dst_port:src_port udp buffer

let handle_tcp ~t ~close =
let handle_tcp ~t =
(* FIXME: need to record the upstream request *)
let listeners _ =
Log.debug (fun f -> f "DNS TCP handshake complete");
Expand Down Expand Up @@ -424,10 +424,7 @@ struct
Lwt.async queries;
loop ()
in
Lwt.pick [
loop ();
close
]
loop ()
in
Some f
in
Expand Down
2 changes: 1 addition & 1 deletion src/hostnet/hostnet_dns.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ sig
t:t -> udp:Udp.t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> src_port:int ->
Cstruct.t -> (unit, Udp.error) result Lwt.t

val handle_tcp: t:t -> close:(unit Lwt.t) -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t
val handle_tcp: t:t -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t

val destroy: t -> unit Lwt.t
end
128 changes: 74 additions & 54 deletions src/hostnet/slirp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ struct
t
let remove id =
all := Id.Map.remove id !all
let mem id = Id.Map.mem id !all
let find id = Id.Map.find id !all
let touch id =
if Id.Map.mem id !all
then (Id.Map.find id !all).last_active_time <- Unix.gettimeofday ()
Expand All @@ -239,8 +241,8 @@ struct
clock: Clock.t;
mutable pending: Tcp.Id.Set.t;
mutable last_active_time: float;
(* Tasks that will be signalled if the endpoint is destroyed *)
mutable on_destroy: unit Lwt.u Tcp.Id.Map.t;
(* Used to shutdown connections when the endpoint is removed from the switch. *)
mutable established: Tcp.Id.Set.t;
}
(** A generic TCP/IP endpoint *)

Expand All @@ -263,16 +265,49 @@ struct

let pending = Tcp.Id.Set.empty in
let last_active_time = Unix.gettimeofday () in
let on_destroy = Tcp.Id.Map.empty in
let established = Tcp.Id.Set.empty in
let tcp_stack =
{ recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
last_active_time; clock; on_destroy }
last_active_time; clock; established }
in
Lwt.return tcp_stack

(* close_flow is idempotent and may be called both from a regular RST/FIN and also
from a concurrent switch port timeout *)
let close_flow t ~id reason =
Log.debug (fun f -> f "%s: close_flow" (string_of_id id));
(* The flow might have been completely closed already *)
if Tcp.Flow.mem id then begin
let tcp = Tcp.Flow.find id in
begin match reason with
| `Port_disconnect ->
Log.warn (fun f -> f "%s closing flow due to idle port disconnection" (string_of_id id))
| `Reset ->
Log.debug (fun f -> f "%s: closing flow due to TCP RST" (string_of_id id))
| `Fin ->
Log.debug (fun f -> f "%s: closing flow due to TCP FIN" (string_of_id id))
end;
Tcp.Flow.remove id;
t.established <- Tcp.Id.Set.remove id t.established;
begin match tcp.Tcp.Flow.socket with
| Some socket ->
(* Note this should cause the proxy to exit cleanly *)
tcp.Tcp.Flow.socket <- None;
Host.Sockets.Stream.Tcp.close socket
| None ->
(* If we have a Tcp.Flow still in the table, there should still be an
active socket, otherwise the state has gotten out-of-sync *)
Log.warn (fun f -> f "%s: no socket registered, possible socket leak" (string_of_id id));
Lwt.return_unit
end
end else Lwt.return_unit

let destroy t =
Tcp.Id.Map.iter (fun _ u -> Lwt.wakeup_later u ()) t.on_destroy;
t.on_destroy <- Tcp.Id.Map.empty
let all = Tcp.Id.Set.fold (fun id acc -> id :: acc) t.established [] in
Lwt_list.iter_s (fun id -> close_flow t ~id `Port_disconnect) all
>>= fun () ->
t.established <- Tcp.Id.Set.empty;
Lwt.return_unit

let intercept_tcp_syn t ~id ~syn on_syn_callback (buf: Cstruct.t) =
if syn then begin
Expand All @@ -284,14 +319,10 @@ struct
Lwt.return_unit
end else begin
t.pending <- Tcp.Id.Set.add id t.pending;
(* Add a task to the "on_destroy" list which will be signalled if
the Endpoint is disconnected from the switch and we should close
connections. *)
let close, close_request = Lwt.task () in
t.on_destroy <- Tcp.Id.Map.add id close_request t.on_destroy;
t.established <- Tcp.Id.Set.add id t.established;
Lwt.finalize
(fun () ->
on_syn_callback close
on_syn_callback ()
>>= fun listeners ->
let src = Stack_tcp_wire.dst id in
let dst = Stack_tcp_wire.src id in
Expand All @@ -312,8 +343,14 @@ struct
module Proxy =
Mirage_flow_lwt.Proxy(Clock)(Stack_tcp)(Host.Sockets.Stream.Tcp)

let input_tcp t ~id ~syn (ip, port) (buf: Cstruct.t) =
intercept_tcp_syn t ~id ~syn (fun close ->
let input_tcp t ~id ~syn ~rst (ip, port) (buf: Cstruct.t) =
(* Note that we must cleanup even when the connection is reset before it
is fully established. *)
( if rst
then close_flow t ~id `Reset
else Lwt.return_unit )
>>= fun () ->
intercept_tcp_syn t ~id ~syn (fun () ->
Host.Sockets.Stream.Tcp.connect (ip, port)
>>= function
| Error (`Msg m) ->
Expand All @@ -324,7 +361,7 @@ struct
| Ok socket ->
let tcp = Tcp.Flow.create id socket in
let listeners port =
Log.debug (fun f ->
Log.info (fun f ->
f "%a:%d handshake complete" Ipaddr.pp_hum ip port);
let f flow =
match tcp.Tcp.Flow.socket with
Expand All @@ -335,35 +372,19 @@ struct
Lwt.return_unit
| Some socket ->
Lwt.finalize (fun () ->
Lwt.pick [
Lwt.map
(function Error e -> Error (`Proxy e) | Ok x -> Ok x)
(Proxy.proxy t.clock flow socket);
Lwt.map
(fun () -> Error `Close)
close
]
>>= function
| Error (`Close) ->
Log.info (fun f ->
f "%s proxy closed due to switch port disconnection"
(Tcp.Flow.to_string tcp));
Lwt.return_unit
| Error (`Proxy e) ->
Log.debug (fun f ->
f "%s proxy failed with %a"
(Tcp.Flow.to_string tcp) Proxy.pp_error e);
Lwt.return_unit
| Ok (_l_stats, _r_stats) ->
Lwt.return_unit
) (fun () ->
Proxy.proxy t.clock flow socket
>>= function
| Error e ->
Log.debug (fun f ->
f "closing flow %s" (string_of_id tcp.Tcp.Flow.id));
tcp.Tcp.Flow.socket <- None;
t.on_destroy <- Tcp.Id.Map.remove id t.on_destroy;
Tcp.Flow.remove tcp.Tcp.Flow.id;
Host.Sockets.Stream.Tcp.close socket
)
f "%s proxy failed with %a"
(Tcp.Flow.to_string tcp) Proxy.pp_error e);
Lwt.return_unit
| Ok (_l_stats, _r_stats) ->
Lwt.return_unit
) (fun () ->
Log.info (fun f -> f "%s proxy terminated" (Tcp.Flow.to_string tcp));
close_flow t ~id `Fin
)
in
Some f
in
Expand Down Expand Up @@ -502,12 +523,12 @@ struct

(* TCP to local ports *)
| Ipv4 { src; dst;
payload = Tcp { src = src_port; dst = dst_port; syn; raw;
payload = Tcp { src = src_port; dst = dst_port; syn; rst; raw;
payload = Payload _; _ }; _ } ->
let id =
Stack_tcp_wire.v ~src_port:dst_port ~dst:src ~src:dst ~dst_port:src_port
in
Endpoint.input_tcp t.endpoint ~id ~syn
Endpoint.input_tcp t.endpoint ~id ~syn ~rst
(Ipaddr.V4 Ipaddr.V4.localhost, dst_port) raw
>|= ok
| _ ->
Expand Down Expand Up @@ -592,9 +613,9 @@ struct
let id =
Stack_tcp_wire.v ~src_port:53 ~dst:src ~src:dst ~dst_port:src_port
in
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
!dns >>= fun t ->
Dns_forwarder.handle_tcp ~t ~close
Dns_forwarder.handle_tcp ~t
) raw
>|= ok

Expand Down Expand Up @@ -673,7 +694,7 @@ struct
(* Transparent HTTP intercept? *)
| Ipv4 { src = dest_ip ; dst = local_ip;
payload = Tcp { src = dest_port;
dst = local_port; syn; raw; _ }; _ } ->
dst = local_port; syn; rst; raw; _ }; _ } ->
let id =
Stack_tcp_wire.v
~src_port:local_port ~dst:dest_ip ~src:local_ip ~dst_port:dest_port
Expand All @@ -684,7 +705,7 @@ struct
in
begin match callback with
| None ->
Endpoint.input_tcp t.endpoint ~id ~syn (Ipaddr.V4 local_ip, local_port)
Endpoint.input_tcp t.endpoint ~id ~syn ~rst (Ipaddr.V4 local_ip, local_port)
raw (* common case *)
>|= ok
| Some cb ->
Expand Down Expand Up @@ -891,12 +912,11 @@ struct
let age = now -. endpoint.Endpoint.last_active_time in
if age > (float_of_int port_max_idle_time) then (ip, endpoint) :: acc else acc
) t.endpoints [] in
List.iter (fun (ip, endpoint) ->
Lwt_list.iter_s (fun (ip, endpoint) ->
Switch.remove t.switch ip;
Endpoint.destroy endpoint;
t.endpoints <- IPMap.remove ip t.endpoints
) old_ips;
Lwt.return_unit
t.endpoints <- IPMap.remove ip t.endpoints;
Endpoint.destroy endpoint
) old_ips
)
>>= fun () ->
delete_unused_endpoints t ~port_max_idle_time ()
Expand Down

0 comments on commit 10f154c

Please sign in to comment.