Skip to content

Commit

Permalink
FIXES bytecodealliance#137: make calling wasm from python 7x faster
Browse files Browse the repository at this point in the history
  • Loading branch information
muayyad-alsadi committed Apr 2, 2023
1 parent d383a05 commit 383cad6
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions wasmtime/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Callable, Optional, Generic, TypeVar, List, Union, Tuple, cast as cast_type, Sequence
from ._exportable import AsExtern
from ._store import Storelike
from ._bindings import wasmtime_val_raw_t, wasm_valtype_kind
from ._bindings import wasmtime_val_raw_t, wasm_valtype_kind, wasmtime_val_t, wasmtime_externref_t, wasmtime_func_t
from ._value import _unintern
from ._ffi import (
WASMTIME_I32,
WASMTIME_I64,
Expand All @@ -18,6 +19,7 @@
WASMTIME_EXTERNREF,
WASM_ANYREF,
WASM_FUNCREF,
wasmtime_externref_data,
)


Expand All @@ -40,11 +42,41 @@
def get_valtype_attr(ty: ValType):
return val_id2attr[wasm_valtype_kind(ty._ptr)]

from struct import Struct

def val_getter(store_id, val_raw, attr):
val = getattr(val_raw, attr)

if attr=='externref':
ptr = ctypes.POINTER(wasmtime_externref_t)
if not val: return None
ffi = ptr.from_address(val)
if not ffi: return
extern_id = wasmtime_externref_data(ffi)
ret = _unintern(extern_id)
return ret
elif attr=='funcref':
if val==0: return None
f=wasmtime_func_t()
f.store_id=store_id
f.index=val
ret=Func._from_raw(f)
return ret
return val

def val_setter(dst, attr, val):
if attr=='externref':
# TODO: handle None
v = Val.externref(val)
casted = ctypes.addressof(v._raw.of.externref)
if isinstance(val, Val) and val._raw.kind==WASMTIME_EXTERNREF.value:
if val._raw.of.externref:
extern_id = wasmtime_externref_data(val._raw.of.externref)
casted = ctypes.addressof(val._raw.of.externref)
else:
v = Val.externref(val)
casted = ctypes.addressof(v._raw.of.externref)
elif attr=='funcref':
if isinstance(val, Val) and val._raw.kind==WASMTIME_FUNCREF.value:
casted = val._raw.of.funcref.index
else: raise RuntimeError("foo")
elif isinstance(val, Func):
# TODO: handle null_funcref
# TODO: validate same val._func.store_id
Expand Down Expand Up @@ -112,9 +144,10 @@ def _extract_return(self, vals_raw: ctypes.Array[wasmtime_val_raw_t]) -> Union[I
if self._results_n==0:
return None
if self._results_n==1:
return getattr(vals_raw[0], self._results_str0)
ret = val_getter(self._func.store_id, vals_raw[0], self._results_str0)
return ret
# we can use tuple construct, but I'm using list for compatability
return [getattr(val_raw, ret_str) for val_raw, ret_str in zip(vals_raw, self._results_str)]
return [val_getter(self._func.store_id, val_raw, ret_str) for val_raw, ret_str in zip(vals_raw, self._results_str)]

def _init_call(self, ty: FuncType):
"""init signature properties used by call"""
Expand All @@ -123,8 +156,8 @@ def _init_call(self, ty: FuncType):
ty_results = ty.results
params_n = len(ty_params)
results_n = len(ty_results)
self._params_str = (get_valtype_attr(i) for i in ty_params)
self._results_str = (get_valtype_attr(i) for i in ty_results)
self._params_str = [get_valtype_attr(i) for i in ty_params]
self._results_str = [get_valtype_attr(i) for i in ty_results]
self._results_str0 = get_valtype_attr(ty_results[0]) if results_n else None
self._params_n = params_n
self._results_n = results_n
Expand Down

0 comments on commit 383cad6

Please sign in to comment.