Skip to content

Commit 7c66137

Browse files
hannesmreynirdinosaure
authored
Adapt to mirage-flow 4 API (#70)
* mirage: avoid an assert false, properly return an error * provide close and shutdown in Awa_mirage * simplify - a shutdown \`read_write is a close * mirage: preserve half-closed connections, and deal with them properly * mirage: avoid assertions * address @reynir review - and use inject_state * mirage: revise close and shutdown first to the ssh teardown, then do the underlying flow teardown * shutdown: don't shutdown the flow unless closed If we are in `Read_closed we may still want to read channel-close and when we are in `Write_closed we may still want to write channel-close. * mirage: set closed earlier in close(); also remove TODO comment * mirage: add comment about states and why errors may occur that we ignore (thanks to @dinosaure) * minor tweaks * shutdown: if in closed/error state, call close on the underlying flow nevertheless --------- Co-authored-by: Reynir Björnsson <reynir@reynir.dk> Co-authored-by: Romain Calascibetta <romain.calascibetta@gmail.com>
1 parent 389c1f3 commit 7c66137

File tree

5 files changed

+135
-42
lines changed

5 files changed

+135
-42
lines changed

awa-mirage.opam

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ depends: [
2222
"lwt" {>= "5.3.0"}
2323
"mirage-time" {>= "2.0.0"}
2424
"duration" {>= "0.2.0"}
25-
"mirage-flow" {>= "2.0.0"}
25+
"mirage-flow" {>= "4.0.0"}
2626
"mirage-clock" {>= "3.0.0"}
2727
"logs"
2828
]

lib/client.ml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,23 @@ let outgoing_data t ?(id = 0l) data =
523523
let* c, frags = Channel.output_data c data in
524524
let t' = { t with channels = Channel.update c t.channels } in
525525
Ok (output_msgs t' frags)
526+
527+
let eof ?(id = 0l) t =
528+
match
529+
let* () = guard (established t) "not yet established" in
530+
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
531+
let msg = Ssh.Msg_channel_eof c.them.id in
532+
Ok (output_msg t msg)
533+
with
534+
| Error _ -> t, None
535+
| Ok (t, msg) -> t, Some msg
536+
537+
let close ?(id = 0l) t =
538+
match
539+
let* () = guard (established t) "not yet established" in
540+
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
541+
let msg = Ssh.Msg_channel_close c.them.id in
542+
Ok (output_msg t msg)
543+
with
544+
| Error _ -> t, None
545+
| Ok (t, msg) -> t, Some msg

lib/client.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,7 @@ val outgoing_request : t -> ?id:int32 -> ?want_reply:bool ->
3838

3939
val outgoing_data : t -> ?id:int32 -> Cstruct.t ->
4040
(t * Cstruct.t list, string) result
41+
42+
val eof : ?id:int32 -> t -> t * Cstruct.t option
43+
44+
val close : ?id:int32 -> t -> t * Cstruct.t option

mirage/awa_mirage.ml

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ module Log = (val Logs.src_log src : Logs.LOG)
55

66
module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = struct
77

8-
module FLOW = F
98
module MCLOCK = M
109

1110
type error = [ `Msg of string
@@ -22,22 +21,63 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
2221
| #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e
2322
| #error as e -> pp_error ppf e
2423

24+
(* this is the flow of a ssh-client. be aware that we're only using a single
25+
channel.
26+
27+
the state `Read_closed is set (a) when a TCP.read returned `Eof,
28+
and (b) when the application did a shutdown `read (or `read_write).
29+
the state `Write_closed is set (a) when a TCP.write returned `Closed,
30+
and (b) when the application did a shutdown `write (or `read_write).
31+
32+
If we're in `Write_closed, and do a shutdown `read, we'll end up in
33+
`Closed, and attempt to (a) send a SSH_MSG_CHANNEL_CLOSE and (b) TCP.close.
34+
This may fail, since on the TCP layer, the connection may have already be
35+
half-closed (or fully closed) in the write direction. We ignore this error
36+
from writev below in close.
37+
*)
2538
type flow = {
26-
flow : FLOW.flow ;
27-
mutable state : [ `Active of Awa.Client.t | `Eof | `Error of error ]
39+
flow : F.flow ;
40+
mutable state : [
41+
| `Active of Awa.Client.t
42+
| `Read_closed of Awa.Client.t
43+
| `Write_closed of Awa.Client.t
44+
| `Closed
45+
| `Error of error ]
2846
}
2947

48+
let half_close state mode =
49+
match state, mode with
50+
| `Active ssh, `read -> `Read_closed ssh
51+
| `Active ssh, `write -> `Write_closed ssh
52+
| `Active _, `read_write -> `Closed
53+
| `Read_closed ssh, `read -> `Read_closed ssh
54+
| `Read_closed _, (`write | `read_write) -> `Closed
55+
| `Write_closed ssh, `write -> `Write_closed ssh
56+
| `Write_closed _, (`read | `read_write) -> `Closed
57+
| (`Closed | `Error _) as e, (`read | `write | `read_write) -> e
58+
59+
let inject_state ssh = function
60+
| `Active _ -> `Active ssh
61+
| `Read_closed _ -> `Read_closed ssh
62+
| `Write_closed _ -> `Write_closed ssh
63+
| (`Closed | `Error _) as e -> e
64+
3065
let write_flow t buf =
31-
FLOW.write t.flow buf >>= function
32-
| Ok () -> Lwt.return (Ok ())
66+
F.write t.flow buf >>= function
67+
| Ok _ as o -> Lwt.return o
68+
| Error `Closed ->
69+
Log.warn (fun m -> m "error closed while writing");
70+
t.state <- half_close t.state `write;
71+
Lwt.return (Error (`Write `Closed))
3372
| Error w ->
3473
Log.warn (fun m -> m "error %a while writing" F.pp_write_error w);
35-
t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w))
74+
t.state <- `Error (`Write w);
75+
Lwt.return (Error (`Write w))
3676

3777
let writev_flow t bufs =
3878
Lwt_list.fold_left_s (fun r d ->
3979
match r with
40-
| Error e -> Lwt.return (Error e)
80+
| Error _ as e -> Lwt.return e
4181
| Ok () -> write_flow t d)
4282
(Ok ()) bufs
4383

@@ -46,25 +86,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
4686

4787
let read_react t =
4888
match t.state with
49-
| `Eof | `Error _ -> Lwt.return (Error ())
50-
| `Active _ ->
51-
FLOW.read t.flow >>= function
89+
| `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ())
90+
| `Active _ | `Write_closed _ ->
91+
F.read t.flow >>= function
5292
| Error e ->
5393
Log.warn (fun m -> m "error %a while reading" F.pp_error e);
5494
t.state <- `Error (`Read e);
5595
Lwt.return (Error ())
56-
| Ok `Eof -> t.state <- `Eof ; Lwt.return (Error ())
96+
| Ok `Eof ->
97+
t.state <- half_close t.state `read;
98+
Lwt.return (Error ())
5799
| Ok (`Data data) ->
58100
match t.state with
59-
| `Active ssh ->
101+
| `Active ssh | `Write_closed ssh ->
60102
begin match Awa.Client.incoming ssh (now ()) data with
61103
| Error msg ->
62104
Log.warn (fun m -> m "error %s while processing data" msg);
63105
t.state <- `Error (`Msg msg);
64106
Lwt.return (Error ())
65107
| Ok (ssh', out, events) ->
66-
let state' = if List.mem `Disconnected events then `Eof else `Active ssh' in
67-
t.state <- state';
108+
t.state <-
109+
inject_state ssh' (if List.mem `Disconnected events then half_close t.state `read else t.state);
68110
writev_flow t out >>= fun _ ->
69111
Lwt.return (Ok events)
70112
end
@@ -74,15 +116,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
74116
read_react t >>= function
75117
| Ok es ->
76118
begin match t.state, List.filter (function `Established _ -> true | _ -> false) es with
77-
| `Eof, _ -> Lwt.return (Error (`Msg "disconnected"))
119+
| (`Read_closed _ | `Closed), _ -> Lwt.return (Error (`Msg "disconnected"))
78120
| `Error e, _ -> Lwt.return (Error e)
79-
| `Active _, [ `Established id ] -> Lwt.return (Ok id)
80-
| `Active _, _ -> drain_handshake t
121+
| (`Active _ | `Write_closed _), [ `Established id ] -> Lwt.return (Ok id)
122+
| (`Active _ | `Write_closed _), _ -> drain_handshake t
81123
end
82124
| Error () -> match t.state with
83125
| `Error e -> Lwt.return (Error e)
84-
| `Eof -> Lwt.return (Error (`Msg "disconnected"))
85-
| `Active _ -> assert false
126+
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Error (`Msg "disconnected"))
86127

