Skip to content

Commit

Permalink
Merge pull request #487 from hannesm/mirage-flow4
Browse files Browse the repository at this point in the history
cleanups (mainly of mirage layer)
  • Loading branch information
hannesm committed Jan 4, 2024
2 parents 96b7f2e + 1ad08b6 commit 75d8c34
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 73 deletions.
35 changes: 12 additions & 23 deletions lib/engine.ml
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ let early_data s =

let rec separate_handshakes buf =
match Reader.parse_handshake_frame buf with
| None, rest -> Ok ([], rest)
| None, rest -> [], rest
| Some hs, rest ->
let* rt, frag = separate_handshakes rest in
Ok (hs :: rt, frag)
let rt, frag = separate_handshakes rest in
hs :: rt, frag

let handle_change_cipher_spec = function
| Client cs -> Handshake_client.handle_change_cipher_spec cs
Expand Down Expand Up @@ -443,7 +443,7 @@ let handle_packet hs buf = function
Ok (hs, items, None, `No_err)

| Packet.HANDSHAKE ->
let* hss, hs_fragment = separate_handshakes (hs.hs_fragment <+> buf) in
let hss, hs_fragment = separate_handshakes (hs.hs_fragment <+> buf) in
let hs = { hs with hs_fragment } in
let* hs, items =
List.fold_left (fun acc raw ->
Expand All @@ -454,8 +454,6 @@ let handle_packet hs buf = function
in
Ok (hs, items, None, `No_err)

| Packet.HEARTBEAT -> Error (`Fatal `NoHeartbeat)

let decrement_early_data hs ty buf =
let bytes left cipher =
let count = Cstruct.length buf - fst (Ciphersuite.kn_13 (Ciphersuite.privprot13 cipher)) in
Expand All @@ -481,11 +479,11 @@ let handle_raw_record state (hdr, buf as record : raw_record) =
let version = hs.protocol_version in
let* () =
match hs.machina, version with
| Client (AwaitServerHello _), _ -> Ok ()
| Server AwaitClientHello , _ -> Ok ()
| Server13 AwaitClientHelloHRR13, _ -> Ok ()
| _ , `TLS_1_3 -> guard (hdr.version = `TLS_1_2) (`Fatal (`BadRecordVersion hdr.version))
| _ , v -> guard (version_eq hdr.version v) (`Fatal (`BadRecordVersion hdr.version))
| Client (AwaitServerHello _), _ -> Ok ()
| Server AwaitClientHello, _ -> Ok ()
| Server13 AwaitClientHelloHRR13, _ -> Ok ()
| _, `TLS_1_3 -> guard (hdr.version = `TLS_1_2) (`Fatal (`BadRecordVersion hdr.version))
| _, v -> guard (version_eq hdr.version v) (`Fatal (`BadRecordVersion hdr.version))
in
let trial = match hs.machina with
| Server13 (AwaitEndOfEarlyData13 _) | Server13 Established13 -> false
Expand Down Expand Up @@ -608,18 +606,9 @@ let send_close_notify st = send_records st [Alert.close_notify]

