Skip to content

Commit

Permalink
FIXES bytecodealliance#137: make calling wasm from python 7x faster
Browse files Browse the repository at this point in the history
  • Loading branch information
muayyad-alsadi committed Mar 25, 2023
1 parent 4a52ebb commit c724ce1
Showing 1 changed file with 45 additions and 33 deletions.
78 changes: 45 additions & 33 deletions wasmtime/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Callable, Optional, Generic, TypeVar, List, Union, Tuple, cast as cast_type, Sequence
from ._exportable import AsExtern
from ._store import Storelike

from ._bindings import wasmtime_val_raw_t

T = TypeVar('T')
FUNCTIONS: "Slab[Tuple]"
Expand All @@ -27,11 +27,11 @@ def __init__(self, store: Storelike, ty: FuncType, func: Callable, access_caller
set to `True` then the first argument given to `func` is an instance of
type `Caller` below.
"""

if not isinstance(store, Store):
raise TypeError("expected a Store")
if not isinstance(ty, FuncType):
raise TypeError("expected a FuncType")
self._func_call_init(ty)
idx = FUNCTIONS.allocate((func, ty.results, access_caller))
_func = ffi.wasmtime_func_t()
ffi.wasmtime_func_new(
Expand All @@ -56,6 +56,33 @@ def type(self, store: Storelike) -> FuncType:
ptr = ffi.wasmtime_func_type(store._context, byref(self._func))
return FuncType._from_ptr(ptr, None)

def _func_call_init(self, ty):
self._ty = ty
ty_params = ty.params
ty_results = ty.results
self._params_str = (str(i) for i in ty_params)
self._results_str = (str(i) for i in ty_results)
params_n = len(ty_params)
results_n = len(ty_results)
self._params_n = params_n
self._results_n = results_n
n = max(params_n, results_n)
self._vals_raw_type = wasmtime_val_raw_t*n

def _create_raw_vals(self, *params: IntoVal) -> ctypes.Array[wasmtime_val_raw_t]:
raw = self._vals_raw_type()
for i, param_str in enumerate(self._params_str):
setattr(raw[i], param_str, params[i])
return raw

def _extract_return(self, vals_raw: ctypes.Array[wasmtime_val_raw_t]) -> Union[IntoVal, Sequence[IntoVal], None]:
if self._results_n==0:
return None
if self._results_n==1:
return getattr(vals_raw[0], self._results_str[0])
# we can use tuple construct, but I'm using list for compatability
return [getattr(val_raw, ret_str) for val_raw, ret_str in zip(vals_raw, self._results_str)]

def __call__(self, store: Storelike, *params: IntoVal) -> Union[IntoVal, Sequence[IntoVal], None]:
"""
Calls this function with the given parameters
Expand All @@ -70,45 +97,30 @@ def __call__(self, store: Storelike, *params: IntoVal) -> Union[IntoVal, Sequenc
Note that you can also use the `__call__` method and invoke a `Func` as
if it were a function directly.
"""

ty = self.type(store)
param_tys = ty.params
if len(params) > len(param_tys):
params_n = len(params)
if params_n > self._params_n:
raise WasmtimeError("too many parameters provided: given %s, expected %s" %
(len(params), len(param_tys)))
if len(params) < len(param_tys):
(params_n, self._params_n))
if params_n < self._params_n:
raise WasmtimeError("too few parameters provided: given %s, expected %s" %
(len(params), len(param_tys)))

param_vals = [Val._convert(ty, params[i]) for i, ty in enumerate(param_tys)]
params_ptr = (ffi.wasmtime_val_t * len(params))()
for i, val in enumerate(param_vals):
params_ptr[i] = val._unwrap_raw()

result_tys = ty.results
results_ptr = (ffi.wasmtime_val_t * len(result_tys))()

(params_n, self._params_n))
vals_raw = self._create_raw_vals(*params)
vals_raw_ptr = ctypes.cast(vals_raw, ctypes.POINTER(wasmtime_val_raw_t))
# according to https://docs.wasmtime.dev/c-api/func_8h.html#a3b54596199641a8647a7cd89f322966f
# it's safe to call wasmtime_func_call_unchecked because
# - we allocate enough space to hold all the parameters and all the results
# - we set proper types
# - but not sure about "Values such as externref and funcref are valid within the store being called"
with enter_wasm(store) as trap:
error = ffi.wasmtime_func_call(
error = None
ffi.wasmtime_func_call_unchecked(
store._context,
byref(self._func),
params_ptr,
len(params),
results_ptr,
len(result_tys),
vals_raw_ptr,
trap)
if error:
raise WasmtimeError._from_ptr(error)

results = []
for i in range(0, len(result_tys)):
results.append(Val(results_ptr[i]).value)
if len(results) == 0:
return None
elif len(results) == 1:
return results[0]
else:
return results
return self._extract_return(vals_raw)

def _as_extern(self) -> ffi.wasmtime_extern_t:
union = ffi.wasmtime_extern_union(func=self._func)
Expand Down

0 comments on commit c724ce1

Please sign in to comment.