87128
let rec read t =
88129
read_react t >>= function
@@ -107,32 +148,57 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
107148
end
108149
| Error () -> match t.state with
109150
| `Error e -> Lwt.return (Error e)
110-
| `Eof -> Lwt.return (Ok `Eof)
111-
| `Active _ -> assert false
151+
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof)
112152

113153
let close t =
114-
(* TODO ssh session teardown (send some protocol messages) *)
115-
FLOW.close t.flow >|= fun () ->
116-
t.state <- `Eof
154+
(match t.state with
155+
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
156+
let ssh, msg = Awa.Client.close ssh in
157+
t.state <- inject_state ssh t.state;
158+
t.state <- `Closed;
159+
(* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
160+
writev_flow t (Option.to_list msg) >|= ignore
161+
| `Error _ | `Closed -> Lwt.return_unit) >>= fun () ->
162+
F.close t.flow
163+
164+
let shutdown t mode =
165+
match t.state with
166+
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
167+
let ssh, msg =
168+
match t.state, mode with
169+
| (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh
170+
| _, `read_write -> Awa.Client.close ssh
171+
| _ -> ssh, None
172+
in
173+
t.state <- inject_state ssh (half_close t.state mode);
174+
(* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
175+
writev_flow t (Option.to_list msg) >>= fun _ ->
176+
(* we don't [FLOW.shutdown _ mode] because we still need to read/write
177+
channel_eof/channel_close unless both directions are closed *)
178+
(match t.state with
179+
| `Closed -> F.close t.flow
180+
| _ -> Lwt.return_unit)
181+
| `Error _ | `Closed ->
182+
F.close t.flow
117183

