Skip to content

Commit

Permalink
FIXES bytecodealliance#137: make calling wasm from python 7x faster
Browse files Browse the repository at this point in the history
  • Loading branch information
muayyad-alsadi committed Apr 4, 2023
1 parent 1446d79 commit 9122bb4
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 27 deletions.
5 changes: 3 additions & 2 deletions examples/gcd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Example of instantiating a wasm module and calling an export on it

from wasmtime import Store, Module, Instance

from functools import partial
store = Store()
module = Module.from_file(store.engine, './examples/gcd.wat')
instance = Instance(store, module, [])
gcd = instance.exports(store)["gcd"]

gcd_func = partial(gcd, store)
print("gcd(6, 27) = %d" % gcd(store, 6, 27))
print("gcd(6, 27) = %d" % gcd_func(6, 27))
4 changes: 3 additions & 1 deletion examples/gcd_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from gcd import gcd_func as wasm_gcd
from gcd_alt import gcd as wasm_gcd_alt


def python_gcd(x, y):
while y:
x, y = y, x % y
return abs(x)


N = 1_000
by_name = locals()
for name in 'math_gcd', 'python_gcd', 'wasm_gcd', 'wasm_gcd_alt':
for name in "math_gcd", "python_gcd", "wasm_gcd", "wasm_gcd_alt":
gcdf = by_name[name]
start_time = time.perf_counter()
for _ in range(N):
Expand Down
19 changes: 11 additions & 8 deletions examples/simd_i8x16.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
how to call v128 SMID operations
how to call v128 SIMD operations
for more details see https://github.com/WebAssembly/simd/blob/main/proposals/simd/SIMD.md#integer-addition
"""
import ctypes
Expand All @@ -8,9 +8,10 @@
from wasmtime import Store, Module, Instance



store = Store()
module = Module(store.engine, """
module = Module(
store.engine,
"""
(module
(func $add_v128 (param $a v128) (param $b v128) (result v128)
local.get $a
Expand All @@ -19,12 +20,14 @@
)
(export "add_v128" (func $add_v128))
)
""")
""",
)

instance = Instance(store, module, [])
vector_type = ctypes.c_uint8*16
vector_type = ctypes.c_uint8 * 16
add_v128 = partial(instance.exports(store)["add_v128"], store)
a=vector_type(*(i for i in range(16)))
b=vector_type(*(40+i for i in range(16)))
c=add_v128(a, b)
a = vector_type(*(i for i in range(16)))
b = vector_type(*(40 + i for i in range(16)))
c = add_v128(a, b)
print([v for v in c])
print([v for v in c] == [i + j for i, j in zip(a, b)])
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
addopts = --doctest-modules --flake8 --mypy
addopts = --doctest-modules
;addopts = --doctest-modules --flake8 --mypy
27 changes: 27 additions & 0 deletions tests/test_func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
import ctypes

from functools import partial
from wasmtime import *


Expand All @@ -17,6 +19,31 @@ def test_add(self):
func = Func(store, ty, lambda a, b: a + b)
self.assertEqual(func(store, 1, 2), 3)

def test_simd_i8x16_add(self):
# i8x16.add is SIMD 128-bit vector of i8 items of size 16
store = Store()
module = Module(
store.engine,
"""
(module
(func $add_v128 (param $a v128) (param $b v128) (result v128)
local.get $a
local.get $b
i8x16.add
)
(export "add_v128" (func $add_v128))
)
""",
)

instance = Instance(store, module, [])
vector_type = ctypes.c_uint8 * 16
add_v128 = partial(instance.exports(store)["add_v128"], store)
a = vector_type(*(i for i in range(16)))
b = vector_type(*(40 + i for i in range(16)))
c = add_v128(a, b)
self.assertEqual([v for v in c], [i + j for i, j in zip(a, b)])

def test_calls(self):
store = Store()
ty = FuncType([ValType.i32()], [])
Expand Down
9 changes: 4 additions & 5 deletions wasmtime/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def _create_raw_vals(self, *params: IntoVal) -> ctypes.Array[wasmtime_val_raw_t]
for i, param_str in enumerate(self._params_str):
val_setter(raw[i], param_str, params[i])
return raw

def _extract_return(self, vals_raw: ctypes.Array[wasmtime_val_raw_t]) -> Union[IntoVal, Sequence[IntoVal], None]:
if self._results_n==0:
if self._results_n == 0:
return None
if self._results_n==1:
if self._results_n == 1:
return val_getter(self._func.store_id, vals_raw[0], self._results_str0)
# we can use tuple construct, but I'm using list for compatability
return [val_getter(self._func.store_id, val_raw, ret_str) for val_raw, ret_str in zip(vals_raw, self._results_str)]
Expand All @@ -91,8 +91,7 @@ def _init_call(self, ty: FuncType):
self._params_n = params_n
self._results_n = results_n
n = max(params_n, results_n)
self._vals_raw_type = wasmtime_val_raw_t*n

self._vals_raw_type = wasmtime_val_raw_t * n

def __call__(self, store: Storelike, *params: IntoVal) -> Union[IntoVal, Sequence[IntoVal], None]:
"""
Expand Down
27 changes: 17 additions & 10 deletions wasmtime/_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
WASM_ANYREF.value: 'externref',
}


@ctypes.CFUNCTYPE(None, c_void_p)
def _externref_finalizer(extern_id: int) -> None:
Val._id_to_ref_count[extern_id] -= 1
Expand All @@ -35,35 +36,41 @@ def _intern(obj: typing.Any) -> c_void_p:
def _unintern(val: int) -> typing.Any:
return Val._id_to_extern.get(val)


def get_valtype_attr(ty: ValType) -> str:
return val_id2attr[wasm_valtype_kind(ty._ptr)]


def val_getter(store_id: int, val_raw: wasmtime_val_raw_t, attr: str) -> typing.Union[int, float, "wasmtime.Func", typing.Any]:
val = getattr(val_raw, attr)
if attr=='externref':

if attr == 'externref':
ptr = ctypes.POINTER(wasmtime_externref_t)
if not val: return None
if not val:
return None
ffi = ptr.from_address(val)
if not ffi: return None
if not ffi:
return None
extern_id = wasmtime_externref_data(ffi)
return _unintern(extern_id)
elif attr=='funcref':
if val==0: return None
elif attr == 'funcref':
if val == 0:
return None
f = wasmtime_func_t()
f.store_id = store_id
f.index = val
return wasmtime.Func._from_raw(f)
return val


def val_setter(dst: wasmtime_val_raw_t, attr: str, val: "IntoVal"):
if attr=='externref':
if isinstance(val, Val) and val._raw.kind==WASMTIME_EXTERNREF.value:
if attr == 'externref':
if isinstance(val, Val) and val._raw.kind == WASMTIME_EXTERNREF.value:
casted = ctypes.addressof(val._raw.of.externref)
else:
casted = ctypes.addressof(Val.externref(val)._raw.of.externref)
elif attr=='funcref':
if isinstance(val, Val) and val._raw.kind==WASMTIME_FUNCREF.value:
elif attr == 'funcref':
if isinstance(val, Val) and val._raw.kind == WASMTIME_FUNCREF.value:
casted = val._raw.of.funcref.index
elif isinstance(val, wasmtime.Func):
# TODO: validate same val._func.store_id
Expand Down

0 comments on commit 9122bb4

Please sign in to comment.