let reneg ?authenticator ?acceptable_cas ?cert st =
let config = st.handshake.config in
let config = match authenticator with
| None -> config
| Some auth -> Config.with_authenticator config auth
in
let config = match acceptable_cas with
| None -> config
| Some cas -> Config.with_acceptable_cas config cas
in
let config = match cert with
| None -> config
| Some cert -> Config.with_own_certificates config cert
in
let config = Option.fold ~none:config ~some:(Config.with_authenticator config) authenticator in
let config = Option.fold ~none:config ~some:(Config.with_acceptable_cas config) acceptable_cas in
let config = Option.fold ~none:config ~some:(Config.with_own_certificates config) cert in
let hs = { st.handshake with config } in
match hs.machina with
| Server Established ->
Expand Down
3 changes: 2 additions & 1 deletion lib/engine.mli
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
layer security} in OCaml. TLS is a widely used security protocol
which establishes an end-to-end secure channel (with optional
(mutual) authentication) between two endpoints. It uses TCP/IP as
transport. This library supports all three versions of TLS:
transport. This library supports all four versions of TLS:
{{:https://tools.ietf.org/html/rfc8446}1.3, RFC8446},
{{:https://tools.ietf.org/html/rfc5246}1.2, RFC5246},
{{:https://tools.ietf.org/html/rfc4346}1.1, RFC4346}, and
{{:https://tools.ietf.org/html/rfc2246}1.0, RFC2246}. SSL, the
Expand Down
4 changes: 0 additions & 4 deletions lib/packet.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,24 @@ type content_type =
| ALERT
| HANDSHAKE
| APPLICATION_DATA
| HEARTBEAT

let content_type_to_int = function
| CHANGE_CIPHER_SPEC -> 20
| ALERT -> 21
| HANDSHAKE -> 22
| APPLICATION_DATA -> 23
| HEARTBEAT -> 24
and int_to_content_type = function
| 20 -> Some CHANGE_CIPHER_SPEC
| 21 -> Some ALERT
| 22 -> Some HANDSHAKE
| 23 -> Some APPLICATION_DATA
| 24 -> Some HEARTBEAT
| _ -> None

let pp_content_type ppf = function
| CHANGE_CIPHER_SPEC -> Fmt.string ppf "change cipher spec"
| ALERT -> Fmt.string ppf "alert"
| HANDSHAKE -> Fmt.string ppf "handshake"
| APPLICATION_DATA -> Fmt.string ppf "application data"
| HEARTBEAT -> Fmt.string ppf "heartbeat"

(* TLS alert level *)
type alert_level =
Expand Down
68 changes: 36 additions & 32 deletions mirage/tls_mirage.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Lwt
open Lwt.Infix

module Make (F : Mirage_flow.S) = struct

Expand Down Expand Up @@ -46,7 +46,7 @@ module Make (F : Mirage_flow.S) = struct
( match flow.state, res with
| `Active _, (`Eof | `Error _ as e) ->
flow.state <- e ; FLOW.close flow.flow
| _ -> return_unit ) >|= fun () ->
| _ -> Lwt.return_unit ) >|= fun () ->
match f_res with
| Ok () -> Ok ()
| Error e -> Error (`Write e :> write_error)
Expand All @@ -61,43 +61,43 @@ module Make (F : Mirage_flow.S) = struct
| `Eof -> `Eof
| `Alert alert -> tls_alert alert );
( match resp with
| None -> return @@ Ok ()
| None -> Lwt.return @@ Ok ()
| Some buf -> FLOW.write flow.flow buf >>= check_write flow ) >>= fun _ ->
( match res with
| `Ok _ -> return_unit
| `Ok _ -> Lwt.return_unit
| _ -> FLOW.close flow.flow ) >>= fun () ->
return @@ `Ok data
Lwt.return @@ `Ok data
| Error (fail, `Response resp) ->
let reason = tls_fail fail in
flow.state <- reason ;
FLOW.(write flow.flow resp >>= fun _ -> close flow.flow) >>= fun () -> return reason
FLOW.(write flow.flow resp >>= fun _ -> close flow.flow) >>= fun () -> Lwt.return reason
in
match flow.state with
| `Eof | `Error _ as e -> return e
| `Eof | `Error _ as e -> Lwt.return e
| `Active _ ->
FLOW.read flow.flow >|= lift_read_result >>=
function
| `Eof | `Error _ as e -> flow.state <- e ; return e
| `Eof | `Error _ as e -> flow.state <- e ; Lwt.return e
| `Data buf -> match flow.state with
| `Active tls -> handle tls buf
| `Eof | `Error _ as e -> return e
| `Eof | `Error _ as e -> Lwt.return e

let rec read flow =
match flow.linger with
| [] ->
( read_react flow >>= function
| `Ok None -> read flow
| `Ok (Some buf) -> return @@ Ok (`Data buf)
| `Eof -> return @@ Ok `Eof
| `Error e -> return @@ Error e )
| `Ok (Some buf) -> Lwt.return @@ Ok (`Data buf)
| `Eof -> Lwt.return @@ Ok `Eof
| `Error e -> Lwt.return @@ Error e )
| bufs ->
flow.linger <- [] ;
return @@ Ok (`Data (Cstruct.concat @@ List.rev bufs))
Lwt.return @@ Ok (`Data (Cstruct.concat @@ List.rev bufs))

let writev flow bufs =
match flow.state with
| `Eof -> return @@ Error `Closed
| `Error e -> return @@ Error (e :> write_error)
| `Eof -> Lwt.return @@ Error `Closed
| `Error e -> Lwt.return @@ Error (e :> write_error)
| `Active tls ->
match Tls.Engine.send_application_data tls bufs with
| Some (tls, answer) ->
Expand All @@ -119,51 +119,55 @@ module Make (F : Mirage_flow.S) = struct
let rec drain_handshake flow =
match flow.state with
| `Active tls when not (Tls.Engine.handshake_in_progress tls) ->
return @@ Ok flow
Lwt.return @@ Ok flow
| _ ->
(* read_react re-throws *)
read_react flow >>= function
| `Ok mbuf ->
flow.linger <- Option.to_list mbuf @ flow.linger ;
drain_handshake flow
| `Error e -> return @@ Error (e :> write_error)
| `Eof -> return @@ Error `Closed
| `Error e -> Lwt.return @@ Error (e :> write_error)
| `Eof -> Lwt.return @@ Error `Closed

type wr_or_msg = [ write_error | `Msg of string ]

let underlying flow = flow.flow

let reneg ?authenticator ?acceptable_cas ?cert ?(drop = true) flow =
match flow.state with
| `Eof -> return @@ Error `Closed
| `Error e -> return @@ Error (e :> write_error)
| `Eof -> Lwt.return @@ Error `Closed
| `Error e -> Lwt.return @@ Error (e :> wr_or_msg)
| `Active tls ->
match Tls.Engine.reneg ?authenticator ?acceptable_cas ?cert tls with
| None ->
(* XXX make this impossible to reach *)
invalid_arg "Renegotiation already in progress"
| None -> Lwt.return (Error (`Msg "Renegotiation already in progress"))
| Some (tls', buf) ->
if drop then flow.linger <- [] ;
flow.state <- `Active tls' ;
FLOW.write flow.flow buf >>= fun _ ->
drain_handshake flow >|= function
| Ok _ -> Ok ()
| Error _ as e -> e
| Ok _ -> Ok ()
| Error e -> Error (e :> wr_or_msg)

let key_update ?request flow =
match flow.state with
| `Eof -> return @@ Error `Closed
| `Error e -> return @@ Error (e :> write_error)
| `Eof -> Lwt.return @@ Error `Closed
| `Error e -> Lwt.return @@ Error (e :> wr_or_msg)
| `Active tls ->
match Tls.Engine.key_update ?request tls with
| Error _ -> invalid_arg "Key update failed"
| Error _ -> Lwt.return (Error (`Msg "Key update failed"))
| Ok (tls', buf) ->
flow.state <- `Active tls' ;
FLOW.write flow.flow buf >>= check_write flow
FLOW.write flow.flow buf >>= check_write flow >|= function
| Ok _ as o -> o
| Error e -> Error (e :> wr_or_msg)

let close flow =
match flow.state with
| `Active tls ->
flow.state <- `Eof ;
let (_, buf) = Tls.Engine.send_close_notify tls in
FLOW.(write flow.flow buf >>= fun _ -> close flow.flow)
| _ -> return_unit
| _ -> Lwt.return_unit

let client_of_flow conf ?host flow =
let conf' = match host with
Expand Down Expand Up @@ -212,8 +216,8 @@ module X509 (KV : Mirage_kv.RO) (C: Mirage_clock.PCLOCK) = struct
let default_cert = "server"

let err_fail pp = function
| Ok x -> return x
| Error e -> Fmt.kstr fail_with "%a" pp e
| Ok x -> Lwt.return x
| Error e -> Fmt.kstr Lwt.fail_with "%a" pp e

let pp_msg ppf = function `Msg m -> Fmt.string ppf m

Expand Down
7 changes: 5 additions & 2 deletions mirage/tls_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ module Make (F : Mirage_flow.S) : sig
with type error := error
and type write_error := write_error

(** [underlying t] returns the underlying flow. This is useful to extract
information such as [src] and [dst] of that flow. *)
val underlying : flow -> FLOW.flow

(** [reneg ~authenticator ~acceptable_cas ~cert ~drop t] renegotiates the
session, and blocks until the renegotiation finished. Optionally, a new
Expand All @@ -28,12 +31,12 @@ module Make (F : Mirage_flow.S) : sig
application data received before the renegotiation finished is dropped. *)
val reneg : ?authenticator:X509.Authenticator.t ->
?acceptable_cas:X509.Distinguished_name.t list -> ?cert:Tls.Config.own_cert ->
?drop:bool -> flow -> (unit, write_error) result Lwt.t
?drop:bool -> flow -> (unit, [ write_error | `Msg of string ]) result Lwt.t

(** [key_update ~request t] updates the traffic key and requests a traffic key
update from the peer if [request] is provided and [true] (the default).
This is only supported in TLS 1.3. *)
val key_update : ?request:bool -> flow -> (unit, write_error) result Lwt.t
val key_update : ?request:bool -> flow -> (unit, [ write_error | `Msg of string ]) result Lwt.t

(** [client_of_flow client ~host flow] upgrades the existing connection
to TLS using the [client] configuration, using [host] as peer name. *)
Expand Down
1 change: 0 additions & 1 deletion tests/readertests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ let good_records =
([ 21 ; 3 ; 2 ; 0 ; 0 ], `Record (({ content_type = ALERT ; version = `TLS_1_1 }, empty), empty) ) ;
([ 22 ; 3 ; 3 ; 0 ; 0 ], `Record (({ content_type = HANDSHAKE ; version = `TLS_1_2 }, empty), empty) ) ;
([ 23 ; 3 ; 0 ; 0 ; 0 ], `Record (({ content_type = APPLICATION_DATA ; version = `SSL_3 }, empty), empty) ) ;
([ 24 ; 3 ; 4 ; 0 ; 0 ], `Record (({ content_type = HEARTBEAT ; version = `TLS_1_3 }, empty), empty) ) ;
([ 16 ; 3 ; 1 ; 0 ; 0 ], `UnknownContent 16 ) ;
([ 19 ; 3 ; 1 ; 0 ; 0 ], `UnknownContent 19 ) ;
([ 20 ; 5 ; 1 ; 0 ; 0 ], `UnknownVersion (5, 1) ) ;
Expand Down
4 changes: 0 additions & 4 deletions tests/readerwritertests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ let header_tests =
"ReadWrite header" >:: readerwriter_header (`TLS_1_1, Packet.CHANGE_CIPHER_SPEC, a) ;
"ReadWrite header" >:: readerwriter_header (`TLS_1_2, Packet.CHANGE_CIPHER_SPEC, a) ;

"ReadWrite header" >:: readerwriter_header (`TLS_1_0, Packet.HEARTBEAT, a) ;
"ReadWrite header" >:: readerwriter_header (`TLS_1_1, Packet.HEARTBEAT, a) ;
"ReadWrite header" >:: readerwriter_header (`TLS_1_2, Packet.HEARTBEAT, a) ;

"ReadWrite header" >:: readerwriter_header (`TLS_1_0, Packet.ALERT, a) ;
"ReadWrite header" >:: readerwriter_header (`TLS_1_1, Packet.ALERT, a) ;
"ReadWrite header" >:: readerwriter_header (`TLS_1_2, Packet.ALERT, a) ;
Expand Down
5 changes: 0 additions & 5 deletions tests/writertests.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ let hdr_assembler_tests = [
(`TLS_1_1, Packet.APPLICATION_DATA, [], [23; 3; 2; 0; 0]) ;
(`TLS_1_0, Packet.APPLICATION_DATA, [], [23; 3; 1; 0; 0]) ;
(`TLS_1_2, Packet.APPLICATION_DATA, [0; 0; 0], [23; 3; 3; 0; 3; 0; 0; 0]) ;

(`TLS_1_2, Packet.HEARTBEAT, [], [24; 3; 3; 0; 0]) ;
(`TLS_1_1, Packet.HEARTBEAT, [], [24; 3; 2; 0; 0]) ;
(`TLS_1_0, Packet.HEARTBEAT, [], [24; 3; 1; 0; 0]) ;
(`TLS_1_2, Packet.HEARTBEAT, [0; 0; 0], [24; 3; 3; 0; 3; 0; 0; 0]) ;
]

let hdr_tests =
Expand Down
2 changes: 1 addition & 1 deletion tls-mirage.opam
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ depends: [
"x509" {>= "0.13.0"}
"fmt" {>= "0.8.7"}
"lwt" {>= "3.0.0"}
"mirage-flow" {>= "2.0.0"}
"mirage-flow" {>= "2.0.0" & < "4.0.0"}
"mirage-kv" {>= "3.0.0"}
"mirage-clock" {>= "3.0.0"}
"ptime" {>= "0.8.1"}
Expand Down

0 comments on commit 75d8c34

Please sign in to comment.