From aae0971ef15884d9ac6c0019dfe938472c24d567 Mon Sep 17 00:00:00 2001 From: Nick Fitzgerald Date: Mon, 18 Sep 2023 13:59:09 -0700 Subject: [PATCH] Get inlining into an inlined function working --- src/optimize.rs | 49 ++++++-- src/profile.rs | 4 +- tests/all/optimize.rs | 283 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 322 insertions(+), 14 deletions(-) diff --git a/src/optimize.rs b/src/optimize.rs index 07440c2..f01fccd 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -131,6 +131,7 @@ impl Optimizer { num_imported_funcs: 0, types: vec![], funcs: vec![], + call_site_offsets: vec![], tables: TablesInfo::default(), func_bodies: vec![], }; @@ -315,7 +316,7 @@ impl Optimizer { wasm = &wasm[consumed..]; } - let mut new_code_section = Some(self.optimize_func_bodies(&context)?); + let mut new_code_section = Some(self.optimize_func_bodies(&mut context)?); log::trace!("Building final optimized module"); let mut module = wasm_encoder::Module::new(); @@ -339,7 +340,21 @@ impl Optimizer { Ok(module.finish()) } - fn optimize_func_bodies(&self, context: &OptimizeContext) -> Result { + fn optimize_func_bodies( + &self, + context: &mut OptimizeContext, + ) -> Result { + let mut call_site_index = 0; + for body in context.func_bodies.iter() { + context.call_site_offsets.push(call_site_index); + for op in body.get_operators_reader()?.into_iter() { + match op? { + wasmparser::Operator::CallIndirect { .. } => call_site_index += 1, + _ => {} + } + } + } + let mut new_code_section = wasm_encoder::CodeSection::new(); for (defined_func_index, body) in context.func_bodies.iter().cloned().enumerate() { let func_type = context.funcs[defined_func_index]; @@ -379,9 +394,6 @@ impl Optimizer { // The instructions making up the new body of the optimized function. let mut new_insts: Vec = vec![]; - // The index of the current `call_indirect` site. - let mut call_site_index = 0; - // Stack of functions to copy over to the new, optimized function. The // root is the original function itself and any subsequent entries are // being inlined into it. As we find a `call_indirect` that we'd like to @@ -394,6 +406,8 @@ impl Optimizer { .into_iter_with_offsets() .peekable(); StackEntry { + call_site_index: context.call_site_offsets + [usize::try_from(defined_func_index).unwrap()], defined_func_index, locals_delta: 0, func_body, @@ -466,6 +480,7 @@ impl Optimizer { table_index, table_byte: _, } => { + entry.call_site_index += 1; if let Some(new_entry) = self.try_enqueue_for_winlining( context, &on_stack, @@ -473,10 +488,11 @@ impl Optimizer { &mut num_locals, &mut new_insts, temp_callee_local, - call_site_index, + entry.call_site_index - 1, table_index, type_index, )? { + on_stack.insert(new_entry.defined_func_index); stack.push(new_entry); } else { new_insts.push(CowInst::Owned(wasm_encoder::Instruction::CallIndirect { @@ -484,8 +500,6 @@ impl Optimizer { table: table_index, })); } - - call_site_index += 1; } // `local.{get,set,tee}` instruction's need their local index adjusted. @@ -552,7 +566,7 @@ impl Optimizer { type_index: u32, ) -> Result>> { // If we haven't already reached our maximum inlining depth... - if on_stack.len() >= self.max_inline_depth { + if (on_stack.len() - 1) >= self.max_inline_depth { return Ok(None); } @@ -664,8 +678,17 @@ impl Optimizer { new_insts.push(CowInst::Owned(wasm_encoder::Instruction::LocalSet(local))); } + // Finally, create any additional locals that the callee function needs. + for l in func_body.get_locals_reader()?.into_iter() { + let (count, ty) = l?; + *num_locals += count; + locals.push((count, crate::convert::val_type(ty))); + } + Ok(Some(StackEntry { defined_func_index, + call_site_index: context.call_site_offsets + [usize::try_from(defined_func_index).unwrap()], locals_delta, func_body, ops, @@ -695,6 +718,11 @@ struct OptimizeContext<'a, 'b> { /// A map from defined function index to type. funcs: Vec, + /// A map from defined function index to the call site index offset for that + /// function (i.e. the count of how many `call_indirect` instructions + /// appeared in the code section before this function body). + call_site_offsets: Vec, + /// The static information we have about the tables present in the module. tables: TablesInfo, @@ -715,6 +743,9 @@ struct StackEntry<'a> { /// The defined function index of the function we are currently inlining. defined_func_index: u32, + /// The current `call_indirect` index we are processing. + call_site_index: u32, + /// The delta to apply to all `local.{get,set,tee}` instructions when /// inlining this function body. locals_delta: u32, diff --git a/src/profile.rs b/src/profile.rs index ddd6695..28fac56 100644 --- a/src/profile.rs +++ b/src/profile.rs @@ -57,7 +57,7 @@ use anyhow::{anyhow, ensure, Context, Result}; /// serde_json::to_writer(file, &my_profile)?; /// # Ok(()) } /// ``` -#[derive(Clone, Default)] +#[derive(Clone, Debug, Default)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Profile { // Per-call site profiling information. @@ -68,7 +68,7 @@ pub struct Profile { pub(crate) call_sites: BTreeMap, } -#[derive(Clone, Default)] +#[derive(Clone, Debug, Default)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct CallSiteProfile { // The total count of indirect calls for this call site. diff --git a/tests/all/optimize.rs b/tests/all/optimize.rs index e343a22..cb7291c 100644 --- a/tests/all/optimize.rs +++ b/tests/all/optimize.rs @@ -665,13 +665,290 @@ fn multiple_funcref_tables() -> Result<()> { ) } -// TODO FITZGEN: inline a call site with call sites that we want to inline inside it +#[test] +fn inlining_into_inlined_function() -> Result<()> { + assert_optimize( + Optimizer::new().min_total_calls(1_000).max_inline_depth(2), + &[&[(0, 1_000)], &[(1, 1_000)]], + r#" +(module + (type (func (result i32))) + (type (func (param i32) (result i32))) + + (func (type 0) + i32.const 11 + ) + + (func (type 1) + local.get 0 + call_indirect (type 0) + ) + + (func (param i32 i32) (result i32) + local.get 0 + local.get 1 + call_indirect (type 1) + ) + + (table 100 100 funcref) + (elem (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + r#" +(module + (type (;0;) (func (result i32))) + (type (;1;) (func (param i32) (result i32))) + (type (;2;) (func (param i32 i32) (result i32))) + (func (;0;) (type 0) (result i32) + (local i32) + i32.const 11 + ) + (func (;1;) (type 1) (param i32) (result i32) + (local i32) + local.get 0 + local.tee 1 + i32.const 0 + i32.eq + if (type 0) (result i32) ;; label = @1 + i32.const 11 + else + local.get 1 + call_indirect (type 0) + end + ) + (func (;2;) (type 2) (param i32 i32) (result i32) + (local i32 i32) + local.get 0 + local.get 1 + local.tee 2 + i32.const 1 + i32.eq + if (type 1) (param i32) (result i32) ;; label = @1 + local.set 3 + local.get 3 + local.tee 2 + i32.const 0 + i32.eq + if (type 0) (result i32) ;; label = @2 + i32.const 11 + else + local.get 2 + call_indirect (type 0) + end + else + local.get 2 + call_indirect (type 1) + end + ) + (table (;0;) 100 100 funcref) + (elem (;0;) (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + ) +} // TODO FITZGEN: too much inline depth +#[test] +fn reach_inline_depth_limit() -> Result<()> { + assert_optimize( + Optimizer::new().min_total_calls(1_000).max_inline_depth(1), + &[&[(0, 1_000)], &[(1, 1_000)]], + r#" +(module + (type (func (result i32))) + (type (func (param i32) (result i32))) + + (func (type 0) + i32.const 11 + ) + + (func (type 1) + local.get 0 + call_indirect (type 0) + ) + + (func (param i32 i32) (result i32) + local.get 0 + local.get 1 + call_indirect (type 1) + ) + + (table 100 100 funcref) + (elem (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + r#" +(module + (type (;0;) (func (result i32))) + (type (;1;) (func (param i32) (result i32))) + (type (;2;) (func (param i32 i32) (result i32))) + (func (;0;) (type 0) (result i32) + (local i32) + i32.const 11 + ) + (func (;1;) (type 1) (param i32) (result i32) + (local i32) + local.get 0 + local.tee 1 + i32.const 0 + i32.eq + if (type 0) (result i32) ;; label = @1 + i32.const 11 + else + local.get 1 + call_indirect (type 0) + end + ) + (func (;2;) (type 2) (param i32 i32) (result i32) + (local i32 i32) + local.get 0 + local.get 1 + local.tee 2 + i32.const 1 + i32.eq + if (type 1) (param i32) (result i32) ;; label = @1 + local.set 3 + local.get 3 + call_indirect (type 0) + else + local.get 2 + call_indirect (type 1) + end + ) + (table (;0;) 100 100 funcref) + (elem (;0;) (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + ) +} -// TODO FITZGEN: indirect recursion +#[test] +fn mutual_recursion() -> Result<()> { + assert_optimize( + Optimizer::new().min_total_calls(100).max_inline_depth(100), + &[&[(1, 100)], &[(0, 100)]], + r#" +(module + (type (func (param i32 i32) (result i32))) -// TODO FITZGEN: inline function with multiple locals + (func (type 0) + local.get 0 + local.get 1 + local.get 0 + call_indirect (type 0) + ) + + (func (type 0) + local.get 0 + local.get 1 + local.get 1 + call_indirect (type 0) + ) + + (table 100 100 funcref) + (elem (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + r#" +(module + (type (;0;) (func (param i32 i32) (result i32))) + (func (;0;) (type 0) (param i32 i32) (result i32) + (local i32 i32 i32) + local.get 0 + local.get 1 + local.get 0 + local.tee 2 + i32.const 1 + i32.eq + if (type 0) (param i32 i32) (result i32) ;; label = @1 + local.set 4 + local.set 3 + local.get 3 + local.get 4 + local.get 4 + call_indirect (type 0) + else + local.get 2 + call_indirect (type 0) + end + ) + (func (;1;) (type 0) (param i32 i32) (result i32) + (local i32 i32 i32) + local.get 0 + local.get 1 + local.get 1 + local.tee 2 + i32.const 0 + i32.eq + if (type 0) (param i32 i32) (result i32) ;; label = @1 + local.set 4 + local.set 3 + local.get 3 + local.get 4 + local.get 3 + call_indirect (type 0) + else + local.get 2 + call_indirect (type 0) + end + ) + (table (;0;) 100 100 funcref) + (elem (;0;) (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + ) +} + +#[test] +fn inline_a_function_with_many_locals() -> Result<()> { + assert_optimize( + Optimizer::new().min_total_calls(100), + &[&[(0, 100)]], + r#" +(module + (type (func (result i32))) + + (func (type 0) (local i32 i64 f32 f64 v128 externref funcref) + local.get 0 + ) + + (func (param i32) (result i32) + (local funcref externref v128 f64 f32 i64 i32) + local.get 0 + call_indirect (type 0) + ) + + (table 100 100 funcref) + (elem (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + r#" +(module + (type (;0;) (func (result i32))) + (type (;1;) (func (param i32) (result i32))) + (func (;0;) (type 0) (result i32) + (local i32 i64 f32 f64 v128 externref funcref i32) + local.get 0 + ) + (func (;1;) (type 1) (param i32) (result i32) + (local funcref externref v128 f64 f32 i64 i32 i32 i32 i64 f32 f64 v128 externref funcref) + local.get 0 + local.tee 8 + i32.const 0 + i32.eq + if (type 0) (result i32) ;; label = @1 + local.get 9 + else + local.get 8 + call_indirect (type 0) + end + ) + (table (;0;) 100 100 funcref) + (elem (;0;) (i32.const 0) funcref (ref.func 0) (ref.func 1)) +) + "#, + ) +} // TODO FITZGEN: probes for speculative hit/miss count