@@ -5,7 +5,6 @@ module Log = (val Logs.src_log src : Logs.LOG)
5
5
6
6
module Make (F : Mirage_flow.S ) (T : Mirage_time.S ) (M : Mirage_clock.MCLOCK ) = struct
7
7
8
- module FLOW = F
9
8
module MCLOCK = M
10
9
11
10
type error = [ `Msg of string
@@ -22,22 +21,63 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
22
21
| #Mirage_flow. write_error as e -> Mirage_flow. pp_write_error ppf e
23
22
| #error as e -> pp_error ppf e
24
23
24
+ (* this is the flow of a ssh-client. be aware that we're only using a single
25
+ channel.
26
+
27
+ the state `Read_closed is set (a) when a TCP.read returned `Eof,
28
+ and (b) when the application did a shutdown `read (or `read_write).
29
+ the state `Write_closed is set (a) when a TCP.write returned `Closed,
30
+ and (b) when the application did a shutdown `write (or `read_write).
31
+
32
+ If we're in `Write_closed, and do a shutdown `read, we'll end up in
33
+ `Closed, and attempt to (a) send a SSH_MSG_CHANNEL_CLOSE and (b) TCP.close.
34
+ This may fail, since on the TCP layer, the connection may have already be
35
+ half-closed (or fully closed) in the write direction. We ignore this error
36
+ from writev below in close.
37
+ *)
25
38
type flow = {
26
- flow : FLOW .flow ;
27
- mutable state : [ `Active of Awa.Client .t | `Eof | `Error of error ]
39
+ flow : F .flow ;
40
+ mutable state : [
41
+ | `Active of Awa.Client .t
42
+ | `Read_closed of Awa.Client .t
43
+ | `Write_closed of Awa.Client .t
44
+ | `Closed
45
+ | `Error of error ]
28
46
}
29
47
48
+ let half_close state mode =
49
+ match state, mode with
50
+ | `Active ssh , `read -> `Read_closed ssh
51
+ | `Active ssh , `write -> `Write_closed ssh
52
+ | `Active _ , `read_write -> `Closed
53
+ | `Read_closed ssh , `read -> `Read_closed ssh
54
+ | `Read_closed _ , (`write | `read_write ) -> `Closed
55
+ | `Write_closed ssh , `write -> `Write_closed ssh
56
+ | `Write_closed _ , (`read | `read_write ) -> `Closed
57
+ | (`Closed | `Error _ ) as e , (`read | `write | `read_write ) -> e
58
+
59
+ let inject_state ssh = function
60
+ | `Active _ -> `Active ssh
61
+ | `Read_closed _ -> `Read_closed ssh
62
+ | `Write_closed _ -> `Write_closed ssh
63
+ | (`Closed | `Error _ ) as e -> e
64
+
30
65
let write_flow t buf =
31
- FLOW. write t.flow buf >> = function
32
- | Ok () -> Lwt. return (Ok () )
66
+ F. write t.flow buf >> = function
67
+ | Ok _ as o -> Lwt. return o
68
+ | Error `Closed ->
69
+ Log. warn (fun m -> m " error closed while writing" );
70
+ t.state < - half_close t.state `write ;
71
+ Lwt. return (Error (`Write `Closed ))
33
72
| Error w ->
34
73
Log. warn (fun m -> m " error %a while writing" F. pp_write_error w);
35
- t.state < - `Error (`Write w) ; Lwt. return (Error (`Write w))
74
+ t.state < - `Error (`Write w);
75
+ Lwt. return (Error (`Write w))
36
76
37
77
let writev_flow t bufs =
38
78
Lwt_list. fold_left_s (fun r d ->
39
79
match r with
40
- | Error e -> Lwt. return ( Error e)
80
+ | Error _ as e -> Lwt. return e
41
81
| Ok () -> write_flow t d)
42
82
(Ok () ) bufs
43
83
@@ -46,25 +86,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
46
86
47
87
let read_react t =
48
88
match t.state with
49
- | `Eof | `Error _ -> Lwt. return (Error () )
50
- | `Active _ ->
51
- FLOW . read t.flow >> = function
89
+ | `Read_closed _ | `Closed | `Error _ -> Lwt. return (Error () )
90
+ | `Active _ | `Write_closed _ ->
91
+ F . read t.flow >> = function
52
92
| Error e ->
53
93
Log. warn (fun m -> m " error %a while reading" F. pp_error e);
54
94
t.state < - `Error (`Read e);
55
95
Lwt. return (Error () )
56
- | Ok `Eof -> t.state < - `Eof ; Lwt. return (Error () )
96
+ | Ok `Eof ->
97
+ t.state < - half_close t.state `read ;
98
+ Lwt. return (Error () )
57
99
| Ok (`Data data ) ->
58
100
match t.state with
59
- | `Active ssh ->
101
+ | `Active ssh | `Write_closed ssh ->
60
102
begin match Awa.Client. incoming ssh (now () ) data with
61
103
| Error msg ->
62
104
Log. warn (fun m -> m " error %s while processing data" msg);
63
105
t.state < - `Error (`Msg msg);
64
106
Lwt. return (Error () )
65
107
| Ok (ssh' , out , events ) ->
66
- let state' = if List. mem `Disconnected events then `Eof else `Active ssh' in
67
- t.state < - state' ;
108
+ t. state < -
109
+ inject_state ssh' ( if List. mem `Disconnected events then half_close t.state `read else t. state) ;
68
110
writev_flow t out >> = fun _ ->
69
111
Lwt. return (Ok events)
70
112
end
@@ -74,15 +116,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
74
116
read_react t >> = function
75
117
| Ok es ->
76
118
begin match t.state, List. filter (function `Established _ -> true | _ -> false ) es with
77
- | `Eof , _ -> Lwt. return (Error (`Msg " disconnected" ))
119
+ | ( `Read_closed _ | `Closed ) , _ -> Lwt. return (Error (`Msg " disconnected" ))
78
120
| `Error e , _ -> Lwt. return (Error e)
79
- | `Active _ , [ `Established id ] -> Lwt. return (Ok id)
80
- | `Active _ , _ -> drain_handshake t
121
+ | ( `Active _ | `Write_closed _ ) , [ `Established id ] -> Lwt. return (Ok id)
122
+ | ( `Active _ | `Write_closed _ ) , _ -> drain_handshake t
81
123
end
82
124
| Error () -> match t.state with
83
125
| `Error e -> Lwt. return (Error e)
84
- | `Eof -> Lwt. return (Error (`Msg " disconnected" ))
85
- | `Active _ -> assert false
126
+ | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt. return (Error (`Msg " disconnected" ))
86
127
87
128
let rec read t =
88
129
read_react t >> = function
@@ -107,32 +148,57 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
107
148
end
108
149
| Error () -> match t.state with
109
150
| `Error e -> Lwt. return (Error e)
110
- | `Eof -> Lwt. return (Ok `Eof )
111
- | `Active _ -> assert false
151
+ | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt. return (Ok `Eof )
112
152
113
153
let close t =
114
- (* TODO ssh session teardown (send some protocol messages) *)
115
- FLOW. close t.flow > |= fun () ->
116
- t.state < - `Eof
154
+ (match t.state with
155
+ | `Active ssh | `Read_closed ssh | `Write_closed ssh ->
156
+ let ssh, msg = Awa.Client. close ssh in
157
+ t.state < - inject_state ssh t.state;
158
+ t.state < - `Closed ;
159
+ (* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
160
+ writev_flow t (Option. to_list msg) > |= ignore
161
+ | `Error _ | `Closed -> Lwt. return_unit) >> = fun () ->
162
+ F. close t.flow
163
+
164
+ let shutdown t mode =
165
+ match t.state with
166
+ | `Active ssh | `Read_closed ssh | `Write_closed ssh ->
167
+ let ssh, msg =
168
+ match t.state, mode with
169
+ | (`Active ssh | `Read_closed ssh ), `write -> Awa.Client. eof ssh
170
+ | _ , `read_write -> Awa.Client. close ssh
171
+ | _ -> ssh, None
172
+ in
173
+ t.state < - inject_state ssh (half_close t.state mode);
174
+ (* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
175
+ writev_flow t (Option. to_list msg) >> = fun _ ->
176
+ (* we don't [FLOW.shutdown _ mode] because we still need to read/write
177
+ channel_eof/channel_close unless both directions are closed *)
178
+ (match t.state with
179
+ | `Closed -> F. close t.flow
180
+ | _ -> Lwt. return_unit)
181
+ | `Error _ | `Closed ->
182
+ F. close t.flow
117
183
118
184
let writev t bufs =
119
185
let open Lwt_result.Infix in
120
186
match t.state with
121
- | `Active ssh ->
187
+ | `Active ssh | `Read_closed ssh ->
122
188
Lwt_list. fold_left_s (fun r data ->
123
189
match r with
124
190
| Error e -> Lwt. return (Error e)
125
191
| Ok ssh ->
126
192
match Awa.Client. outgoing_data ssh data with
127
193
| Ok (ssh' , datas ) ->
128
- t.state < - `Active ssh';
194
+ t.state < - inject_state ssh' t.state ;
129
195
writev_flow t datas > |= fun () ->
130
196
ssh'
131
197
| Error msg ->
132
198
t.state < - `Error (`Msg msg) ;
133
199
Lwt. return (Error (`Msg msg)))
134
200
(Ok ssh) bufs > |= fun _ -> ()
135
- | `Eof -> Lwt. return (Error `Closed )
201
+ | `Write_closed _ | `Closed -> Lwt. return (Error `Closed )
136
202
| `Error e -> Lwt. return (Error (e :> write_error ))
137
203
138
204
let write t buf = writev t [buf]
@@ -146,12 +212,17 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
146
212
} in
147
213
writev_flow t msgs >> = fun () ->
148
214
drain_handshake t >> = fun id ->
149
- (* TODO that's a bit hardcoded... *)
150
- let ssh = match t.state with `Active t -> t | _ -> assert false in
151
- (match Awa.Client. outgoing_request ssh ~id req with
152
- | Error msg -> t.state < - `Error (`Msg msg) ; Lwt. return (Error (`Msg msg))
153
- | Ok (ssh' , data ) -> t.state < - `Active ssh' ; write_flow t data) > |= fun () ->
154
- t
215
+ match t.state with
216
+ | `Active ssh ->
217
+ (match Awa.Client. outgoing_request ssh ~id req with
218
+ | Error msg -> t.state < - `Error (`Msg msg) ; Lwt. return (Error (`Msg msg))
219
+ | Ok (ssh' , data ) -> t.state < - `Active ssh' ; write_flow t data) > |= fun () ->
220
+ t
221
+ | `Read_closed _ -> Lwt. return (Error (`Msg " read closed" ))
222
+ | `Write_closed _ -> Lwt. return (Error (`Msg " write closed" ))
223
+ | `Closed -> Lwt. return (Error (`Msg " closed" ))
224
+ | `Error e -> Lwt. return (Error e)
225
+
155
226
156
227
(* copy from awa_lwt.ml and unix references removed in favor to FLOW *)
157
228
type nexus_msg =
@@ -195,10 +266,10 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
195
266
let send_msg flow server msg =
196
267
wrapr (Awa.Server. output_msg server msg)
197
268
>> = fun (server , msg_buf ) ->
198
- FLOW . write flow msg_buf >> = function
269
+ F . write flow msg_buf >> = function
199
270
| Ok () -> Lwt. return server
200
271
| Error w ->
201
- Log. err (fun m -> m " error %a while writing" FLOW . pp_write_error w);
272
+ Log. err (fun m -> m " error %a while writing" F . pp_write_error w);
202
273
Lwt. return server
203
274
204
275
let rec send_msgs fd server = function
@@ -209,9 +280,9 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
209
280
| [] -> Lwt. return server
210
281
211
282
let net_read flow =
212
- FLOW . read flow >> = function
283
+ F . read flow >> = function
213
284
| Error e ->
214
- Log. err (fun m -> m " read error %a" FLOW . pp_error e);
285
+ Log. err (fun m -> m " read error %a" F . pp_error e);
215
286
Lwt. return Net_eof
216
287
| Ok `Eof ->
217
288
Lwt. return Net_eof
0 commit comments