Skip to content

Commit

Permalink
Shift primitives for Nat (#3112)
Browse files Browse the repository at this point in the history
Introduce two new primitives {`shiftLeft`, `shiftRight`}, both with signature `(Nat, Nat32) -> Nat`. This is hopefully helpful for #3175.

-------

- [x] `grep` for "TODO: use shift right instead"
- [x] `grep` for "TODO: use shift left instead"
- [ ] QuickChecks (but the tests are rather exhaustive already)
- [x] make the prims into calls? — 1672aa0

----------
## Further optimisation opportunities

- restructure the `lsh` code to reuse the 64-bit shift result to create the bignum ([see comment below](#3112 (comment)))
- `clz >= 33` as a criterion for compact
  • Loading branch information
ggreif committed Jun 21, 2022
1 parent 82146bd commit e4125ef
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 6 deletions.
15 changes: 14 additions & 1 deletion rts/motoko-rts/src/bigint.rs
Expand Up @@ -151,7 +151,7 @@ pub(crate) unsafe fn mp_iszero(p: *const mp_int) -> bool {
(*p).used == 0
}

// Allocates a mp_int on the stack
// Allocates an mp_int on the stack
unsafe fn tmp_bigint() -> mp_int {
let mut i: mp_int = core::mem::zeroed();
check(mp_init(&mut i));
Expand Down Expand Up @@ -400,6 +400,19 @@ unsafe extern "C" fn bigint_lsh(a: Value, b: i32) -> Value {
persist_bigint(i)
}

#[cfg(feature = "ic")]
#[no_mangle]
unsafe extern "C" fn bigint_rsh(a: Value, b: i32) -> Value {
let mut i = tmp_bigint();
check(mp_div_2d(
a.as_bigint().mp_int_ptr(),
b,
&mut i,
core::ptr::null_mut(),
));
persist_bigint(i)
}

#[no_mangle]
unsafe extern "C" fn bigint_count_bits(a: Value) -> i32 {
mp_count_bits(a.as_bigint().mp_int_ptr())
Expand Down
93 changes: 88 additions & 5 deletions src/codegen/compile.ml
Expand Up @@ -853,6 +853,7 @@ module RTS = struct
E.add_func_import env "rts" "bigint_pow" [I32Type; I32Type] [I32Type];
E.add_func_import env "rts" "bigint_neg" [I32Type] [I32Type];
E.add_func_import env "rts" "bigint_lsh" [I32Type; I32Type] [I32Type];
E.add_func_import env "rts" "bigint_rsh" [I32Type; I32Type] [I32Type];
E.add_func_import env "rts" "bigint_abs" [I32Type] [I32Type];
E.add_func_import env "rts" "bigint_leb128_size" [I32Type] [I32Type];
E.add_func_import env "rts" "bigint_leb128_encode" [I32Type; I32Type] [];
Expand Down Expand Up @@ -2123,6 +2124,8 @@ sig
val compile_unsigned_div : E.t -> G.t
val compile_unsigned_rem : E.t -> G.t
val compile_unsigned_pow : E.t -> G.t
val compile_lsh : E.t -> G.t
val compile_rsh : E.t -> G.t

(* comparisons *)
val compile_eq : E.t -> G.t
Expand Down Expand Up @@ -2344,6 +2347,72 @@ module MakeCompact (Num : BigNumType) : BigNumType = struct
get_res
end)

(*
Note [left shifting compact Nat]
For compact Nats (i.e. non-heap allocated ones) we first try to perform the shift in the i64 domain.
for this we extend (signed, but that doesn't really matter) to 64 bits and then perform the left shift.
Then we check whether the result will fit back into the compact representation by either
- comparing: truncate to i32, then sign-extend back to i64, with the shift result
- count leading zeros >= 33 (currently we don't use this idea).
If the test works out, we have to ensure that the shift amount was smaller than 64, due to Wasm semantics.
If this is the case then the truncated i32 is the result (lowest bit is guaranteed to be clear),
otherwise we have to fall back to bignum arithmetic. We have two choices:
- reuse the 64-bit shift result going to heap (not currently, amount must be less than 33 for this to work)
- convert the original base to bigum and do the shift there.
N.B. we currently choose the shift cutoff as 42, just because (it must be <64).
*)

let compile_lsh env =
Func.share_code2 env "B_lsh" (("n", I32Type), ("amount", I32Type)) [I32Type]
(fun env get_n get_amount ->
get_n ^^
BitTagged.if_tagged_scalar env [I32Type]
( (* see Note [left shifting compact Nat] *)
get_n ^^
G.i (Convert (Wasm.Values.I64 I64Op.ExtendSI32)) ^^
get_amount ^^
G.i (Convert (Wasm.Values.I64 I64Op.ExtendUI32)) ^^
G.i (Binary (Wasm.Values.I64 I64Op.Shl)) ^^
let set_remember, get_remember = new_local64 env "remember" in
set_remember ^^ get_remember ^^
G.i (Convert (Wasm.Values.I32 I32Op.WrapI64)) ^^
let set_res, get_res = new_local env "res" in
set_res ^^ get_res ^^
G.i (Convert (Wasm.Values.I64 I64Op.ExtendSI32)) ^^ (* exclude sign flip *)
get_remember ^^
G.i (Compare (Wasm.Values.I64 I64Op.Eq)) ^^
get_amount ^^ compile_rel_const I32Op.LeU 42l ^^
G.i (Binary (Wasm.Values.I32 I32Op.And)) ^^
G.if1 I32Type
get_res
(get_n ^^ compile_shrS_const 1l ^^ Num.from_word30 env ^^ get_amount ^^ Num.compile_lsh env)
)
(get_n ^^ get_amount ^^ Num.compile_lsh env))

let compile_rsh env =
Func.share_code2 env "B_rsh" (("n", I32Type), ("amount", I32Type)) [I32Type]
(fun env get_n get_amount ->
get_n ^^
BitTagged.if_tagged_scalar env [I32Type]
begin
get_n ^^
get_amount ^^
G.i (Binary (Wasm.Values.I32 I32Op.ShrU)) ^^
compile_bitand_const 0xFFFFFFFEl ^^
get_amount ^^ compile_rel_const I32Op.LeU 31l ^^
G.i (Binary (Wasm.Values.I32 I32Op.Mul)) (* branch-free `if` *)
end
begin
get_n ^^ get_amount ^^ Num.compile_rsh env ^^
let set_res, get_res = new_local env "res" in
set_res ^^ get_res ^^
fits_in_vanilla env ^^
G.if1 I32Type
(get_res ^^ Num.truncate_to_word32 env ^^ BitTagged.tag_i32)
get_res
end)

let compile_is_negative env =
let set_n, get_n = new_local env "n" in
set_n ^^ get_n ^^
Expand Down Expand Up @@ -2756,6 +2825,8 @@ module BigNumLibtommath : BigNumType = struct
let compile_unsigned_rem env = E.call_import env "rts" "bigint_rem"
let compile_unsigned_div env = E.call_import env "rts" "bigint_div"
let compile_unsigned_pow env = E.call_import env "rts" "bigint_pow"
let compile_lsh env = E.call_import env "rts" "bigint_lsh"
let compile_rsh env = E.call_import env "rts" "bigint_rsh"

let compile_eq env = E.call_import env "rts" "bigint_eq"
let compile_is_negative env = E.call_import env "rts" "bigint_isneg"
Expand Down Expand Up @@ -4028,9 +4099,9 @@ module Cycles = struct
compile_add_const 8l ^^
(G.i (Load {ty = I64Type; align = 0; offset = 0l; sz = None })) ^^
BigNum.from_word64 env ^^
(* shift left 64 *)
compile_unboxed_const (BigNum.vanilla_lit env (Big_int.power_int_positive_int 2 64)) ^^
BigNum.compile_mul env ^^ (* TODO: use shift left instead *)
(* shift left 64 bits *)
compile_unboxed_const 64l ^^
BigNum.compile_lsh env ^^
BigNum.compile_add env)

(* takes a bignum from the stack, traps if ≥2^128, and leaves two 64bit words on the stack *)
Expand All @@ -4045,8 +4116,8 @@ module Cycles = struct

get_val ^^
(* shift right 64 bits *)
compile_unboxed_const (BigNum.vanilla_lit env (Big_int.power_int_positive_int 2 64)) ^^
BigNum.compile_unsigned_div env ^^ (* TODO: use shift right instead *)
compile_unboxed_const 64l ^^
BigNum.compile_rsh env ^^
BigNum.truncate_to_word64 env ^^

get_val ^^
Expand Down Expand Up @@ -8469,6 +8540,18 @@ and compile_prim_invocation (env : E.t) ae p es at =
| OtherPrim "blob_iter_next", [e] ->
SR.Vanilla, compile_exp_vanilla env ae e ^^ Blob.iter_next env

| OtherPrim "lsh_Nat", [e1; e2] ->
SR.Vanilla,
compile_exp_vanilla env ae e1 ^^
compile_exp_as env ae SR.UnboxedWord32 e2 ^^
BigNum.compile_lsh env

| OtherPrim "rsh_Nat", [e1; e2] ->
SR.Vanilla,
compile_exp_vanilla env ae e1 ^^
compile_exp_as env ae SR.UnboxedWord32 e2 ^^
BigNum.compile_rsh env

| OtherPrim "abs", [e] ->
SR.Vanilla,
compile_exp_vanilla env ae e ^^
Expand Down
9 changes: 9 additions & 0 deletions src/mo_values/prim.ml
Expand Up @@ -179,6 +179,15 @@ let prim =
| Int64 y -> Int64 Int_64.(and_ y (shl (of_int 1) (as_int64 a)))
| _ -> failwith "btst")

| "lsh_Nat" -> fun _ v k ->
(match as_tup v with
| [x; shift] -> k (Int Numerics.Int.(mul (as_int x) (pow (of_int 2) (of_big_int (Nat32.to_big_int (as_nat32 shift))))))
| _ -> failwith "lsh_Nat")
| "rsh_Nat" -> fun _ v k ->
(match as_tup v with
| [x; shift] -> k (Int Numerics.Int.(div (as_int x) (pow (of_int 2) (of_big_int (Nat32.to_big_int (as_nat32 shift))))))
| _ -> failwith "rsh_Nat")

| "conv_Char_Text" -> fun _ v k -> let str = match as_char v with
| c when c <= 0o177 -> String.make 1 (Char.chr c)
| code -> Wasm.Utf8.encode [code]
Expand Down
2 changes: 2 additions & 0 deletions src/prelude/prim.mo
Expand Up @@ -39,6 +39,8 @@ module Types = {
};

func abs(x : Int) : Nat { (prim "abs" : Int -> Nat) x };
func shiftLeft(x : Nat, shift : Nat32) : Nat { (prim "lsh_Nat" : (Nat, Nat32) -> Nat) (x, shift) };
func shiftRight(x : Nat, shift : Nat32) : Nat { (prim "rsh_Nat" : (Nat, Nat32) -> Nat) (x, shift) };

// for testing
func idlHash(x : Text) : Nat32 { (prim "idlHash" : Text -> Nat32) x };
Expand Down
43 changes: 43 additions & 0 deletions test/run/nat-shift.mo
@@ -0,0 +1,43 @@
import { debugPrint; shiftLeft; shiftRight; nat32ToNat } = "mo:⛔"

func checkShiftLeft(base : Nat, amount : Nat32) =
assert base * (2 ** nat32ToNat amount) == shiftLeft(base, amount);

checkShiftLeft(42, 7);
checkShiftLeft(42, 24);
checkShiftLeft(42, 25);
checkShiftLeft(42, 26);
checkShiftLeft(42, 25 + 32); // 57
checkShiftLeft(42, 25 + 64); // 89
checkShiftLeft(42, 125);
checkShiftLeft(0, 125);
checkShiftLeft(10 ** 10, 25);

class range(x : Nat32, y : Nat32) {
var i = x;
public func next() : ?Nat32 { if (i > y) null else {let j = i; i += 1; ?j} };
};

for (i in range(0, 200)) { checkShiftLeft(1, i) };
for (i in range(0, 200)) { checkShiftLeft(42, i) };

func checkShiftRight(base : Nat, amount : Nat32) =
assert base / 2 ** nat32ToNat amount == shiftRight(base, amount);

for (i in range(0, 40)) { checkShiftRight(1, i) };
for (i in range(0, 40)) { checkShiftRight(42, i) };

let huge = 2 ** 190;
for (i in range(0, 200)) { checkShiftRight(huge, i) };
for (i in range(0, 200)) { checkShiftRight(huge - 1, i) };

// iterated

assert 1 == shiftRight(shiftRight(huge, 189), 1);
assert 0 == shiftRight(shiftRight(huge, 189), 33);

// roundtrips
for (i in range(0, 200)) { assert 1 == shiftRight(shiftLeft(1, i), i) };
for (i in range(0, 200)) { assert 42 == shiftRight(shiftLeft(42, i), i) };
for (i in range(0, 200)) { assert huge == shiftRight(shiftLeft(huge, i), i) };
for (i in range(0, 200)) { assert huge - 1 == shiftRight(shiftLeft(huge - 1, i), i) }
2 changes: 2 additions & 0 deletions test/run/ok/nat-shift.tc.ok
@@ -0,0 +1,2 @@
nat-shift.mo:43.35-43.43: warning [M0155], operator may trap for inferred type
Nat

0 comments on commit e4125ef

Please sign in to comment.