Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement qrexec protocol version 3 #60

Merged
merged 6 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions lib/formats.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ module Qrexec = struct
} [@@little_endian]
]

[%%cstruct
type trigger_service_params3 = {
target_domain : uint8_t [@len 64];
request_id : uint8_t [@len 32];
(* rest of message is service name *)
} [@@little_endian]
]

type msg_type =
[ `Exec_cmdline
| `Just_exec
| `Service_connect
| `Service_refused
| `Trigger_service
| `Connection_terminated
| `Trigger_service3
| `Hello
| `Data_stdin
| `Data_stdout
Expand All @@ -66,6 +75,7 @@ module Qrexec = struct
| 0x203l -> `Service_refused
| 0x210l -> `Trigger_service
| 0x211l -> `Connection_terminated
| 0x212l -> `Trigger_service3
| 0x300l -> `Hello
| x -> `Unknown x

Expand All @@ -80,6 +90,7 @@ module Qrexec = struct
| `Service_refused -> 0x203l
| `Trigger_service -> 0x210l
| `Connection_terminated -> 0x211l
| `Trigger_service3 -> 0x212l
| `Hello -> 0x300l
| `Unknown x -> x

Expand All @@ -94,9 +105,24 @@ module Qrexec = struct
| `Service_refused -> "MSG_SERVICE_REFUSED"
| `Trigger_service -> "MSG_TRIGGER_SERVICE"
| `Connection_terminated -> "MSG_CONNECTION_TERMINATED"
| `Trigger_service3 -> "MSG_TRIGGER_SERVICE3"
| `Hello -> "MSG_HELLO"
| `Unknown x -> "Unknown message: " ^ (Int32.to_string x)

type version =
[ `V2
| `V3 ]

let version_of_int = function
| 2l -> `V2
| 3l -> `V3
| x -> `Unknown_version x

let int_of_version = function
| `V2 -> 2l
| `V3 -> 3l
| `Unknown_version x -> x


module Framing = struct
let header_size = sizeof_msg_header
Expand Down
145 changes: 91 additions & 54 deletions lib/rExec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ let vchan_base_port =
| Error (`Msg msg) -> failwith msg
| Ok port -> port

let max_data_chunk = 4096

let max_data_chunk_v2 = 4096
(** Max size for data chunks. See MAX_DATA_CHUNK in qubes-linux-utils/qrexec-lib/qrexec.h *)

let rec send t ~ty data =
let data, data' = Cstruct.split data (min max_data_chunk (Cstruct.length data)) in
let max_data_chunk_v3 = 65536
(** protocol version 3+ *)

let max_data_chunk : Formats.Qrexec.version -> int = function
| `V2 -> max_data_chunk_v2
| `V3 -> max_data_chunk_v3

let rec send t ~version ~ty data =
let data, data' = Cstruct.split data (min (max_data_chunk version) (Cstruct.length data)) in
let hdr = Cstruct.create sizeof_msg_header in
set_msg_header_ty hdr (int_of_type ty);
set_msg_header_len hdr (Cstruct.length data |> Int32.of_int);
Expand All @@ -41,7 +49,7 @@ let rec send t ~ty data =
else QV.send t [hdr; data] >>= function
| `Eof -> Lwt.return `Eof
| `Ok () ->
send t ~ty data'
send t ~version ~ty data'

let recv t =
QV.recv t >>!= fun (hdr, data) ->
Expand All @@ -53,16 +61,18 @@ module Flow = struct
dstream : QV.t;
mutable stdin_buf : Cstruct.t;
ty : [`Just_exec | `Exec_cmdline];
version : Formats.Qrexec.version;
}

let create ~ty dstream = {dstream; stdin_buf = Cstruct.create 0; ty}
let create ~version ~ty dstream =
{dstream; stdin_buf = Cstruct.create 0; ty; version}

let push ~stream flow buf =
match flow.ty with
| `Just_exec -> Lwt.return_unit
| `Exec_cmdline ->
if Cstruct.length buf > 0 then
send flow.dstream ~ty:stream buf >>= or_fail
send flow.dstream ~version:flow.version ~ty:stream buf >>= or_fail
else
Lwt.return_unit

Expand Down Expand Up @@ -115,8 +125,8 @@ module Flow = struct
set_exit_status_return_code msg (Int32.of_int return_code);
Lwt.finalize
(fun () ->
send flow.dstream ~ty:`Data_stdout (Cstruct.create 0) >>!= fun () ->
send flow.dstream ~ty:`Data_exit_code msg >>!= fun () ->
send flow.dstream ~version:flow.version ~ty:`Data_stdout (Cstruct.create 0) >>!= fun () ->
send flow.dstream ~version:flow.version ~ty:`Data_exit_code msg >>!= fun () ->
Lwt.return (`Ok ())
)
(fun () -> QV.disconnect flow.dstream)
Expand All @@ -127,16 +137,18 @@ module Client_flow = struct
dstream : QV.t;
mutable stdout_buf : Cstruct.t;
mutable stderr_buf : Cstruct.t;
version : Formats.Qrexec.version;
}

let create dstream = { dstream; stdout_buf = Cstruct.empty;
stderr_buf = Cstruct.empty }
let create ~version dstream =
{ dstream; stdout_buf = Cstruct.empty;
stderr_buf = Cstruct.empty; version }

let write t data = send ~ty:`Data_stdin t.dstream data
let write t data = send ~version:t.version ~ty:`Data_stdin t.dstream data

let writef t fmt =
fmt |> Printf.ksprintf @@ fun s ->
send ~ty:`Data_stdin t.dstream (Cstruct.of_string s)
send ~version:t.version ~ty:`Data_stdin t.dstream (Cstruct.of_string s)

let next_msg t =
recv t.dstream >>= function
Expand Down Expand Up @@ -186,6 +198,7 @@ type t = {
t : QV.t;
clients : (identifier, client) Hashtbl.t;
mutable counter : int;
version : version;
}

let disconnect t =
Expand All @@ -194,18 +207,37 @@ let disconnect t =
type handler = user:string -> string -> Flow.t -> int Lwt.t

let send_hello t =
let version = `V3 in
let hello = Cstruct.create sizeof_peer_info in
set_peer_info_version hello 2l;
send t ~ty:`Hello hello >>= function
| `Eof -> Lwt.fail_with "End-of-file sending msg_hello"
set_peer_info_version hello (int_of_version version);
send t ~version ~ty:`Hello hello >>= function
| `Eof -> Fmt.failwith "End-of-file sending msg_hello"
| `Ok () -> Lwt.return_unit

let recv_hello t =
recv t >>= function
| `Eof -> Lwt.fail_with "End-of-file waiting for msg_hello"
| `Ok (`Hello, resp) -> Lwt.return (get_peer_info_version resp)
| `Eof -> Fmt.failwith "End-of-file waiting for msg_hello"
| `Ok (`Hello, resp) ->
let peer_version = get_peer_info_version resp in
Lwt.return (version_of_int peer_version)
| `Ok (ty, _) -> Fmt.failwith "Expected msg_hello, got %ld" (int_of_type ty)

let negotiate_version (peer_version : [ version | `Unknown_version of int32 ])
: version =
let version =
match peer_version with
| `Unknown_version x -> if x < int_of_version `V2
then Fmt.failwith "Unsupported qrexec version %lu" x
else `V3
| #version as version -> version
in
Log.debug (fun f -> f "remote end wants to use protocol version %lu, \
continuing with version %lu"
(int_of_version peer_version) (int_of_version version));
version



let try_close flow return_code =
Flow.close flow return_code >|= function
| `Ok () -> ()
Expand All @@ -217,14 +249,10 @@ let with_flow ~ty ~domid ~port fn =
QV.client ~domid ~port () >>= fun client ->
Lwt.catch
(fun () ->
recv_hello client >>= function
| version when version < 2l -> Fmt.failwith "Unsupported qrexec version %ld" version
| version ->
Log.info (fun f -> f "client connected, \
other end wants to use protocol version %lu, \
continuing with version 2" version);
send_hello client >|= fun () ->
Flow.create ~ty client
recv_hello client >>= fun peer_version ->
send_hello client >|= fun () ->
let version = negotiate_version peer_version in
Flow.create ~version ~ty client
)
(fun ex -> QV.disconnect client >>= fun () -> Lwt.fail ex)
)
Expand Down Expand Up @@ -276,7 +304,7 @@ let exec t ~ty ~handler msg =
)
(fun () ->
let reply = Cstruct.sub msg 0 sizeof_exec_params in
send t.t ~ty:`Connection_terminated reply >|= function
send t.t ~version:t.version ~ty:`Connection_terminated reply >|= function
| `Ok () | `Eof -> ()
)
)
Expand All @@ -301,13 +329,14 @@ let start_connection params clients =
| Ok port ->
QV.server ~domid:(Int32.to_int domid) ~port () >>= fun remote ->
send_hello remote >>= fun () ->
recv_hello remote >>= fun version ->
recv_hello remote >>= fun peer_version ->
let version = negotiate_version peer_version in
Log.debug (fun f -> f "server connected on port %s, using protocol vers
ion %ld" (Vchan.Port.to_string port) version);
ion %ld" (Vchan.Port.to_string port) (int_of_version version));
match Hashtbl.find_opt clients request_id with
| Some client ->
Hashtbl.remove clients request_id;
client (`Ok (Client_flow.create remote))
client (`Ok (Client_flow.create ~version remote))
| None ->
Log.debug (fun f -> f "request_id %S without client" request_id);
Lwt.return_unit
Expand Down Expand Up @@ -343,30 +372,42 @@ let listen t handler =
Lwt.return `Done in
loop () >|= fun `Done -> ()

let service_params ~version ~service ~vm ~request_id =
let zero_pad s len =
String.init len (fun i -> if i < String.length s then s.[i] else '\000')
in
match version with
| `V2 ->
let service_len = 64
and target_domain_len = 32 in
if String.length service >= service_len ||
String.length vm >= target_domain_len
then raise (Invalid_argument "Qubes.RExec.qrexec: vm or service arguments too long");
let buf = Cstruct.create sizeof_trigger_service_params in
set_trigger_service_params_service_name (zero_pad service service_len) 0 buf;
set_trigger_service_params_target_domain (zero_pad vm target_domain_len) 0 buf;
set_trigger_service_params_request_id request_id 0 buf;
`Trigger_service, buf
| `V3 ->
let target_domain_len = 64 in
let buf = Cstruct.create (sizeof_trigger_service_params3 + String.length request_id) in
set_trigger_service_params3_target_domain (zero_pad vm target_domain_len) 0 buf;
set_trigger_service_params3_request_id request_id 0 buf;
Cstruct.blit_from_string request_id 0 buf sizeof_trigger_service_params3
(String.length request_id);
`Trigger_service3, buf

let qrexec t ~vm ~service client =
let service_len = 64
and target_domain_len = 32 in
if String.length service >= service_len ||
String.length vm >= target_domain_len
then raise (Invalid_argument "Qubes.RExec.qrexec: vm or service arguments too long");
(* XXX: This *should* be unique. The counter could overflow, though *)
let request_id =
let id = t.counter in
t.counter <- id + 1;
(* a '\000' terminated string of length 32 including '\000' *)
Printf.sprintf "MIRAGE%025u\000" id in
let trigger_service_params =
let zero_pad s len =
String.init len (fun i -> if i < String.length s then s.[i] else '\000')
in
let buf = Cstruct.create sizeof_trigger_service_params in
set_trigger_service_params_service_name (zero_pad service service_len) 0 buf;
set_trigger_service_params_target_domain (zero_pad vm target_domain_len) 0 buf;
set_trigger_service_params_request_id request_id 0 buf;
buf
in
let ty, trigger_service_params =
service_params ~version:t.version ~service ~vm ~request_id in
Hashtbl.add t.clients request_id client;
send t.t ~ty:`Trigger_service trigger_service_params >>= function
send t.t ~version:t.version ~ty trigger_service_params >>= function
| `Eof ->
(* XXX: Should we handle this differently? *)
Lwt.async (fun () -> client (`Error "dom0 closed connection"));
Expand All @@ -377,12 +418,8 @@ let qrexec t ~vm ~service client =
let connect ~domid () =
Log.info (fun f -> f "waiting for client...");
QV.server ~domid ~port:vchan_base_port () >>= fun t ->
let t = { t; clients = Hashtbl.create 4; counter = 0; } in
send_hello t.t >>= fun () ->
recv_hello t.t >>= function
| version when version < 2l -> Fmt.failwith "Unsupported qrexec version %ld" version
| version ->
Log.info (fun f -> f "client connected, \
other end wants to use protocol version %lu, \
continuing with version 2" version);
Lwt.return t
send_hello t >>= fun () ->
recv_hello t >>= fun peer_version ->
let version = negotiate_version peer_version in
Log.info (fun f -> f "client connected, using protocol version %ld" (int_of_version version));
Lwt.return { t; clients = Hashtbl.create 4; counter = 0; version; }
2 changes: 1 addition & 1 deletion mirage-qubes-ipv4.opam
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ depends: [
"ipaddr" { >= "3.0.0" }
"mirage-random" {>= "2.0.0"}
"mirage-clock" {>= "3.0.0"}
"cstruct" { >= "1.9.0" }
"cstruct" { >= "6.0.0" }
"lwt"
"logs" { >= "0.5.0" }
"ocaml" { >= "4.06.0" }
Expand Down