diff --git a/compiler/bin-wasm_of_ocaml/compile.ml b/compiler/bin-wasm_of_ocaml/compile.ml index cd91ca450a..85f9a33255 100644 --- a/compiler/bin-wasm_of_ocaml/compile.ml +++ b/compiler/bin-wasm_of_ocaml/compile.ml @@ -81,6 +81,42 @@ let with_runtime_files ~runtime_wasm_files f = in Wat_preprocess.with_preprocessed_files ~variables:[] ~inputs f +let build_runtime ~runtime_file = + (* Keep this variables in sync with gen/gen.ml *) + let variables = + [ ( "effects" + , Wat_preprocess.String + (match Config.effects () with + | `Jspi -> "jspi" + | `Cps -> "cps" + | `Disabled | `Double_translation -> assert false) ) + ] + in + match + List.find_opt Runtime_files.precompiled_runtimes ~f:(fun (flags, _) -> + assert ( + List.length flags = List.length variables + && List.for_all2 ~f:(fun (k, _) (k', _) -> String.equal k k') flags variables); + Poly.equal flags variables) + with + | Some (_, contents) -> Fs.write_file ~name:runtime_file ~contents + | None -> + let inputs = + List.map + ~f:(fun (module_name, contents) -> + { Wat_preprocess.module_name + ; file = module_name ^ ".wat" + ; source = Contents contents + }) + Runtime_files.wat_files + in + Runtime.build + ~link_options:[ "-g" ] + ~opt_options:[ "-g"; "-O2" ] + ~variables + ~inputs + ~output_file:runtime_file + let link_and_optimize ~profile ~sourcemap_root @@ -99,7 +135,7 @@ let link_and_optimize let enable_source_maps = Option.is_some opt_sourcemap_file in Fs.with_intermediate_file (Filename.temp_file "runtime" ".wasm") @@ fun runtime_file -> - Fs.write_file ~name:runtime_file ~contents:Runtime_files.wasm_runtime; + build_runtime ~runtime_file; Fs.with_intermediate_file (Filename.temp_file "wasm-merged" ".wasm") @@ fun temp_file -> opt_with @@ -145,7 +181,7 @@ let link_and_optimize let link_runtime ~profile runtime_wasm_files output_file = if List.is_empty runtime_wasm_files - then Fs.write_file ~name:output_file ~contents:Runtime_files.wasm_runtime + then build_runtime ~runtime_file:output_file else Fs.with_intermediate_file (Filename.temp_file "extra_runtime" ".wasm") @@ fun extra_runtime -> @@ -167,7 +203,7 @@ let link_runtime ~profile runtime_wasm_files output_file = (); Fs.with_intermediate_file (Filename.temp_file "runtime" ".wasm") @@ fun runtime_file -> - Fs.write_file ~name:runtime_file ~contents:Runtime_files.wasm_runtime; + build_runtime ~runtime_file; Binaryen.link ~opt_output_sourcemap:None ~inputs: diff --git a/compiler/bin-wasm_of_ocaml/dune b/compiler/bin-wasm_of_ocaml/dune index 48619f0fe4..1870a60f7c 100644 --- a/compiler/bin-wasm_of_ocaml/dune +++ b/compiler/bin-wasm_of_ocaml/dune @@ -25,9 +25,10 @@ (target runtime_files.ml) (deps gen/gen.exe - ../../runtime/wasm/runtime.wasm ../../runtime/wasm/runtime.js - ../../runtime/wasm/deps.json) + ../../runtime/wasm/deps.json + (glob_files ../../runtime/wasm/*.wat) + (glob_files ../../runtime/wasm/runtime-*.wasm)) (action (with-stdout-to %{target} diff --git a/compiler/bin-wasm_of_ocaml/gen/gen.ml b/compiler/bin-wasm_of_ocaml/gen/gen.ml index b7a20c4e3e..6cc0146bc6 100644 --- a/compiler/bin-wasm_of_ocaml/gen/gen.ml +++ b/compiler/bin-wasm_of_ocaml/gen/gen.ml @@ -1,13 +1,77 @@ let read_file ic = really_input_string ic (in_channel_length ic) +(* Keep the two variables below in sync with function build_runtime in + ../compile.ml *) + +let default_flags = [] + +let interesting_runtimes = [ [ "effects", `S "jspi" ]; [ "effects", `S "cps" ] ] + +let name_runtime standard l = + let flags = + List.filter_map + (fun (k, v) -> + match v with + | `S s -> Some s + | `B b -> if b then Some k else None) + l + in + String.concat "-" ("runtime" :: (if standard then [ "standard" ] else flags)) ^ ".wasm" + +let print_flags f flags = + Format.fprintf + f + "@[<2>[ %a ]@]" + (Format.pp_print_list + ~pp_sep:(fun f () -> Format.fprintf f ";@ ") + (fun f (k, v) -> + Format.fprintf + f + "@[\"%s\",@ %a@]" + k + (fun f v -> + match v with + | `S s -> Format.fprintf f "Wat_preprocess.String \"%s\"" s + | `B b -> + Format.fprintf f "Wat_preprocess.Bool %s" (if b then "true" else "false")) + v)) + flags + let () = let () = set_binary_mode_out stdout true in + Format.printf "open Wasm_of_ocaml_compiler@."; Format.printf - "let wasm_runtime = \"%s\"@." + "let js_runtime = \"%s\"@." (String.escaped (read_file (open_in_bin Sys.argv.(1)))); Format.printf - "let js_runtime = \"%s\"@." + "let dependencies = \"%s\"@." (String.escaped (read_file (open_in_bin Sys.argv.(2)))); + let wat_files, runtimes = + List.partition + (fun f -> Filename.check_suffix f ".wat") + (Array.to_list (Array.sub Sys.argv 3 (Array.length Sys.argv - 3))) + in Format.printf - "let dependencies = \"%s\"@." - (String.escaped (read_file (open_in_bin Sys.argv.(3)))) + "let wat_files = [%a]@." + (Format.pp_print_list (fun f file -> + Format.fprintf + f + "\"%s\", \"%s\"; " + Filename.(chop_suffix (basename file) ".wat") + (String.escaped (read_file (open_in_bin file))))) + wat_files; + Format.printf + "let precompiled_runtimes = [%a]@." + (Format.pp_print_list (fun f (standard, flags) -> + let flags = flags @ default_flags in + let name = name_runtime standard flags in + match List.find_opt (fun file -> Filename.basename file = name) runtimes with + | None -> failwith ("Missing runtime " ^ name) + | Some file -> + Format.fprintf + f + "%a, \"%s\"; " + print_flags + flags + (String.escaped (read_file (open_in_bin file))))) + (List.mapi (fun i flags -> i = 0, flags) interesting_runtimes) diff --git a/compiler/lib-wasm/gc_target.ml b/compiler/lib-wasm/gc_target.ml index cd82902c31..e450183da1 100644 --- a/compiler/lib-wasm/gc_target.ml +++ b/compiler/lib-wasm/gc_target.ml @@ -1703,18 +1703,6 @@ let post_process_function_body = Initialize_locals.f let entry_point ~toplevel_fun = let code = - let* () = - match Config.effects () with - | `Cps | `Double_translation -> - let* f = - register_import - ~name:"caml_cps_initialize_effects" - (Fun { W.params = []; result = [] }) - in - instr (W.CallInstr (f, [])) - | `Jspi -> return () - | `Disabled -> assert false - in let* main = register_import ~name:"caml_main" diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 45f76c6f27..46ba5e4bf4 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -1189,9 +1189,6 @@ let init () = ] in Primitive.register "caml_array_of_uniform_array" `Mutable None None; - let l = - if effects_cps () then ("caml_alloc_stack", "caml_cps_alloc_stack") :: l else l - in List.iter ~f:(fun (nm, nm') -> Primitive.alias nm nm') l (* Make sure we can use [br_table] for switches *) diff --git a/runtime/wasm/dune b/runtime/wasm/dune index 0df18be889..778660b290 100644 --- a/runtime/wasm/dune +++ b/runtime/wasm/dune @@ -1,10 +1,13 @@ (install (section lib) (package wasm_of_ocaml-compiler) - (files runtime.wasm runtime.js)) + (files + (glob_files *.wat) + (glob_files runtime-*.wasm) + runtime.js)) (rule - (target runtime.wasm) + (target runtime-standard.wasm) (deps args (glob_files *.wat)) @@ -13,6 +16,21 @@ ../../compiler/bin-wasm_of_ocaml/wasmoo_link_wasm.exe --binaryen=-g --binaryen-opt=-O3 + --set=effects=jspi + %{target} + %{read-lines:args}))) + +(rule + (target runtime-cps.wasm) + (deps + args + (glob_files *.wat)) + (action + (run + ../../compiler/bin-wasm_of_ocaml/wasmoo_link_wasm.exe + --binaryen=-g + --binaryen-opt=-O3 + --set=effects=cps %{target} %{read-lines:args}))) diff --git a/runtime/wasm/effect.wat b/runtime/wasm/effect.wat index 2a579fd1c3..e71be4f60c 100644 --- a/runtime/wasm/effect.wat +++ b/runtime/wasm/effect.wat @@ -52,6 +52,88 @@ (sub $closure (struct (field (ref $function_1)) (field (ref $function_3))))) + ;; Generic fibers + + (type $generic_fiber + (sub + (struct + (field $value (mut (ref eq))) + (field $exn (mut (ref eq))) + (field $effect (mut (ref eq)))))) + + (@string $already_resumed "Effect.Continuation_already_resumed") + + (@string $effect_unhandled "Effect.Unhandled") + + (func $raise_unhandled + (param $eff (ref eq)) (param (ref eq)) (result (ref eq)) + (block $null + (call $caml_raise_with_arg + (br_on_null $null + (call $caml_named_value (global.get $effect_unhandled))) + (local.get $eff))) + (call $caml_raise_constant + (array.new_fixed $block 3 (ref.i31 (global.get $object_tag)) + (global.get $effect_unhandled) + (call $caml_fresh_oo_id (ref.i31 (i32.const 0))))) + (ref.i31 (i32.const 0))) + + (global $raise_unhandled (ref $closure) + (struct.new $closure (ref.func $raise_unhandled))) + + (global $effect_allowed (mut i32) (i32.const 1)) + + (func $caml_continuation_use_noexc (export "caml_continuation_use_noexc") + (param (ref eq)) (result (ref eq)) + (local $cont (ref $block)) + (local $stack (ref eq)) + (drop (block $used (result (ref eq)) + (local.set $cont (ref.cast (ref $block) (local.get 0))) + (local.set $stack + (br_on_cast_fail $used (ref eq) (ref $generic_fiber) + (array.get $block (local.get $cont) (i32.const 1)))) + (array.set $block (local.get $cont) (i32.const 1) + (ref.i31 (i32.const 0))) + (return (local.get $stack)))) + (ref.i31 (i32.const 0))) + + (func (export "caml_continuation_use_and_update_handler_noexc") + (param $cont (ref eq)) (param $hval (ref eq)) (param $hexn (ref eq)) + (param $heff (ref eq)) (result (ref eq)) + (local $stack (ref eq)) + (local $tail (ref $generic_fiber)) + (local.set $stack (call $caml_continuation_use_noexc (local.get $cont))) + (if (ref.test (ref $generic_fiber) (local.get $stack)) + (then + (local.set $tail + (ref.cast (ref $generic_fiber) + (array.get $block + (ref.cast (ref $block) (local.get $cont)) + (i32.const 2)))) + (struct.set $generic_fiber $value (local.get $tail) + (local.get $hval)) + (struct.set $generic_fiber $exn (local.get $tail) (local.get $hexn)) + (struct.set $generic_fiber $effect (local.get $tail) + (local.get $heff)))) + (local.get $stack)) + + (func (export "caml_get_continuation_callstack") + (param (ref eq) (ref eq)) (result (ref eq)) + (array.new_fixed $block 1 (ref.i31 (i32.const 0)))) + + (func (export "caml_is_continuation") (param (ref eq)) (result i32) + (drop (block $not_continuation (result (ref eq)) + (return + (ref.eq + (array.get $block + (br_on_cast_fail $not_continuation (ref eq) (ref $block) + (local.get 0)) + (i32.const 0)) + (ref.i31 (global.get $cont_tag)))))) + (i32.const 0)) + +(@if (= effects "jspi") +(@then ;; Apply a function f to a value v, both contained in a pair (f, v) (type $pair (struct (field (ref eq)) (field (ref eq)))) @@ -114,13 +196,6 @@ ;; Stack of fibers - (type $generic_fiber - (sub - (struct - (field $value (mut (ref eq))) - (field $exn (mut (ref eq))) - (field $effect (mut (ref eq)))))) - (type $fiber (sub final $generic_fiber (struct @@ -130,24 +205,6 @@ (field $cont (mut (ref $cont))) (field $next (mut (ref null $fiber)))))) - (@string $effect_unhandled "Effect.Unhandled") - - (func $raise_unhandled - (param $eff (ref eq)) (param (ref eq)) (result (ref eq)) - (block $null - (call $caml_raise_with_arg - (br_on_null $null - (call $caml_named_value (global.get $effect_unhandled))) - (local.get $eff))) - (call $caml_raise_constant - (array.new_fixed $block 3 (ref.i31 (global.get $object_tag)) - (global.get $effect_unhandled) - (call $caml_fresh_oo_id (ref.i31 (i32.const 0))))) - (ref.i31 (i32.const 0))) - - (global $raise_unhandled (ref $closure) - (struct.new $closure (ref.func $raise_unhandled))) - (func $initial_cont (param $p (ref $pair)) (param (ref eq)) (return_call $start_fiber (local.get $p))) @@ -214,8 +271,6 @@ (local.get $k) (struct.get $cont $cont_func (local.get $k)))) - (@string $already_resumed "Effect.Continuation_already_resumed") - (func $resume (export "%resume") (param $stack_head (ref eq)) (param $f (ref eq)) (param $v (ref eq)) (param $stack_tail (ref eq)) (result (ref eq)) @@ -280,8 +335,6 @@ (local.get $k1) (struct.get $cont $cont_func (local.get $k1)))) - (global $effect_allowed (mut i32) (i32.const 1)) - (func (export "%perform") (param $eff (ref eq)) (result (ref eq)) (if (i32.or (i32.eqz (global.get $effect_allowed)) (ref.is_null (struct.get $fiber $next (global.get $stack)))) @@ -380,60 +433,10 @@ (local.get $hv) (local.get $hx) (local.get $hf) (global.get $initial_cont) (ref.null $fiber))) +)) - ;; Other functions - - (func $caml_continuation_use_noexc (export "caml_continuation_use_noexc") - (param (ref eq)) (result (ref eq)) - (local $cont (ref $block)) - (local $stack (ref eq)) - (drop (block $used (result (ref eq)) - (local.set $cont (ref.cast (ref $block) (local.get 0))) - (local.set $stack - (br_on_cast_fail $used (ref eq) (ref $generic_fiber) - (array.get $block (local.get $cont) (i32.const 1)))) - (array.set $block (local.get $cont) (i32.const 1) - (ref.i31 (i32.const 0))) - (return (local.get $stack)))) - (ref.i31 (i32.const 0))) - - (func (export "caml_continuation_use_and_update_handler_noexc") - (param $cont (ref eq)) (param $hval (ref eq)) (param $hexn (ref eq)) - (param $heff (ref eq)) (result (ref eq)) - (local $stack (ref eq)) - (local $tail (ref $generic_fiber)) - (local.set $stack (call $caml_continuation_use_noexc (local.get $cont))) - (if (ref.test (ref $generic_fiber) (local.get $stack)) - (then - (local.set $tail - (ref.cast (ref $generic_fiber) - (array.get $block - (ref.cast (ref $block) (local.get $cont)) - (i32.const 2)))) - (struct.set $generic_fiber $value (local.get $tail) - (local.get $hval)) - (struct.set $generic_fiber $exn (local.get $tail) (local.get $hexn)) - (struct.set $generic_fiber $effect (local.get $tail) - (local.get $heff)))) - (local.get $stack)) - - (func (export "caml_get_continuation_callstack") - (param (ref eq) (ref eq)) (result (ref eq)) - (array.new_fixed $block 1 (ref.i31 (i32.const 0)))) - - (func (export "caml_is_continuation") (param (ref eq)) (result i32) - (drop (block $not_continuation (result (ref eq)) - (return - (ref.eq - (array.get $block - (br_on_cast_fail $not_continuation (ref eq) (ref $block) - (local.get 0)) - (i32.const 0)) - (ref.i31 (global.get $cont_tag)))))) - (i32.const 0)) - - ;; Effects through CPS transformation - +(@if (= effects "cps") +(@then (type $function_2 (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq)))) (type $function_4 @@ -562,7 +565,7 @@ (param (ref eq)) (param (ref eq)) (param (ref eq)) (result (ref eq)) (unreachable)) - (func $caml_trampoline (export "caml_trampoline") + (func (export "caml_trampoline") (param $f (ref eq)) (param $vargs (ref eq)) (result (ref eq)) (local $args (ref $block)) (local $i i32) (local $res (ref eq)) @@ -639,9 +642,6 @@ (global.set $cps_fiber_stack (local.get $saved_fiber_stack)) (throw $ocaml_exception (local.get $exn))) - (global $caml_trampoline_ref (export "caml_trampoline_ref") - (mut (ref null $function_1)) (ref.null $function_1)) - (func $caml_pop_fiber (result (ref eq)) (local $f (ref $cps_fiber)) (local.set $f @@ -765,7 +765,7 @@ (global $exn_handler (ref $closure) (struct.new $closure (ref.func $exn_handler))) - (func (export "caml_cps_alloc_stack") + (func (export "caml_alloc_stack") (param $hv (ref eq)) (param $hx (ref eq)) (param $hf (ref eq)) (result (ref eq)) (struct.new $cps_fiber @@ -773,9 +773,7 @@ (global.get $value_handler) (struct.new $exn_stack (global.get $exn_handler) (ref.null $exn_stack)) (ref.null $cps_fiber))) - - (func (export "caml_cps_initialize_effects") - (global.set $caml_trampoline_ref (ref.func $caml_trampoline))) +)) (func (export "caml_assume_no_perform") (param $f (ref eq)) (result (ref eq)) (local $saved_effect_allowed i32) diff --git a/runtime/wasm/obj.wat b/runtime/wasm/obj.wat index 8e44ecd376..4eba296265 100644 --- a/runtime/wasm/obj.wat +++ b/runtime/wasm/obj.wat @@ -23,8 +23,12 @@ (func $caml_dup_custom (param (ref eq)) (result (ref eq)))) (import "effect" "caml_is_continuation" (func $caml_is_continuation (param (ref eq)) (result i32))) - (import "effect" "caml_trampoline_ref" - (global $caml_trampoline_ref (mut (ref null $function_1)))) +(@if (= effects "cps") +(@then + (import "effect" "caml_trampoline" + (func $caml_trampoline (param (ref eq) (ref eq)) (result (ref eq)))) +)) + (type $block (array (mut (ref eq)))) (type $bytes (array (mut i8))) @@ -452,18 +456,28 @@ (func (export "caml_obj_reachable_words") (param (ref eq)) (result (ref eq)) (ref.i31 (i32.const 0))) +(@if (= effects "cps") +(@then (func $caml_callback_1 (export "caml_callback_1") (param $f (ref eq)) (param $x (ref eq)) (result (ref eq)) - (drop (block $cps (result (ref eq)) - (return_call_ref $function_1 (local.get $x) - (local.get $f) - (struct.get $closure 0 - (br_on_cast_fail $cps (ref eq) (ref $closure) - (local.get $f)))))) - (return_call_ref $function_1 + (return_call $caml_trampoline + (local.get $f) + (array.new_fixed $block 2 (ref.i31 (i32.const 0)) (local.get $x)))) + + (func (export "caml_callback_2") + (param $f (ref eq)) (param $x (ref eq)) (param $y (ref eq)) + (result (ref eq)) + (return_call $caml_trampoline (local.get $f) - (array.new_fixed $block 2 (ref.i31 (i32.const 0)) (local.get $x)) - (ref.as_non_null (global.get $caml_trampoline_ref)))) + (array.new_fixed $block 3 (ref.i31 (i32.const 0)) + (local.get $x) (local.get $y)))) +) +(@else + (func $caml_callback_1 (export "caml_callback_1") + (param $f (ref eq)) (param $x (ref eq)) (result (ref eq)) + (return_call_ref $function_1 (local.get $x) + (local.get $f) + (struct.get $closure 0 (ref.cast (ref $closure) (local.get $f))))) (func (export "caml_callback_2") (param $f (ref eq)) (param $x (ref eq)) (param $y (ref eq)) @@ -474,15 +488,8 @@ (struct.get $closure_2 1 (br_on_cast_fail $not_direct (ref eq) (ref $closure_2) (local.get $f)))))) - (if (ref.test (ref $closure) (local.get $f)) - (then - (return_call $caml_callback_1 - (call $caml_callback_1 (local.get $f) (local.get $x)) - (local.get $y))) - (else - (return_call_ref $function_1 - (local.get $f) - (array.new_fixed $block 3 (ref.i31 (i32.const 0)) - (local.get $x) (local.get $y)) - (ref.as_non_null (global.get $caml_trampoline_ref)))))) + (return_call $caml_callback_1 + (call $caml_callback_1 (local.get $f) (local.get $x)) + (local.get $y))) +)) )