118184
let writev t bufs =
119185
let open Lwt_result.Infix in
120186
match t.state with
121-
| `Active ssh ->
187+
| `Active ssh | `Read_closed ssh ->
122188
Lwt_list.fold_left_s (fun r data ->
123189
match r with
124190
| Error e -> Lwt.return (Error e)
125191
| Ok ssh ->
126192
match Awa.Client.outgoing_data ssh data with
127193
| Ok (ssh', datas) ->
128-
t.state <- `Active ssh';
194+
t.state <- inject_state ssh' t.state;
129195
writev_flow t datas >|= fun () ->
130196
ssh'
131197
| Error msg ->
132198
t.state <- `Error (`Msg msg) ;
133199
Lwt.return (Error (`Msg msg)))
134200
(Ok ssh) bufs >|= fun _ -> ()
135-
| `Eof -> Lwt.return (Error `Closed)
201+
| `Write_closed _ | `Closed -> Lwt.return (Error `Closed)
136202
| `Error e -> Lwt.return (Error (e :> write_error))
137203

138204
let write t buf = writev t [buf]
@@ -146,12 +212,17 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
146212
} in
147213
writev_flow t msgs >>= fun () ->
148214
drain_handshake t >>= fun id ->
149-
(* TODO that's a bit hardcoded... *)
150-
let ssh = match t.state with `Active t -> t | _ -> assert false in
151-
(match Awa.Client.outgoing_request ssh ~id req with
152-
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
153-
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
154-
t
215+
match t.state with
216+
| `Active ssh ->
217+
(match Awa.Client.outgoing_request ssh ~id req with
218+
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
219+
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
220+
t
221+
| `Read_closed _ -> Lwt.return (Error (`Msg "read closed"))
222+
| `Write_closed _ -> Lwt.return (Error (`Msg "write closed"))
223+
| `Closed -> Lwt.return (Error (`Msg "closed"))
224+
| `Error e -> Lwt.return (Error e)
225+
155226

156227
(* copy from awa_lwt.ml and unix references removed in favor to FLOW *)
157228
type nexus_msg =
@@ -195,10 +266,10 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
195266
let send_msg flow server msg =
196267
wrapr (Awa.Server.output_msg server msg)
197268
>>= fun (server, msg_buf) ->
198-
FLOW.write flow msg_buf >>= function
269+
F.write flow msg_buf >>= function
199270
| Ok () -> Lwt.return server
200271
| Error w ->
201-
Log.err (fun m -> m "error %a while writing" FLOW.pp_write_error w);
272+
Log.err (fun m -> m "error %a while writing" F.pp_write_error w);
202273
Lwt.return server
203274

204275
let rec send_msgs fd server = function
@@ -209,9 +280,9 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
209280
| [] -> Lwt.return server
210281

211282
let net_read flow =
212-
FLOW.read flow >>= function
283+
F.read flow >>= function
213284
| Error e ->
214-
Log.err (fun m -> m "read error %a" FLOW.pp_error e);
285+
Log.err (fun m -> m "read error %a" F.pp_error e);
215286
Lwt.return Net_eof
216287
| Ok `Eof ->
217288
Lwt.return Net_eof

mirage/awa_mirage.mli

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
(** SSH module given a flow *)
44
module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : sig
55

6-
module FLOW : Mirage_flow.S
7-
86
(** possible errors: incoming alert, processing failure, or a
97
problem in the underlying flow. *)
108
type error = [ `Msg of string
@@ -24,7 +22,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
2422
sends the channel request. *)
2523
val client_of_flow : ?authenticator:Awa.Keys.authenticator -> user:string ->
2624
[ `Pubkey of Awa.Hostkey.priv | `Password of string ] ->
27-
Awa.Ssh.channel_request -> FLOW.flow -> (flow, error) result Lwt.t
25+
Awa.Ssh.channel_request -> F.flow -> (flow, error) result Lwt.t
2826

2927
type t
3028

@@ -64,4 +62,4 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
6462
{b NOTE}: Even if the [ssh_channel_handler] is fulfilled, [spawn_server]
6563
continues to handle SSH channels. Only [stop] can really stop the internal
6664
SSH channels handler. *)
67-
end with module FLOW = F
65+
end

0 commit comments

Comments
 (0)