Skip to content

Commit

Permalink
Merge pull request #34 from ocaml-wasm/names
Browse files Browse the repository at this point in the history
Improve Wat output
  • Loading branch information
vouillon authored May 13, 2024
2 parents d4c5423 + 6cb2f7e commit ef48ae6
Show file tree
Hide file tree
Showing 13 changed files with 429 additions and 245 deletions.
125 changes: 74 additions & 51 deletions compiler/lib/wasm/wa_asm_output.ml
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ module Output () = struct

let offs _ i = Int32.to_string i

let rec expression e =
let rec expression m e =
match e with
| Const op ->
line
Expand All @@ -259,54 +259,55 @@ module Output () = struct
| ConstSym (name, offset) ->
line (type_prefix (I32 ()) ^^ string "const " ^^ symbol name offset)
| UnOp (op, e') ->
expression e'
expression m e'
^^ line
(type_prefix op
^^ string (select int_un_op int_un_op float_un_op float_un_op op))
| BinOp (op, e1, e2) ->
expression e1
^^ expression e2
expression m e1
^^ expression m e2
^^ line
(type_prefix op
^^ string (select int_bin_op int_bin_op float_bin_op float_bin_op op))
| I32WrapI64 e -> expression e ^^ line (string "i32.wrap_i64")
| I64ExtendI32 (s, e) -> expression e ^^ line (string (signage "i64.extend_i32" s))
| F32DemoteF64 e -> expression e ^^ line (string "f32.demote_f64")
| F64PromoteF32 e -> expression e ^^ line (string "f64.promote_f32")
| I32WrapI64 e -> expression m e ^^ line (string "i32.wrap_i64")
| I64ExtendI32 (s, e) -> expression m e ^^ line (string (signage "i64.extend_i32" s))
| F32DemoteF64 e -> expression m e ^^ line (string "f32.demote_f64")
| F64PromoteF32 e -> expression m e ^^ line (string "f64.promote_f32")
| Load (offset, e') ->
expression e'
expression m e'
^^ line
(type_prefix offset
^^ string "load "
^^ string (select offs offs offs offs offset))
| Load8 (s, offset, e') ->
expression e'
expression m e'
^^ line
(type_prefix offset
^^ string (signage "load8" s)
^^ string " "
^^ string (select offs offs offs offs offset))
| LocalGet i -> line (string "local.get " ^^ integer i)
| LocalTee (i, e') -> expression e' ^^ line (string "local.tee " ^^ integer i)
| LocalGet i -> line (string "local.get " ^^ integer (Hashtbl.find m i))
| LocalTee (i, e') ->
expression m e' ^^ line (string "local.tee " ^^ integer (Hashtbl.find m i))
| GlobalGet nm -> line (string "global.get " ^^ symbol nm 0)
| BlockExpr (ty, l) ->
line (string "block" ^^ block_type ty)
^^ indent (concat_map instruction l)
^^ indent (concat_map (instruction m) l)
^^ line (string "end_block")
| Call_indirect (typ, f, l) ->
concat_map expression l
^^ expression f
concat_map (expression m) l
^^ expression m f
^^ line (string "call_indirect " ^^ func_type typ)
| Call (x, l) -> concat_map expression l ^^ line (string "call " ^^ index x)
| MemoryGrow (mem, e) -> expression e ^^ line (string "memory.grow " ^^ integer mem)
| Seq (l, e') -> concat_map instruction l ^^ expression e'
| Call (x, l) -> concat_map (expression m) l ^^ line (string "call " ^^ index x)
| MemoryGrow (mem, e) -> expression m e ^^ line (string "memory.grow " ^^ integer mem)
| Seq (l, e') -> concat_map (instruction m) l ^^ expression m e'
| Pop _ -> empty
| IfExpr (ty, e, e1, e2) ->
expression e
expression m e
^^ line (string "if" ^^ block_type { params = []; result = [ ty ] })
^^ indent (expression e1)
^^ indent (expression m e1)
^^ line (string "else")
^^ indent (expression e2)
^^ indent (expression m e2)
^^ line (string "end_if")
| RefFunc _
| Call_ref _
Expand All @@ -328,83 +329,85 @@ module Output () = struct
| ExternExternalize _
| ExternInternalize _ -> assert false (* Not supported *)

and instruction i =
and instruction m i =
match i with
| Drop e -> expression e ^^ line (string "drop")
| Drop e -> expression m e ^^ line (string "drop")
| Store (offset, e, e') ->
expression e
^^ expression e'
expression m e
^^ expression m e'
^^ line
(type_prefix offset
^^ string "store "
^^ string (select offs offs offs offs offset))
| Store8 (offset, e, e') ->
expression e
^^ expression e'
expression m e
^^ expression m e'
^^ line
(type_prefix offset
^^ string "store8 "
^^ string (select offs offs offs offs offset))
| LocalSet (i, e) -> expression e ^^ line (string "local.set " ^^ integer i)
| GlobalSet (nm, e) -> expression e ^^ line (string "global.set " ^^ symbol nm 0)
| LocalSet (i, e) ->
expression m e ^^ line (string "local.set " ^^ integer (Hashtbl.find m i))
| GlobalSet (nm, e) -> expression m e ^^ line (string "global.set " ^^ symbol nm 0)
| Loop (ty, l) ->
line (string "loop" ^^ block_type ty)
^^ indent (concat_map instruction l)
^^ indent (concat_map (instruction m) l)
^^ line (string "end_loop")
| Block (ty, l) ->
line (string "block" ^^ block_type ty)
^^ indent (concat_map instruction l)
^^ indent (concat_map (instruction m) l)
^^ line (string "end_block")
| If (ty, e, l1, l2) ->
expression e
expression m e
^^ line (string "if" ^^ block_type ty)
^^ indent (concat_map instruction l1)
^^ indent (concat_map (instruction m) l1)
^^ line (string "else")
^^ indent (concat_map instruction l2)
^^ indent (concat_map (instruction m) l2)
^^ line (string "end_if")
| Br_table (e, l, i) ->
expression e
expression m e
^^ line
(string "br_table {"
^^ separate_map (string ", ") integer (l @ [ i ])
^^ string "}")
| Br (i, Some e) -> expression e ^^ instruction (Br (i, None))
| Br (i, Some e) -> expression m e ^^ instruction m (Br (i, None))
| Br (i, None) -> line (string "br " ^^ integer i)
| Br_if (i, e) -> expression e ^^ line (string "br_if " ^^ integer i)
| Return (Some e) -> expression e ^^ instruction (Return None)
| Br_if (i, e) -> expression m e ^^ line (string "br_if " ^^ integer i)
| Return (Some e) -> expression m e ^^ instruction m (Return None)
| Return None -> line (string "return")
| CallInstr (x, l) -> concat_map expression l ^^ line (string "call " ^^ index x)
| CallInstr (x, l) -> concat_map (expression m) l ^^ line (string "call " ^^ index x)
| Nop -> empty
| Push e -> expression e
| Push e -> expression m e
| Try (ty, body, catches, catch_all) ->
Feature.require exception_handling;
line (string "try" ^^ block_type ty)
^^ indent (concat_map instruction body)
^^ indent (concat_map (instruction m) body)
^^ concat_map
(fun (tag, l) ->
line (string "catch " ^^ index tag) ^^ indent (concat_map instruction l))
line (string "catch " ^^ index tag)
^^ indent (concat_map (instruction m) l))
catches
^^ (match catch_all with
| None -> empty
| Some l -> line (string "catch_all") ^^ indent (concat_map instruction l))
| Some l -> line (string "catch_all") ^^ indent (concat_map (instruction m) l))
^^ line (string "end_try")
| Throw (i, e) ->
Feature.require exception_handling;
expression e ^^ line (string "throw " ^^ index i)
expression m e ^^ line (string "throw " ^^ index i)
| Rethrow i ->
Feature.require exception_handling;
line (string "rethrow " ^^ integer i)
| Return_call_indirect (typ, f, l) ->
Feature.require tail_call;
concat_map expression l
^^ expression f
concat_map (expression m) l
^^ expression m f
^^ line (string "return_call_indirect " ^^ func_type typ)
| Return_call (x, l) ->
Feature.require tail_call;
concat_map expression l ^^ line (string "return_call " ^^ index x)
concat_map (expression m) l ^^ line (string "return_call " ^^ index x)
| Location (_, i) ->
(* Source maps not supported for the non-GC target *)
instruction i
instruction m i
| ArraySet _ | StructSet _ | Return_call_ref _ -> assert false (* Not supported *)

let escape_string s =
Expand Down Expand Up @@ -595,7 +598,24 @@ module Output () = struct
concat_map
(fun f ->
match f with
| Function { name; exported_name; typ; locals; body } ->
| Function { name; exported_name; typ; param_names; locals; body } ->
let local_names = Hashtbl.create 8 in
let idx =
List.fold_left
~f:(fun idx x ->
Hashtbl.add local_names x idx;
idx + 1)
~init:0
param_names
in
let _ =
List.fold_left
~f:(fun idx (x, _) ->
Hashtbl.add local_names x idx;
idx + 1)
~init:idx
locals
in
indent
(section_header "text" (V name)
^^ define_symbol (V name)
Expand All @@ -616,8 +636,11 @@ module Output () = struct
else
line
(string ".local "
^^ separate_map (string ", ") value_type locals))
^^ concat_map instruction body
^^ separate_map
(string ", ")
(fun (_, ty) -> value_type ty)
locals))
^^ concat_map (instruction local_names) body
^^ line (string "end_function"))
| Import _ | Data _ | Global _ | Tag _ | Type _ -> empty)
fields
Expand Down
9 changes: 5 additions & 4 deletions compiler/lib/wasm/wa_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ type expression =
| F64PromoteF32 of expression
| Load of (memarg, memarg, memarg, memarg) op * expression
| Load8 of signage * (memarg, memarg, memarg, memarg) op * expression
| LocalGet of int
| LocalTee of int * expression
| LocalGet of var
| LocalTee of var * expression
| GlobalGet of symbol
| BlockExpr of func_type * instruction list
| Call_indirect of func_type * expression * expression list
Expand Down Expand Up @@ -163,7 +163,7 @@ and instruction =
| Drop of expression
| Store of (memarg, memarg, memarg, memarg) op * expression * expression
| Store8 of (memarg, memarg, memarg, memarg) op * expression * expression
| LocalSet of int * expression
| LocalSet of var * expression
| GlobalSet of symbol * expression
| Loop of func_type * instruction list
| Block of func_type * instruction list
Expand Down Expand Up @@ -215,7 +215,8 @@ type module_field =
{ name : var
; exported_name : string option
; typ : func_type
; locals : value_type list
; param_names : var list
; locals : (var * value_type) list
; body : instruction list
}
| Data of
Expand Down
26 changes: 14 additions & 12 deletions compiler/lib/wasm/wa_code_generation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ let make_context ~value_type =
}

type var =
| Local of int * W.value_type option
| Local of int * Var.t * W.value_type option
| Expr of W.expression t

and state =
Expand Down Expand Up @@ -247,14 +247,14 @@ let var x st =

let add_var ?typ x ({ var_count; vars; _ } as st) =
match Var.Map.find_opt x vars with
| Some (Local (i, typ')) ->
| Some (Local (_, x', typ')) ->
assert (Poly.equal typ typ');
i, st
x', st
| Some (Expr _) -> assert false
| None ->
let i = var_count in
let vars = Var.Map.add x (Local (i, typ)) vars in
i, { st with var_count = var_count + 1; vars }
let vars = Var.Map.add x (Local (i, x, typ)) vars in
x, { st with var_count = var_count + 1; vars }

let define_var x e st = (), { st with vars = Var.Map.add x (Expr e) st.vars }

Expand Down Expand Up @@ -442,7 +442,7 @@ let rec is_smi e =

let get_i31_value x st =
match st.instrs with
| LocalSet (x', RefI31 e) :: rem when x = x' && is_smi e ->
| LocalSet (x', RefI31 e) :: rem when Code.Var.equal x x' && is_smi e ->
let x = Var.fresh () in
let x, st = add_var ~typ:I32 x st in
Some x, { st with instrs = LocalSet (x', RefI31 (LocalTee (x, e))) :: rem }
Expand All @@ -451,7 +451,7 @@ let get_i31_value x st =
let load x =
let* x = var x in
match x with
| Local (x, _) -> return (W.LocalGet x)
| Local (_, x, _) -> return (W.LocalGet x)
| Expr e -> e

let tee ?typ x e =
Expand Down Expand Up @@ -509,7 +509,7 @@ let assign x e =
let* x = var x in
let* e = e in
match x with
| Local (x, _) -> instr (W.LocalSet (x, e))
| Local (_, x, _) -> instr (W.LocalSet (x, e))
| Expr _ -> assert false

let seq l e =
Expand Down Expand Up @@ -613,21 +613,23 @@ let need_dummy_fun ~cps ~arity st =

let init_code context = instrs context.init_code

let function_body ~context ~param_count ~body =
let function_body ~context ~param_names ~body =
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
let (), st = body st in
let local_count, body = st.var_count, List.rev st.instrs in
let local_types = Array.make local_count None in
let local_types = Array.make local_count (Var.fresh (), None) in
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
Var.Map.iter
(fun _ v ->
match v with
| Local (i, typ) -> local_types.(i) <- typ
| Local (i, x, typ) -> local_types.(i) <- x, typ
| Expr _ -> ())
st.vars;
let body = Wa_tail_call.f body in
let param_count = List.length param_names in
let locals =
local_types
|> Array.map ~f:(fun v -> Option.value ~default:context.value_type v)
|> Array.map ~f:(fun (x, v) -> x, Option.value ~default:context.value_type v)
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
|> Array.to_list
in
Expand Down
8 changes: 4 additions & 4 deletions compiler/lib/wasm/wa_code_generation.mli
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ val if_ : Wa_ast.func_type -> expression -> unit t -> unit t -> unit t

val try_ : Wa_ast.func_type -> unit t -> (Code.Var.t * unit t) list -> unit t

val add_var : ?typ:Wa_ast.value_type -> Wa_ast.var -> int t
val add_var : ?typ:Wa_ast.value_type -> Wa_ast.var -> Wa_ast.var t

val define_var : Wa_ast.var -> expression -> unit t

val is_small_constant : Wa_ast.expression -> bool t

val get_i31_value : int -> int option t
val get_i31_value : Wa_ast.var -> Wa_ast.var option t

val with_location : Code.loc -> unit t -> unit t

Expand Down Expand Up @@ -167,6 +167,6 @@ val need_dummy_fun : cps:bool -> arity:int -> Code.Var.t t

val function_body :
context:context
-> param_count:int
-> param_names:Code.Var.t list
-> body:unit t
-> Wa_ast.value_type list * Wa_ast.instruction list
-> (Wa_ast.var * Wa_ast.value_type) list * Wa_ast.instruction list
4 changes: 2 additions & 2 deletions compiler/lib/wasm/wa_core_target.ml
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ let handle_exceptions ~result_typ ~fall_through ~context body x exn_handler =
exn_handler ~result_typ ~fall_through ~context )
]

let post_process_function_body ~param_count:_ ~locals:_ instrs = instrs
let post_process_function_body ~param_names:_ ~locals:_ instrs = instrs

let entry_point ~context:_ ~toplevel_fun =
let code =
Expand All @@ -653,4 +653,4 @@ let entry_point ~context:_ ~toplevel_fun =
let* () = instr (W.GlobalSet (S "young_limit", low)) in
drop (return (W.Call (toplevel_fun, [])))
in
{ W.params = []; result = [] }, code
{ W.params = []; result = [] }, [], code
Loading

0 comments on commit ef48ae6

Please sign in to comment.