diff --git a/CHANGELOG.md b/CHANGELOG.md index a660685..e52cdc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed archive **no_std** support which was broken in the previous release, and added more tests to ensure it stays working +- Check returns in untyped host functions ([#27](https://github.com/explodingcamera/tinywasm/pull/27)) (thanks [@WhaleKit](https://github.com/WhaleKit)) ## [0.8.0] - 2024-08-29 diff --git a/Cargo.lock b/Cargo.lock index 64d907b..1615111 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -698,6 +698,7 @@ dependencies = [ "tinywasm-types", "wasm-testsuite", "wast", + "wat", ] [[package]] diff --git a/crates/tinywasm/Cargo.toml b/crates/tinywasm/Cargo.toml index 2d5c6dd..9bf9808 100644 --- a/crates/tinywasm/Cargo.toml +++ b/crates/tinywasm/Cargo.toml @@ -23,6 +23,7 @@ libm={version="0.2", default-features=false} wasm-testsuite={version="0.3.3"} indexmap="2.7" wast={workspace=true} +wat={workspace=true} eyre={workspace=true} pretty_env_logger={workspace=true} criterion={workspace=true} diff --git a/crates/tinywasm/src/error.rs b/crates/tinywasm/src/error.rs index ebc71af..e0510b2 100644 --- a/crates/tinywasm/src/error.rs +++ b/crates/tinywasm/src/error.rs @@ -1,4 +1,5 @@ use alloc::string::{String, ToString}; +use alloc::vec::Vec; use core::{fmt::Display, ops::ControlFlow}; use tinywasm_types::FuncType; @@ -20,8 +21,13 @@ pub enum Error { /// An unknown error occurred Other(String), - /// A function did not return a value - FuncDidNotReturn, + /// A host function returned an invalid value + InvalidHostFnReturn { + /// The expected type + expected: FuncType, + /// The actual value + actual: Vec, + }, /// An invalid label type was encountered InvalidLabelType, @@ -183,7 +189,9 @@ impl Display for Error { Self::InvalidLabelType => write!(f, "invalid label type"), Self::Other(message) => write!(f, "unknown error: {message}"), Self::UnsupportedFeature(feature) => write!(f, "unsupported feature: {feature}"), - Self::FuncDidNotReturn => write!(f, "function did not return"), + Self::InvalidHostFnReturn { expected, actual } => { + write!(f, "invalid host function return: expected={expected:?}, actual={actual:?}") + } Self::InvalidStore => write!(f, "invalid store"), } } diff --git a/crates/tinywasm/src/func.rs b/crates/tinywasm/src/func.rs index 7035109..47a2cdf 100644 --- a/crates/tinywasm/src/func.rs +++ b/crates/tinywasm/src/func.rs @@ -51,9 +51,9 @@ impl FuncHandle { let func_inst = store.get_func(self.addr); let wasm_func = match &func_inst.func { Function::Host(host_func) => { - let func = &host_func.clone().func; + let host_func = host_func.clone(); let ctx = FuncContext { store, module_addr: self.module_addr }; - return (func)(ctx, params); + return host_func.call(ctx, params); } Function::Wasm(wasm_func) => wasm_func, }; diff --git a/crates/tinywasm/src/imports.rs b/crates/tinywasm/src/imports.rs index 7cc5bbb..0797470 100644 --- a/crates/tinywasm/src/imports.rs +++ b/crates/tinywasm/src/imports.rs @@ -139,7 +139,26 @@ impl Extern { ty: &tinywasm_types::FuncType, func: impl Fn(FuncContext<'_>, &[WasmValue]) -> Result> + 'static, ) -> Self { - Self::Function(Function::Host(Rc::new(HostFunction { func: Box::new(func), ty: ty.clone() }))) + let _ty = ty.clone(); + let inner_func = move |ctx: FuncContext<'_>, args: &[WasmValue]| -> Result> { + let _ty = _ty.clone(); + let result = func(ctx, args)?; + + if result.len() != _ty.results.len() { + return Err(crate::Error::InvalidHostFnReturn { expected: _ty.clone(), actual: result }); + }; + + result.iter().zip(_ty.results.iter()).try_for_each(|(val, ty)| { + if val.val_type() != *ty { + return Err(crate::Error::InvalidHostFnReturn { expected: _ty.clone(), actual: result.clone() }); + } + Ok(()) + })?; + + Ok(result) + }; + + Self::Function(Function::Host(Rc::new(HostFunction { func: Box::new(inner_func), ty: ty.clone() }))) } /// Create a new typed function import diff --git a/crates/tinywasm/src/interpreter/executor.rs b/crates/tinywasm/src/interpreter/executor.rs index 2b64942..f05ae30 100644 --- a/crates/tinywasm/src/interpreter/executor.rs +++ b/crates/tinywasm/src/interpreter/executor.rs @@ -331,7 +331,7 @@ impl<'store, 'stack> Executor<'store, 'stack> { let func = &host_func.clone(); let params = self.stack.values.pop_params(&host_func.ty.params); let res = - (func.func)(FuncContext { store: self.store, module_addr: self.module.id() }, ¶ms).to_cf()?; + func.call(FuncContext { store: self.store, module_addr: self.module.id() }, ¶ms).to_cf()?; self.stack.values.extend_from_wasmvalues(&res); self.cf.incr_instr_ptr(); return ControlFlow::Continue(()); @@ -370,7 +370,7 @@ impl<'store, 'stack> Executor<'store, 'stack> { let host_func = host_func.clone(); let params = self.stack.values.pop_params(&host_func.ty.params); let res = - match (host_func.func)(FuncContext { store: self.store, module_addr: self.module.id() }, ¶ms) { + match host_func.call(FuncContext { store: self.store, module_addr: self.module.id() }, ¶ms) { Ok(res) => res, Err(e) => return ControlFlow::Break(Some(e)), }; diff --git a/crates/tinywasm/tests/host_func_signature_check.rs b/crates/tinywasm/tests/host_func_signature_check.rs new file mode 100644 index 0000000..787b24a --- /dev/null +++ b/crates/tinywasm/tests/host_func_signature_check.rs @@ -0,0 +1,173 @@ +use eyre::Result; +use std::fmt::Write; +use tinywasm::{ + types::{FuncType, ValType, WasmValue}, + Extern, FuncContext, Imports, Module, Store, +}; + +const VAL_LISTS: &[&[WasmValue]] = &[ + &[], + &[WasmValue::I32(0)], + &[WasmValue::I32(0), WasmValue::I32(0)], // 2 of the same + &[WasmValue::I32(0), WasmValue::I32(0), WasmValue::F64(0.0)], // add another type + &[WasmValue::I32(0), WasmValue::F64(0.0), WasmValue::I32(0)], // reorder + &[WasmValue::RefExtern(0), WasmValue::F64(0.0), WasmValue::I32(0)], // all different types +]; +// (f64, i32, i32) and (f64) can be used to "match_none" + +fn get_type_lists() -> impl Iterator + Clone> + Clone { + VAL_LISTS.iter().map(|l| l.iter().map(WasmValue::val_type)) +} +fn get_modules() -> Vec<(Module, FuncType, Vec)> { + let mut result = Vec::<(Module, FuncType, Vec)>::new(); + let val_and_tys = get_type_lists().zip(VAL_LISTS); + for res_types in get_type_lists() { + for (arg_types, arg_vals) in val_and_tys.clone() { + let ty = FuncType { results: res_types.clone().collect(), params: arg_types.collect() }; + result.push((proxy_module(&ty), ty, arg_vals.to_vec())); + } + } + result +} + +#[test] +fn test_return_invalid_type() -> Result<()> { + // try to return from host functions types that don't match their signatures + let mod_list = get_modules(); + + for (module, func_ty, test_args) in mod_list { + for result_to_try in VAL_LISTS { + println!("trying"); + let mut store = Store::default(); + let mut imports = Imports::new(); + imports + .define("host", "hfn", Extern::func(&func_ty, |_: FuncContext<'_>, _| Ok(result_to_try.to_vec()))) + .unwrap(); + + let instance = module.clone().instantiate(&mut store, Some(imports)).unwrap(); + let caller = instance.exported_func_untyped(&store, "call_hfn").unwrap(); + let res_types_returned = result_to_try.iter().map(WasmValue::val_type); + dbg!(&res_types_returned, &func_ty); + let res_types_expected = &func_ty.results; + let should_succeed = res_types_returned.eq(res_types_expected.iter().cloned()); + // Extern::func that returns wrong type(s) can only be detected when it runs + let call_res = caller.call(&mut store, &test_args); + dbg!(&call_res); + assert_eq!(call_res.is_ok(), should_succeed); + println!("this time ok"); + } + } + Ok(()) +} + +#[test] +fn test_linking_invalid_untyped_func() -> Result<()> { + // try to import host functions with function types no matching those expected by modules + let mod_list = get_modules(); + for (module, actual_func_ty, _) in &mod_list { + for (_, func_ty_to_try, _) in &mod_list { + let tried_fn = Extern::func(func_ty_to_try, |_: FuncContext<'_>, _| panic!("not intended to be called")); + let mut store = Store::default(); + let mut imports = Imports::new(); + imports.define("host", "hfn", tried_fn).unwrap(); + + let should_succeed = func_ty_to_try == actual_func_ty; + let link_res = module.clone().instantiate(&mut store, Some(imports)); + + assert_eq!(link_res.is_ok(), should_succeed); + } + } + Ok(()) +} + +#[test] +fn test_linking_invalid_typed_func() -> Result<()> { + type Existing = (i32, i32, f64); + type NonMatchingOne = f64; + type NonMatchingMul = (f64, i32, i32); + const DONT_CALL: &str = "not meant to be called"; + + // they don't match any signature from get_modules() + #[rustfmt::skip] // to make it table-like + let matching_none= &[ + Extern::typed_func(|_, _: NonMatchingMul| -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: NonMatchingMul| -> tinywasm::Result<()> { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: NonMatchingOne| -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: NonMatchingOne| -> tinywasm::Result<()> { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: Existing | -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: Existing | -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: () | -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: () | -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: NonMatchingOne| -> tinywasm::Result { panic!("{DONT_CALL}") } ), + Extern::typed_func(|_, _: NonMatchingOne| -> tinywasm::Result { panic!("{DONT_CALL}") } ), + ]; + + let mod_list = get_modules(); + for (module, _, _) in mod_list { + for typed_fn in matching_none.clone() { + let mut store = Store::default(); + let mut imports = Imports::new(); + imports.define("host", "hfn", typed_fn).unwrap(); + let link_failure = module.clone().instantiate(&mut store, Some(imports)); + link_failure.expect_err("no func in matching_none list should link to any mod"); + } + } + + // the valid cases are well-checked in other tests + Ok(()) +} + +fn to_name(ty: &ValType) -> &str { + match ty { + ValType::I32 => "i32", + ValType::I64 => "i64", + ValType::F32 => "f32", + ValType::F64 => "f64", + ValType::V128 => "v128", + ValType::RefFunc => "funcref", + ValType::RefExtern => "externref", + } +} + +// make a module with imported function {module:"host", name:"hfn"} that takes specified results and returns specified params +// and 2 wasm functions: call_hfn takes params, passes them to hfn and returns it's results +// and 2 wasm functions: call_hfn_discard takes params, passes them to hfn and drops it's results +fn proxy_module(func_ty: &FuncType) -> Module { + let results = func_ty.results.as_ref(); + let params = func_ty.params.as_ref(); + let join_surround = |list: &[ValType], keyword| { + if list.is_empty() { + return "".to_string(); + } + let step = list.iter().map(|ty| format!("{} ", to_name(ty)).to_string()).collect::(); + format!("({keyword} {step})") + }; + + let results_text = join_surround(results, "result"); + let params_text = join_surround(params, "param"); + + // let params_gets: String = params.iter().enumerate().map(|(num, _)| format!("(local.get {num})\n")).collect(); + let params_gets: String = params.iter().enumerate().fold(String::new(), |mut acc, (num, _)| { + let _ = writeln!(acc, "(local.get {num})", num = num); + acc + }); + + let result_drops = "(drop)\n".repeat(results.len()).to_string(); + let wasm_text = format!( + r#"(module + (import "host" "hfn" (func $host_fn {params_text} {results_text})) + (func (export "call_hfn") {params_text} {results_text} + {params_gets} + (call $host_fn) + ) + (func (export "call_hfn_discard") {params_text} + {params_gets} + (call $host_fn) + {result_drops} + ) + ) + "# + ); + let wasm = wat::parse_str(wasm_text).expect("failed to parse wat"); + Module::parse_bytes(&wasm).expect("failed to make module") +}