diff --git a/optd-cli/src/main.rs b/optd-cli/src/main.rs index cdad6e73..8b9cbccf 100644 --- a/optd-cli/src/main.rs +++ b/optd-cli/src/main.rs @@ -34,13 +34,12 @@ use clap::{Parser, Subcommand}; use colored::Colorize; -use optd::catalog::Catalog; use optd::catalog::iceberg::memory_catalog; use optd::dsl::analyzer::hir::{CoreData, HIR, Udf, Value}; use optd::dsl::compile::{Config, compile_hir}; use optd::dsl::engine::{Continuation, Engine, EngineResponse}; use optd::dsl::utils::errors::{CompileError, Diagnose}; -use optd::dsl::utils::retriever::{MockRetriever, Retriever}; +use optd::dsl::utils::retriever::MockRetriever; use std::collections::HashMap; use std::sync::Arc; use tokio::runtime::Builder; @@ -66,22 +65,17 @@ enum Commands { RunFunctions(Config), } -/// A unimplemented user-defined function. -pub fn unimplemented_udf( - _args: &[Value], - _catalog: &dyn Catalog, - _retriever: &dyn Retriever, -) -> Value { - println!("This user-defined function is unimplemented!"); - Value::new(CoreData::::None) -} - fn main() -> Result<(), Vec> { let cli = Cli::parse(); let mut udfs = HashMap::new(); let udf = Udf { - func: unimplemented_udf, + func: Arc::new(|_, _, _| { + Box::pin(async move { + println!("This user-defined function is unimplemented!"); + Value::new(CoreData::::None) + }) + }), }; udfs.insert("unimplemented_udf".to_string(), udf.clone()); diff --git a/optd/src/demo/demo.opt b/optd/src/demo/demo.opt index 2245f82b..b44d5dc1 100644 --- a/optd/src/demo/demo.opt +++ b/optd/src/demo/demo.opt @@ -1,7 +1,8 @@ -data Physical data PhysicalProperties data Statistics -data LogicalProperties +// Taking folded here is not the most interesting property, +// but it ensures they are the same for all expressions in the same group. +data LogicalProperties(folded: I64) data Logical = | Add(left: Logical, right: Logical) @@ -10,6 +11,15 @@ data Logical = | Div(left: Logical, right: Logical) \ Const(val: I64) +data Physical = + | PhysicalAdd(left: Physical, right: Physical) + | PhysicalSub(left: Physical, right: Physical) + | PhysicalMult(left: Physical, right: Physical) + | PhysicalDiv(left: Physical, right: Physical) + \ PhysicalConst(val: I64) + +// This will be the input plan that will be optimized. +// Result is: ((1 - 2) * (3 / 4)) + ((5 - 6) * (7 / 8)) = 0 fn input(): Logical = Add( Mult( @@ -17,14 +27,26 @@ fn input(): Logical = Div(Const(3), Const(4)) ), Mult( - Sub(Const(5), Const(6)), + Sub(Const(1), Const(2)), Div(Const(7), Const(8)) ) ) -// TODO(Alexis): This should be $ really, make costing and derive consistent with each other. +// External function to allow the retrieval of properties. +fn properties(op: Logical*): LogicalProperties + +// FIXME: This should be $ really (or other), make costing and derive consistent with each other. // Also, be careful of not forking in there! And make it a required function in analyzer. -fn derive(op: Logical) = LogicalProperties +fn derive(op: Logical*) = match op + | Add(left, right) -> + LogicalProperties(left.properties()#folded + right.properties()#folded) + | Sub(left, right) -> + LogicalProperties(left.properties()#folded - right.properties()#folded) + | Mult(left, right) -> + LogicalProperties(left.properties()#folded * right.properties()#folded) + | Div(left, right) -> + LogicalProperties(left.properties()#folded / right.properties()#folded) + \ Const(val) -> LogicalProperties(val) [transformation] fn (op: Logical*) mult_commute(): Logical? = match op @@ -55,4 +77,12 @@ fn (op: Logical*) const_fold_sub(): Logical? = match op fn (op: Logical*) const_fold_div(): Logical? = match op | Div(Const(a), Const(b)) -> if b == 0 then none else Const(a / b) - \ _ -> none \ No newline at end of file + \ _ -> none + +[implementation] +fn (op: Logical*) to_physical(props: PhysicalProperties?) = match op + | Add(left, right) -> PhysicalAdd(left.to_physical(props), right.to_physical(props)) + | Sub(left, right) -> PhysicalSub(left.to_physical(props), right.to_physical(props)) + | Mult(left, right) -> PhysicalMult(left.to_physical(props), right.to_physical(props)) + | Div(left, right) -> PhysicalDiv(left.to_physical(props), right.to_physical(props)) + \ Const(val) -> PhysicalConst(val) \ No newline at end of file diff --git a/optd/src/demo/mod.rs b/optd/src/demo/mod.rs index 48f720c4..5098085d 100644 --- a/optd/src/demo/mod.rs +++ b/optd/src/demo/mod.rs @@ -1,21 +1,53 @@ use crate::{ - catalog::iceberg::memory_catalog, + catalog::{Catalog, iceberg::memory_catalog}, dsl::{ - analyzer::hir::Value, + analyzer::hir::{CoreData, LogicalOp, Materializable, Udf, Value}, compile::{Config, compile_hir}, engine::{Continuation, Engine, EngineResponse}, - utils::retriever::MockRetriever, + utils::retriever::{MockRetriever, Retriever}, }, memo::MemoryMemo, - optimizer::{OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical}, + optimizer::{ClientRequest, OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical}, }; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{sync::mpsc, time::timeout}; +use tokio::{ + sync::mpsc, + time::{sleep, timeout}, +}; + +pub async fn properties( + args: Vec, + _catalog: Arc, + retriever: Arc, +) -> Value { + let arg = args[0].clone(); + let group_id = match &arg.data { + CoreData::Logical(Materializable::Materialized(LogicalOp { group_id, .. })) => { + group_id.unwrap() + } + CoreData::Logical(Materializable::UnMaterialized(group_id)) => *group_id, + _ => panic!("Expected a logical plan"), + }; + + retriever.get_properties(group_id).await +} async fn run_demo() { // Compile the HIR. let config = Config::new("src/demo/demo.opt".into()); - let udfs = HashMap::new(); + + // Create a properties UDF. + let properties_udf = Udf { + func: Arc::new(|args, catalog, retriever| { + Box::pin(async move { properties(args, catalog, retriever).await }) + }), + }; + + // Create the UDFs HashMap. + let mut udfs = HashMap::new(); + udfs.insert("properties".to_string(), properties_udf); + + // Compile with the config and UDFs. let hir = compile_hir(config, udfs).unwrap(); // Create necessary components. @@ -35,15 +67,15 @@ async fn run_demo() { let optimize_channel = Optimizer::launch(memo, catalog, hir); let (tx, mut rx) = mpsc::channel(1); optimize_channel - .send(OptimizeRequest { - plan: logical_plan, - physical_tx: tx, - }) + .send(ClientRequest::Optimize(OptimizeRequest { + plan: logical_plan.clone(), + physical_tx: tx.clone(), + })) .await .unwrap(); - // Timeout after 2 seconds. - let timeout_duration = Duration::from_secs(2); + // Timeout after 5 seconds. + let timeout_duration = Duration::from_secs(5); let result = timeout(timeout_duration, async { while let Some(response) = rx.recv().await { println!("Received response: {:?}", response); @@ -55,6 +87,13 @@ async fn run_demo() { Ok(_) => println!("Finished receiving responses."), Err(_) => println!("Timed out after 5 seconds."), } + + // Dump the memo (debug utility). + optimize_channel + .send(ClientRequest::DumpMemo) + .await + .unwrap(); + sleep(Duration::from_secs(10)).await; } #[cfg(test)] diff --git a/optd/src/dsl/analyzer/from_ast/converter.rs b/optd/src/dsl/analyzer/from_ast/converter.rs index 92d311c8..76297c41 100644 --- a/optd/src/dsl/analyzer/from_ast/converter.rs +++ b/optd/src/dsl/analyzer/from_ast/converter.rs @@ -177,13 +177,12 @@ impl ASTConverter { #[cfg(test)] mod converter_tests { use super::*; - use crate::catalog::Catalog; use crate::dsl::analyzer::from_ast::from_ast; use crate::dsl::analyzer::hir::{CoreData, FunKind}; use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind}; use crate::dsl::parser::ast::{self, Adt, Function, Item, Module, Type as AstType}; - use crate::dsl::utils::retriever::Retriever; use crate::dsl::utils::span::{Span, Spanned}; + use std::sync::Arc; // Helper functions to create test items fn create_test_span() -> Span { @@ -382,19 +381,15 @@ mod converter_tests { let ext_func = create_simple_function("external_function", false); let module = create_module_with_functions(vec![ext_func]); - pub fn external_function( - _args: &[Value], - _catalog: &dyn Catalog, - _retriever: &dyn Retriever, - ) -> Value { - println!("Hello from UDF!"); - Value::new(CoreData::::None) - } - // Link the dummy function. let mut udfs = HashMap::new(); let udf = Udf { - func: external_function, + func: Arc::new(|_, _, _| { + Box::pin(async move { + println!("Hello from UDF!"); + Value::new(CoreData::None) + }) + }), }; udfs.insert("external_function".to_string(), udf); @@ -408,13 +403,6 @@ mod converter_tests { // Check that the function is in the context. let func_val = hir.context.lookup("external_function"); assert!(func_val.is_some()); - - // Verify it is the same function pointer. - if let CoreData::Function(FunKind::Udf(udf)) = &func_val.unwrap().data { - assert_eq!(udf.func as usize, external_function as usize); - } else { - panic!("Expected UDF function"); - } } #[test] diff --git a/optd/src/dsl/analyzer/hir/mod.rs b/optd/src/dsl/analyzer/hir/mod.rs index 458e6f8e..bcdf8106 100644 --- a/optd/src/dsl/analyzer/hir/mod.rs +++ b/optd/src/dsl/analyzer/hir/mod.rs @@ -21,7 +21,8 @@ use crate::dsl::utils::retriever::Retriever; use crate::dsl::utils::span::Span; use context::Context; use map::Map; -use std::fmt::Debug; +use std::fmt::{self, Debug}; +use std::pin::Pin; use std::{collections::HashMap, sync::Arc}; pub(crate) mod context; @@ -72,22 +73,30 @@ impl TypedSpan { } } -#[derive(Debug, Clone)] +/// Type aliases for user-defined functions (UDFs). +type UdfFutureOutput = Pin + Send>>; +type UdfFunction = + dyn Fn(Vec, Arc, Arc) -> UdfFutureOutput + Send + Sync; + +#[derive(Clone)] pub struct Udf { - /// The function pointer to the user-defined function. - /// - /// Note that [`Value`]s passed to and returned from this UDF do not have associated metadata. - pub func: fn(&[Value], &dyn Catalog, &dyn Retriever) -> Value, + pub func: Arc, +} + +impl fmt::Debug for Udf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[udf]") + } } impl Udf { - pub fn call( + pub async fn call( &self, values: &[Value], - catalog: &dyn Catalog, - retriever: &dyn Retriever, + catalog: Arc, + retriever: Arc, ) -> Value { - (self.func)(values, catalog, retriever) + (self.func)(values.to_vec(), catalog, retriever).await } } diff --git a/optd/src/dsl/engine/eval/expr.rs b/optd/src/dsl/engine/eval/expr.rs index 616e7d1d..8cded69b 100644 --- a/optd/src/dsl/engine/eval/expr.rs +++ b/optd/src/dsl/engine/eval/expr.rs @@ -588,7 +588,7 @@ impl Engine { Arc::new(move |arg_values| { Box::pin(capture!([udf, catalog, deriver, k], async move { // Call the UDF with the argument values. - let result = udf.call(&arg_values, catalog.as_ref(), deriver.as_ref()); + let result = udf.call(&arg_values, catalog, deriver).await; // Pass the result to the continuation. k(result).await @@ -1143,18 +1143,22 @@ mod tests { // Define a Rust UDF that calculates the sum of array elements let sum_function = Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Array(elements) => { - let mut sum = 0; - for elem in elements { - if let CoreData::Literal(Literal::Int64(value)) = &elem.data { - sum += value; + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Array(elements) => { + let mut sum: i64 = 0; + for elem in elements { + if let CoreData::Literal(Literal::Int64(value)) = &elem.data { + sum += value; + } + } + Value::new(CoreData::Literal(Literal::Int64(sum))) } + _ => panic!("Expected array argument"), } - Value::new(CoreData::Literal(Literal::Int64(sum))) - } - _ => panic!("Expected array argument"), - }, + }) + }), }))); ctx.bind("sum".to_string(), sum_function); @@ -1271,16 +1275,18 @@ mod tests { ctx.bind( "get".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| { - if args.len() != 2 { - panic!("get function requires 2 arguments"); - } + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + if args.len() != 2 { + panic!("get function requires 2 arguments"); + } - match &args[0].data { - CoreData::Map(map) => map.get(&args[1]), - _ => panic!("First argument must be a map"), - } - }, + match &args[0].data { + CoreData::Map(map) => map.get(&args[1]), + _ => panic!("First argument must be a map"), + } + }) + }), }))), ); diff --git a/optd/src/dsl/engine/eval/match.rs b/optd/src/dsl/engine/eval/match.rs index e1554afa..9fc853be 100644 --- a/optd/src/dsl/engine/eval/match.rs +++ b/optd/src/dsl/engine/eval/match.rs @@ -671,12 +671,16 @@ mod tests { ctx.bind( "length".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Array(elements) => { - Value::new(CoreData::Literal(int(elements.len() as i64))) - } - _ => panic!("Expected array"), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Array(elements) => { + Value::new(CoreData::Literal(int(elements.len() as i64))) + } + _ => panic!("Expected array"), + } + }) + }), }))), ); @@ -985,12 +989,16 @@ mod tests { // Create a to_string function let to_string_fn = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Function( FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Literal(lit) => { - Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) - } - _ => Value::new(CoreData::Literal(string(""))), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Literal(lit) => { + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) + } + _ => Value::new(CoreData::Literal(string(""))), + } + }) + }), }), ))))); @@ -1098,12 +1106,16 @@ mod tests { ctx.bind( "to_string".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Literal(Literal::Int64(i)) => { - Value::new(CoreData::Literal(string(&i.to_string()))) - } - _ => panic!("Expected integer literal"), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Literal(Literal::Int64(i)) => { + Value::new(CoreData::Literal(string(&i.to_string()))) + } + _ => panic!("Expected integer literal"), + } + }) + }), }))), ); @@ -1232,13 +1244,17 @@ mod tests { ctx.bind( "to_string".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Literal(lit) => { - Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) - } - CoreData::Array(_) => Value::new(CoreData::Literal(string(""))), - _ => Value::new(CoreData::Literal(string(""))), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Literal(lit) => { + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) + } + CoreData::Array(_) => Value::new(CoreData::Literal(string(""))), + _ => Value::new(CoreData::Literal(string(""))), + } + }) + }), }))), ); diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index c087a6cc..88727ac5 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -4,11 +4,12 @@ use super::{Infallible, MemoryMemo}; use crate::{ cir::*, memo::{ - Materialize, Memo, MergeGroupProduct, MergeProducts, Representative, memory::GroupInfo, + Materialize, Memo, MergeGoalProduct, MergeGroupProduct, MergePhysicalExprProduct, + Representative, memory::GroupInfo, }, }; use hashbrown::{HashMap, HashSet}; -use std::borrow::Cow; +use std::{borrow::Cow, collections::VecDeque}; impl MemoryMemo { pub(super) fn next_group_id(&mut self) -> GroupId { @@ -35,21 +36,25 @@ impl MemoryMemo { PhysicalExpressionId(id) } + /// Takes the set of [`LogicalExpressionId`] that reference a group, mapped to their + /// representatives. + fn take_referencing_expr_set(&mut self, group_id: GroupId) -> HashSet { + self.group_referencing_exprs_index + .remove(&group_id) + .unwrap_or_default() + .iter() + .map(|id| self.repr_logical_expr_id.find(id)) + .collect() + } + /// Merges the two sets of logical expressions that reference the two groups into a single set /// of expressions under a new [`GroupId`]. /// /// If a group does not exist, then the set of expressions referencing it is the empty set. fn merge_referencing_exprs(&mut self, group1: GroupId, group2: GroupId, new_group: GroupId) { // Remove the entries for the original two groups that we want to merge. - let exprs1 = self - .group_referencing_exprs_index - .remove(&group1) - .unwrap_or_default(); - let exprs2 = self - .group_referencing_exprs_index - .remove(&group2) - .unwrap_or_default(); - + let exprs1 = self.take_referencing_expr_set(group1); + let exprs2 = self.take_referencing_expr_set(group2); let new_set = exprs1.union(&exprs2).copied().collect(); // Update the index for the new group / set of logical expressions. @@ -64,11 +69,18 @@ impl MemoryMemo { /// /// See the implementation itself for the documentation of each helper method. pub trait MemoryMemoHelper: Memo { + fn find_repr_goal_member_id(&self, id: GoalMemberId) -> GoalMemberId; + async fn remap_logical_expr<'a>( &self, logical_expr: &'a LogicalExpression, ) -> Result, Infallible>; + async fn remap_physical_expr<'a>( + &self, + physical_expr: &'a PhysicalExpression, + ) -> Result, Infallible>; + async fn remap_goal<'a>(&self, goal: &'a Goal) -> Result, Infallible>; async fn merge_group_pair( @@ -84,20 +96,54 @@ pub trait MemoryMemoHelper: Memo { prev_expr: LogicalExpression, ) -> Result, Infallible>; - async fn process_referencing_expressions( + async fn process_referencing_logical_exprs( &mut self, group_id_1: GroupId, group_id_2: GroupId, new_group_id: GroupId, ) -> Result, Infallible>; - async fn consolidate_merge_results( + async fn consolidate_merge_group_products( &self, merge_operations: Vec, - ) -> Result; + ) -> Result, Infallible>; + + async fn merge_dependent_goals( + &mut self, + group_merges: &[MergeGroupProduct], + ) -> Result, Infallible>; + + async fn consolidate_merge_physical_expr_products( + &self, + merge_operations: Vec, + ) -> Result, Infallible>; + + async fn merge_dependent_physical_exprs( + &mut self, + goal_merges: &[MergeGoalProduct], + ) -> Result, Infallible>; } impl MemoryMemoHelper for MemoryMemo { + /// Finds the representative ID for a given [`GoalMemberId`]. + /// + /// This handles both variants of `GoalMemberId` by finding the appropriate representative + /// based on whether it's a Goal or PhysicalExpr. + fn find_repr_goal_member_id(&self, id: GoalMemberId) -> GoalMemberId { + use GoalMemberId::*; + + match id { + GoalId(goal_id) => { + let repr_goal_id = self.repr_goal_id.find(&goal_id); + GoalId(repr_goal_id) + } + PhysicalExpressionId(expr_id) => { + let repr_expr_id = self.repr_physical_expr_id.find(&expr_id); + PhysicalExpressionId(repr_expr_id) + } + } + } + /// Remaps the children of a logical expression such that they are all identified by their /// representative IDs. /// @@ -147,6 +193,61 @@ impl MemoryMemoHelper for MemoryMemo { }) } + /// Remaps the children of a physical expression such that they are all identified by their + /// representative IDs. + /// + /// For example, if a physical expression has a child goal with [`GoalMemberId`] 3, but the + /// representative of goal 3 is [`GoalMemberId`] 42, then the output expression will be the input + /// physical expression with a child goal of 42. + /// + /// If no remapping needs to occur, this returns the same [`PhysicalExpression`] object via the + /// [`Cow`]. Otherwise, this function will create a new owned [`PhysicalExpression`]. + async fn remap_physical_expr<'a>( + &self, + physical_expr: &'a PhysicalExpression, + ) -> Result, Infallible> { + use Child::*; + + let mut needs_remapping = false; + + let remapped_children = physical_expr + .children + .iter() + .map(|child| match child { + Singleton(goal_id) => { + let repr = self.find_repr_goal_member_id(*goal_id); + if repr != *goal_id { + needs_remapping = true; + } + Singleton(repr) + } + VarLength(goal_ids) => { + let reprs: Vec<_> = goal_ids + .iter() + .map(|id| { + let repr = self.find_repr_goal_member_id(*id); + if repr != *id { + needs_remapping = true; + } + repr + }) + .collect(); + VarLength(reprs) + } + }) + .collect(); + + Ok(if needs_remapping { + Cow::Owned(PhysicalExpression { + tag: physical_expr.tag.clone(), + data: physical_expr.data.clone(), + children: remapped_children, + }) + } else { + Cow::Borrowed(physical_expr) + }) + } + /// Remaps the [`GroupId`] component of a [`Goal`] to its representative [`GroupId`]. /// /// For example, if the [`Goal`] has a [`GroupID`] of 3, but the representative of group 3 is @@ -215,6 +316,18 @@ impl MemoryMemoHelper for MemoryMemo { .copied() .collect(); + // Combine goals from both groups. + let all_goals: HashMap<_, Vec<_>> = group1_info + .goals + .into_iter() + .chain(group2_info.goals) + .fold(HashMap::new(), |mut acc, (props, mut goals)| { + acc.entry(props) + .and_modify(|existing_goals| existing_goals.append(&mut goals)) + .or_insert(goals); + acc + }); + // Update the union-find structure. self.repr_group_id.merge(&group_id_1, &new_group_id); self.repr_group_id.merge(&group_id_2, &new_group_id); @@ -227,6 +340,7 @@ impl MemoryMemoHelper for MemoryMemo { // Create and save the new group. let new_group_info = GroupInfo { expressions: all_exprs, + goals: all_goals, logical_properties: group1_info.logical_properties, }; self.group_info.insert(new_group_id, new_group_info); @@ -306,7 +420,7 @@ impl MemoryMemoHelper for MemoryMemo { } } - /// Processes expressions that reference the merged groups. + /// Processes logical expressions that reference the merged groups. /// /// Returns the group pairs that need to be merged due to new equivalences. /// @@ -329,14 +443,14 @@ impl MemoryMemoHelper for MemoryMemo { /// b. Check if the remapped expression is different from the original. /// c. If different, handle the change and check for new equivalences. /// 4. Return any new group pairs that need to be merged. - async fn process_referencing_expressions( + async fn process_referencing_logical_exprs( &mut self, group_id_1: GroupId, group_id_2: GroupId, new_group_id: GroupId, ) -> Result, Infallible> { - // Merge the set of expressions that reference these two groups into one that references the - // new group. + // Merge the set of expressions that reference these two groups into one + // that references the new group. self.merge_referencing_exprs(group_id_1, group_id_2, new_group_id); // We need to clone here because we are modifying our `self` state inside the loop. @@ -371,16 +485,16 @@ impl MemoryMemoHelper for MemoryMemo { Ok(new_pending_merges) } - /// Consolidates merge operations into a comprehensive result. + /// Consolidates merge group operations into a comprehensive result. /// - /// This function takes a list of individual merge operations and consolidates them into a - /// complete picture of which groups were merged into each final representative. + /// This function takes a list of individual group merge operations and consolidates + /// them into a complete picture of which groups were merged into each final representative. /// /// It also handles cases where past representatives themselves are merged into newer ones. - async fn consolidate_merge_results( + async fn consolidate_merge_group_products( &self, merge_operations: Vec, - ) -> Result { + ) -> Result, Infallible> { // Collect operations into a map from representative to all merged groups. let mut consolidated_map: HashMap> = HashMap::new(); @@ -411,9 +525,200 @@ impl MemoryMemoHelper for MemoryMemo { }) .collect(); - Ok(MergeProducts { - group_merges, - goal_merges: vec![], - }) + Ok(group_merges) + } + + /// Processes goal merges and returns the results. + /// + /// This function performs the merge operations for goals based on the merged groups and + /// returns a list of `MergeGoalProduct` instances representing the merged goals. + /// + /// # Parameters + /// + /// * `group_merges` - A slice of `MergeGroupProduct` instances representing the merged groups. + /// + /// # Returns + /// + /// A vector of `MergeGoalProduct` instances representing the merged goals. + async fn merge_dependent_goals( + &mut self, + group_merges: &[MergeGroupProduct], + ) -> Result, Infallible> { + let mut goal_merges = Vec::new(); + + for merge_product in group_merges { + let new_group_id = merge_product.new_group_id; + let group_info = self.group_info.get_mut(&new_group_id).unwrap(); + + // Take related goals to the group to avoid borrowing issues. + let related_goals = std::mem::take(&mut group_info.goals); + let mut updated_goals = HashMap::new(); + + for (physical_props, merged_goals) in related_goals { + // Create a new representative goal for this physical property. + let new_goal = Goal(new_group_id, physical_props.clone()); + let new_goal_id = self.get_goal_id(&new_goal).await?; + + // Process each goal being merged. + for &old_goal_id in &merged_goals { + // Remove old goal info and extend new goal with its members. + let old_goal_info = self.goal_info.remove(&old_goal_id).unwrap(); + let new_goal_info = self.goal_info.get_mut(&new_goal_id).unwrap(); + + new_goal_info.members.extend(old_goal_info.members); + self.goal_to_id.remove(&old_goal_info.goal); + + // Update referencing expressions index. + let refs = self + .goal_member_referencing_exprs_index + .remove(&GoalMemberId::GoalId(old_goal_id)) + .unwrap_or_default(); + + self.goal_member_referencing_exprs_index + .entry(GoalMemberId::GoalId(new_goal_id)) + .or_default() + .extend(refs); + + // Update the union-find structure for goal representatives. + self.repr_goal_id.merge(&old_goal_id, &new_goal_id); + } + + // Update the goals mapping for this physical property. + updated_goals.insert(physical_props, vec![new_goal_id]); + + // Record the merge operation. + goal_merges.push(MergeGoalProduct { + new_goal_id, + merged_goals, + }); + } + + // Update the group's goals with the consolidated mapping. + self.group_info.get_mut(&new_group_id).unwrap().goals = updated_goals; + } + + Ok(goal_merges) + } + + /// Consolidates merge physical expression operations into a comprehensive result. + /// + /// This function takes a list of individual physical expression merge operations + /// and consolidates them into a complete picture of which physical expressions + /// were merged into each final representative. + /// + /// It also handles cases where past representatives themselves are merged into newer ones. + async fn consolidate_merge_physical_expr_products( + &self, + merge_operations: Vec, + ) -> Result, Infallible> { + // Collect operations into a map from representative to all merged physical expressions. + let mut consolidated_map: HashMap> = + HashMap::new(); + + for op in merge_operations { + let current_repr = self.repr_physical_expr_id.find(&op.new_physical_expr_id); + + consolidated_map + .entry(current_repr) + .or_default() + .extend(op.merged_physical_exprs.iter().copied()); + + if op.new_physical_expr_id != current_repr { + consolidated_map + .entry(current_repr) + .or_default() + .insert(op.new_physical_expr_id); + } + } + + // Build the final list of merge products from the consolidated map. + let physical_expr_merges = consolidated_map + .into_iter() + .filter_map(|(repr, exprs)| { + (!exprs.is_empty()).then(|| MergePhysicalExprProduct { + new_physical_expr_id: repr, + merged_physical_exprs: exprs.into_iter().collect(), + }) + }) + .collect(); + + Ok(physical_expr_merges) + } + + /// Processes physical expression merges and returns the results. + /// + /// This function performs the merge operations for physical expressions based on + /// the merged goals and returns a list of `MergePhysicalExprProduct` instances + /// representing the merged physical expressions. + /// + /// # Parameters + /// + /// * `goal_merges` - A slice of `MergeGoalProduct` instances representing the merged goals. + /// + /// # Returns + /// + /// A vector of `MergePhysicalExprProduct` instances representing the merged physical expressions. + async fn merge_dependent_physical_exprs( + &mut self, + goal_merges: &[MergeGoalProduct], + ) -> Result, Infallible> { + let mut physical_expr_merges = Vec::new(); + let mut pending_dependencies = VecDeque::new(); + + // Initialize pending dependencies with the input goal merges. + for goal_merge in goal_merges { + pending_dependencies.push_back(GoalMemberId::GoalId(goal_merge.new_goal_id)); + } + + // Process dependencies in a loop to handle cascading merges. + while let Some(current_dependency) = pending_dependencies.pop_front() { + let referencing_exprs = self + .goal_member_referencing_exprs_index + .get(¤t_dependency) + .cloned() + .unwrap_or_default(); + + // Process each expression that references this dependency. + for reference_id in referencing_exprs { + // Remap the expression to use updated member references. + let prev_expr = self.materialize_physical_expr(reference_id).await?; + let new_expr = self.remap_physical_expr(&prev_expr).await?; + let new_id = self.get_physical_expr_id(&new_expr).await?; + + if reference_id != new_id { + // Remove old expression from indexes. + self.id_to_physical_expr.remove(&reference_id); + self.physical_expr_to_id.remove(&prev_expr); + + // Update the goal member referencing expressions index + // for the new physical expr. + let old_refs = self + .goal_member_referencing_exprs_index + .remove(&GoalMemberId::PhysicalExpressionId(reference_id)) + .unwrap_or_default(); + self.goal_member_referencing_exprs_index + .entry(GoalMemberId::PhysicalExpressionId(new_id)) + .or_default() + .extend(old_refs); + + // Update the representative ID for the expression. + self.repr_physical_expr_id.merge(&reference_id, &new_id); + + // This handles cascading effects where merging + // one expression affects others. + pending_dependencies.push_back(GoalMemberId::PhysicalExpressionId(new_id)); + + // Record the merge. + physical_expr_merges.push(MergePhysicalExprProduct { + new_physical_expr_id: new_id, + merged_physical_exprs: vec![reference_id], + }); + } + } + } + + // Consolidate the merge products to ensure no duplicates. + self.consolidate_merge_physical_expr_products(physical_expr_merges) + .await } } diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index ecd62309..3f79f2ac 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -1,14 +1,91 @@ //! The main implementation of the in-memory memo table. -use std::collections::VecDeque; - use super::{ Infallible, Memo, MemoryMemo, MergeProducts, Representative, helpers::MemoryMemoHelper, }; use crate::{cir::*, memo::memory::GroupInfo}; -use hashbrown::HashSet; +use hashbrown::{HashMap, HashSet}; +use std::collections::VecDeque; impl Memo for MemoryMemo { + async fn debug_dump(&self) -> Result<(), Infallible> { + println!("\n===== MEMO TABLE DUMP ====="); + println!("---- GROUPS ----"); + + // Get all group IDs and sort them for consistent output. + let mut group_ids: Vec<_> = self.group_info.keys().copied().collect(); + group_ids.sort_by_key(|id| id.0); + + for group_id in group_ids { + println!("Group {:?}:", group_id); + + // Print logical properties. + let group_info = self.group_info.get(&group_id).unwrap(); + println!(" Properties: {:?}", group_info.logical_properties); + + // Print expressions in group. + println!(" Expressions:"); + for expr_id in &group_info.expressions { + let expr = self.id_to_logical_expr.get(expr_id).unwrap(); + let repr_id = self.find_repr_logical_expr_id(*expr_id).await; + println!(" {:?} [{:?}]: {:?}", expr_id, repr_id, expr); + } + + // Print goals in this group (if any). + if !group_info.goals.is_empty() { + println!(" Goals:"); + for (props, goal_ids) in &group_info.goals { + for goal_id in goal_ids { + println!(" Goal {:?} with properties: {:?}", goal_id, props); + } + } + } + } + + println!("---- GOALS ----"); + + // Get all goal IDs and sort them. + let mut goal_ids: Vec<_> = self.goal_info.keys().copied().collect(); + goal_ids.sort_by_key(|id| id.0); + + for goal_id in goal_ids { + println!("Goal {:?}:", goal_id); + + // Print goal definition. + let goal_info = self.goal_info.get(&goal_id).unwrap(); + println!(" Definition: {:?}", goal_info.goal); + + // Print members. + println!(" Members:"); + for member in &goal_info.members { + match member { + GoalMemberId::GoalId(sub_goal_id) => { + println!(" Sub-goal: {:?}", sub_goal_id); + } + GoalMemberId::PhysicalExpressionId(phys_id) => { + let expr = self.id_to_physical_expr.get(phys_id).unwrap(); + println!(" Physical expr {:?}: {:?}", phys_id, expr); + } + } + } + } + + println!("---- PHYSICAL EXPRESSIONS ----"); + + // Get all physical expression IDs and sort them + let mut phys_expr_ids: Vec<_> = self.id_to_physical_expr.keys().copied().collect(); + phys_expr_ids.sort_by_key(|id| id.0); + + for phys_id in phys_expr_ids { + let expr = self.id_to_physical_expr.get(&phys_id).unwrap(); + println!("Physical expr {:?}: {:?}", phys_id, expr); + } + + println!("===== END MEMO TABLE DUMP ====="); + + Ok(()) + } + async fn get_logical_properties( &self, group_id: GroupId, @@ -60,6 +137,7 @@ impl Memo for MemoryMemo { let group_id = self.next_group_id(); let group_info = GroupInfo { expressions: HashSet::from([logical_expr_id]), + goals: HashMap::new(), logical_properties: props.clone(), }; @@ -90,13 +168,33 @@ impl Memo for MemoryMemo { /// 3. When expressions become identical, their containing groups also need to be merged, /// creating a cascading effect. /// + /// 4. After all logical groups have been merged, we need to identify and merge all goals + /// that have become equivalent. Goals are considered equivalent when they reference the + /// same group and share identical logical properties. + /// + /// 5. When goals are merged, this triggers a cascading effect on physical expressions: + /// - Physical expressions that reference merged goals need updating. + /// - This may cause previously distinct physical expressions to become identical. + /// - These newly identical expressions must then be merged. + /// - Each merge may create more identical expressions, continuing the chain reaction + /// until the entire structure is consistent. + /// /// To handle this complexity, we use an iterative approach: + /// PHASE 1: LOGICAL /// - We maintain a queue of group pairs that need to be merged. /// - For each pair, we create a new representative group. /// - We update all expressions that reference the merged groups. /// - If this creates new equivalences, we add the affected groups to the merge queue. /// - We continue until no more merges are needed. /// + /// PHASE 2: PHYSICAL + /// - We inspect the result of logical merges, and identify the goals to merge in + /// the corresponding group_info of the new representative group. + /// - For each group, we create a new representative goal. + /// - We update all expressions that reference the merged goals. + /// - If this creates new equivalences, we add the affected goals to the merge queue. + /// - We continue until no more merges are needed. + /// /// This approach ensures that all cascading effects are properly handled and the memo /// structure remains consistent after the merge. /// @@ -111,8 +209,6 @@ impl Memo for MemoryMemo { group_id_1: GroupId, group_id_2: GroupId, ) -> Result { - println!("Merging: {:?} and {:?}", group_id_1, group_id_2); - let mut merge_operations = vec![]; let mut pending_merges = VecDeque::from(vec![(group_id_1, group_id_2)]); @@ -124,7 +220,7 @@ impl Memo for MemoryMemo { continue; } - // Perform the group merge, creating a new representative + // Perform the group merge, creating a new representative. let (new_group_id, merge_product) = self.merge_group_pair(group_id_1, group_id_2).await?; merge_operations.push(merge_product); @@ -132,59 +228,84 @@ impl Memo for MemoryMemo { // Process expressions that reference the merged groups, // which may trigger additional group merges. let new_pending_merges = self - .process_referencing_expressions(group_id_1, group_id_2, new_group_id) + .process_referencing_logical_exprs(group_id_1, group_id_2, new_group_id) .await?; pending_merges.extend(new_pending_merges); } - // Consolidate the merge results by replacing the incremental merges + // Consolidate the merge products by replacing the incremental merges // with consolidated results that show the full picture. - self.consolidate_merge_results(merge_operations).await - } - - async fn get_best_optimized_physical_expr( - &self, - _goal_id: GoalId, - ) -> Result, Infallible> { - todo!() + let group_merges = self + .consolidate_merge_group_products(merge_operations) + .await?; + + // Now handle goal merges: we do not need to pass any extra parameters as + // the goals to merge are gathered in the `goals` member of each new + // representative group. + let goal_merges = self.merge_dependent_goals(&group_merges).await?; + + // Finally, we need to recursively merge the physical expressions that are + // dependent on the merged goals (and the recursively merged expressions themselves). + let expr_merges = self.merge_dependent_physical_exprs(&goal_merges).await?; + + Ok(MergeProducts { + group_merges, + goal_merges, + expr_merges, + }) } async fn get_all_goal_members( &self, - _goal_id: GoalId, - ) -> Result, Infallible> { - todo!() - } + goal_id: GoalId, + ) -> Result, Infallible> { + let repr_goal_id = self.find_repr_goal_id(goal_id).await?; + + let members = self + .goal_info + .get(&repr_goal_id) + .expect("Goal not found in memo table") + .members + .iter() + .map(|&member_id| self.find_repr_goal_member_id(member_id)) + .collect(); - async fn add_goal_member( - &mut self, - _goal_id: GoalId, - _member: GoalMemberId, - ) -> Result { - todo!() + Ok(members) } - async fn update_physical_expr_cost( + async fn add_goal_member( &mut self, - _physical_expr_id: PhysicalExpressionId, - _new_cost: Cost, + goal_id: GoalId, + member_id: GoalMemberId, ) -> Result { - todo!() - } + // We call `get_all_goal_members` to ensure we only have the representative IDs + // in the set. This is important because members may have been merged with another. + let mut current_members = self.get_all_goal_members(goal_id).await?; + + let repr_member_id = self.find_repr_goal_member_id(member_id); + let added = current_members.insert(repr_member_id); + + if added { + let repr_goal_id = self.find_repr_goal_id(goal_id).await?; + self.goal_info + .get_mut(&repr_goal_id) + .expect("Goal not found in memo table") + .members + .insert(repr_member_id); + } - async fn get_physical_expr_cost( - &self, - _physical_expr_id: PhysicalExpressionId, - ) -> Result, Infallible> { - todo!() + Ok(added) } } #[cfg(test)] pub mod tests { use super::*; - use crate::cir::{Child, OperatorData}; + use crate::{ + cir::{Child, OperatorData}, + memo::Materialize, + }; pub async fn lookup_or_insert( memo: &mut impl Memo, @@ -276,6 +397,28 @@ pub mod tests { assert_eq!(retrieve(&memo, g4).await, vec![0, 1, 2, 3]); } + async fn create_goal( + memo: &mut MemoryMemo, + group_id: GroupId, + props: PhysicalProperties, + ) -> GoalId { + let goal = Goal(group_id, props); + memo.get_goal_id(&goal).await.unwrap() + } + + async fn create_physical_expr( + memo: &mut MemoryMemo, + tag: &str, + children: Vec, + ) -> PhysicalExpressionId { + let expr = PhysicalExpression { + tag: tag.to_string(), + data: vec![], + children: children.into_iter().map(Child::Singleton).collect(), + }; + memo.get_physical_expr_id(&expr).await.unwrap() + } + #[tokio::test] async fn test_recursive_merge() { let mut memo = MemoryMemo::default(); @@ -553,4 +696,227 @@ pub mod tests { "Merged group should contain exactly one copy of each expression" ); } + + #[tokio::test] + async fn test_goal_merge() { + let mut memo = MemoryMemo::default(); + + // Create two groups. + let g1 = lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = lookup_or_insert(&mut memo, 2, vec![]).await; + + // Create goals with the same properties. + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props).await; + + // Add a member to each goal. + let p1 = create_physical_expr(&mut memo, "a", vec![]).await; + let p2 = create_physical_expr(&mut memo, "b", vec![]).await; + + memo.add_goal_member(goal1, GoalMemberId::PhysicalExpressionId(p1)) + .await + .unwrap(); + memo.add_goal_member(goal2, GoalMemberId::PhysicalExpressionId(p2)) + .await + .unwrap(); + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // Verify goal merges in the result. + assert!( + !merge_result.goal_merges.is_empty(), + "Goal merges should not be empty" + ); + + // Get the new goal and verify members. + let new_goal_id = merge_result.goal_merges[0].new_goal_id; + let members = memo.get_all_goal_members(new_goal_id).await.unwrap(); + + assert_eq!(members.len(), 2, "Merged goal should have two members"); + assert!(members.contains(&GoalMemberId::PhysicalExpressionId(p1))); + assert!(members.contains(&GoalMemberId::PhysicalExpressionId(p2))); + + // Verify representatives. + assert_eq!(memo.find_repr_goal_id(goal1).await.unwrap(), new_goal_id); + assert_eq!(memo.find_repr_goal_id(goal2).await.unwrap(), new_goal_id); + } + + #[tokio::test] + async fn test_physical_expr_merge() { + let mut memo = MemoryMemo::default(); + + // Create two groups and goals. + let g1 = lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = lookup_or_insert(&mut memo, 2, vec![]).await; + + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props).await; + + // Create physical expressions referencing these goals. + let p1 = create_physical_expr(&mut memo, "op", vec![GoalMemberId::GoalId(goal1)]).await; + let p2 = create_physical_expr(&mut memo, "op", vec![GoalMemberId::GoalId(goal2)]).await; + + // Verify they're different before merge. + assert_ne!(p1, p2); + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // Verify physical expression merges in the result. + assert!( + !merge_result.expr_merges.is_empty(), + "Physical expr merges should not be empty" + ); + + // Get the new physical expression id. + let new_expr_id = merge_result.expr_merges[0].new_physical_expr_id; + + // Verify representatives. + assert_eq!( + memo.find_repr_physical_expr_id(p1).await.unwrap(), + new_expr_id + ); + assert_eq!( + memo.find_repr_physical_expr_id(p2).await.unwrap(), + new_expr_id + ); + } + + #[tokio::test] + async fn test_recursive_physical_expr_merge() { + let mut memo = MemoryMemo::default(); + + // Create groups. + let g1 = lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = lookup_or_insert(&mut memo, 2, vec![]).await; + + // Create goals. + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props.clone()).await; + + // Create first level physical expressions. + let p1 = create_physical_expr(&mut memo, "op1", vec![GoalMemberId::GoalId(goal1)]).await; + let p2 = create_physical_expr(&mut memo, "op1", vec![GoalMemberId::GoalId(goal2)]).await; + + // Create second level physical expressions. + let p3 = create_physical_expr( + &mut memo, + "op2", + vec![GoalMemberId::PhysicalExpressionId(p1)], + ) + .await; + let p4 = create_physical_expr( + &mut memo, + "op2", + vec![GoalMemberId::PhysicalExpressionId(p2)], + ) + .await; + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // There should be multiple levels of physical expr merges. + assert!( + merge_result.expr_merges.len() >= 2, + "Should have at least 2 levels of physical expr merges" + ); + + // Verify representatives for both levels. + let p1_repr = memo.find_repr_physical_expr_id(p1).await.unwrap(); + let p2_repr = memo.find_repr_physical_expr_id(p2).await.unwrap(); + let p3_repr = memo.find_repr_physical_expr_id(p3).await.unwrap(); + let p4_repr = memo.find_repr_physical_expr_id(p4).await.unwrap(); + + assert_eq!( + p1_repr, p2_repr, + "First level expressions should share representative" + ); + assert_eq!( + p3_repr, p4_repr, + "Second level expressions should share representative" + ); + } + + #[tokio::test] + async fn test_merge_products_completeness() { + let mut memo = MemoryMemo::default(); + + // Create a three-level structure. + let g1 = super::tests::lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = super::tests::lookup_or_insert(&mut memo, 2, vec![]).await; + + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props.clone()).await; + + // Level 1 + let p1 = create_physical_expr(&mut memo, "leaf1", vec![GoalMemberId::GoalId(goal1)]).await; + let p2 = create_physical_expr(&mut memo, "leaf1", vec![GoalMemberId::GoalId(goal2)]).await; + + // Level 2 + let p3 = create_physical_expr( + &mut memo, + "mid1", + vec![GoalMemberId::PhysicalExpressionId(p1)], + ) + .await; + let p4 = create_physical_expr( + &mut memo, + "mid1", + vec![GoalMemberId::PhysicalExpressionId(p2)], + ) + .await; + + // Level 3 + let p5 = create_physical_expr( + &mut memo, + "top1", + vec![GoalMemberId::PhysicalExpressionId(p3)], + ) + .await; + let p6 = create_physical_expr( + &mut memo, + "top1", + vec![GoalMemberId::PhysicalExpressionId(p4)], + ) + .await; + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // Verify structure of the merge products. + assert_eq!( + merge_result.group_merges.len(), + 1, + "Should have 1 group merge" + ); + assert_eq!( + merge_result.goal_merges.len(), + 1, + "Should have 1 goal merge" + ); + assert_eq!( + merge_result.expr_merges.len(), + 3, + "Should have 3 expression merges for the three levels" + ); + + // Verify all expressions share the same representatives at their respective levels. + assert_eq!( + memo.find_repr_physical_expr_id(p1).await.unwrap(), + memo.find_repr_physical_expr_id(p2).await.unwrap(), + ); + assert_eq!( + memo.find_repr_physical_expr_id(p3).await.unwrap(), + memo.find_repr_physical_expr_id(p4).await.unwrap(), + ); + assert_eq!( + memo.find_repr_physical_expr_id(p5).await.unwrap(), + memo.find_repr_physical_expr_id(p6).await.unwrap(), + ); + } } diff --git a/optd/src/memo/memory/materialize.rs b/optd/src/memo/memory/materialize.rs index 711f37c4..aa351073 100644 --- a/optd/src/memo/memory/materialize.rs +++ b/optd/src/memo/memory/materialize.rs @@ -2,11 +2,12 @@ //! //! See the documentation for [`Materialize`] for more information. -use super::{Infallible, MemoryMemo, helpers::MemoryMemoHelper}; +use super::{GoalInfo, Infallible, MemoryMemo, helpers::MemoryMemoHelper}; use crate::{ cir::*, memo::{Materialize, Representative}, }; +use hashbrown::HashSet; impl Materialize for MemoryMemo { async fn get_logical_expr_id( @@ -24,7 +25,7 @@ impl Materialize for MemoryMemo { // Otherwise, create a new entry in the memo table (slow path). let expr_id = self.next_logical_expression_id(); - // Update the logical expression to group index. + // Update the group referencing expression index. remapped_expr .children .iter() @@ -69,8 +70,21 @@ impl Materialize for MemoryMemo { // Otherwise, create a new entry in the memo table (slow path). let goal_id = self.next_goal_id(); - self.id_to_goal.insert(goal_id, goal.clone().into_owned()); - self.goal_to_id.insert(goal.into_owned(), goal_id); + let goal_info = GoalInfo { + goal: goal.as_ref().clone(), + members: HashSet::new(), + }; + + self.goal_info.insert(goal_id, goal_info); + self.goal_to_id.insert(goal.clone().into_owned(), goal_id); + + // Connect the goal to its group. + let Goal(group_id, props) = goal.into_owned(); + self.group_info + .get_mut(&group_id) + .expect("Group not found in memo table") + .goals + .insert(props, vec![goal_id]); Ok(goal_id) } @@ -78,23 +92,61 @@ impl Materialize for MemoryMemo { async fn materialize_goal(&self, goal_id: GoalId) -> Result { let repr_goal_id = self.find_repr_goal_id(goal_id).await?; Ok(self - .id_to_goal + .goal_info .get(&repr_goal_id) - .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_goal_id)) + .expect("Goal not found in memo table") + .goal .clone()) } async fn get_physical_expr_id( &mut self, - _physical_expr: &PhysicalExpression, + physical_expr: &PhysicalExpression, ) -> Result { - todo!() + use Child::*; + + // Check if the expression is already in the memo table (fast path). + let remapped_expr = self.remap_physical_expr(physical_expr).await?; + if let Some(&expr_id) = self.physical_expr_to_id.get(remapped_expr.as_ref()) { + return Ok(expr_id); + } + + // Otherwise, create a new entry in the memo table (slow path). + let expr_id = self.next_physical_expression_id(); + + // Update the goal member referencing expression index. + remapped_expr + .children + .iter() + .flat_map(|child| match child { + Singleton(member_id) => vec![*member_id], + VarLength(member_ids) => member_ids.clone(), + }) + .for_each(|member_id| { + self.goal_member_referencing_exprs_index + .entry(member_id) + .or_default() + .insert(expr_id); + }); + + // Update the physical expression ID indexes. + self.id_to_physical_expr + .insert(expr_id, remapped_expr.clone().into_owned()); + self.physical_expr_to_id + .insert(remapped_expr.into_owned(), expr_id); + + Ok(expr_id) } async fn materialize_physical_expr( &self, - _physical_expr_id: PhysicalExpressionId, + physical_expr_id: PhysicalExpressionId, ) -> Result { - todo!() + let repr_expr_id = self.find_repr_physical_expr_id(physical_expr_id).await?; + Ok(self + .id_to_physical_expr + .get(&repr_expr_id) + .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_expr_id)) + .clone()) } } diff --git a/optd/src/memo/memory/mod.rs b/optd/src/memo/memory/mod.rs index 7f8b813d..a09dc405 100644 --- a/optd/src/memo/memory/mod.rs +++ b/optd/src/memo/memory/mod.rs @@ -31,7 +31,7 @@ pub struct MemoryMemo { // Goals. /// Key is always a representative ID. - id_to_goal: HashMap, + goal_info: HashMap, /// Each representative goal is mapped to its id, for faster lookups. goal_to_id: HashMap, @@ -41,12 +41,23 @@ pub struct MemoryMemo { /// Each representative expression is mapped to its id, for faster lookups. logical_expr_to_id: HashMap, + /// Key is always a representative ID. + id_to_physical_expr: HashMap, + /// Each representative expression is mapped to its id, for faster lookups. + physical_expr_to_id: HashMap, + // Indexes: only deal with representative IDs, but speeds up most queries. /// To speed up expr->group lookup, we maintain a mapping from logical expression IDs to group IDs. logical_id_to_group_index: HashMap, + /// To speed up recursive merges, we maintain a mapping from group IDs to all logical expression IDs - /// that contain a reference to this group. The value logical_expr_ids may *NOT* be a representative ID. + /// that contain a reference to this group. + /// The value logical_expr_ids may *NOT* be a representative ID. group_referencing_exprs_index: HashMap>, + /// To speed up recursive merges, we maintain a mapping from goal member IDs to all physical expression IDs + /// that contain a reference to this goal member. + /// The value physical_expr_ids may *NOT* be a representative ID. + goal_member_referencing_exprs_index: HashMap>, /// The shared next unique id to be used for goals, groups, logical expressions, and physical expressions. next_shared_id: i64, @@ -60,9 +71,23 @@ pub struct MemoryMemo { /// Information about a group: /// - All logical expressions in this group (always representative IDs). +/// - All goals that have this group as objective, for each physical properties. /// - Logical properties of this group. #[derive(Clone, Debug)] struct GroupInfo { expressions: HashSet, + // We make the key a Vec so that we can accumulate + // goals to merge while merging groups. Outside of merging, + // the value is always a single goal. + goals: HashMap>, logical_properties: LogicalProperties, } + +/// Information about a goal: +/// - The goal (group + properties), always representative. +/// - The members of the goal (physical expression IDs or other (sub)goal IDs), may *NOT* be representative. +#[derive(Clone, Debug)] +struct GoalInfo { + goal: Goal, + members: HashSet, +} diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index fa619da7..fe7ba5ea 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -25,6 +25,16 @@ pub struct MergeGoalProduct { pub merged_goals: Vec, } +/// Result of merging two physical expressions. +#[derive(Debug)] +pub struct MergePhysicalExprProduct { + /// ID of the new physical expression. + pub new_physical_expr_id: PhysicalExpressionId, + + /// Physical expressions that were merged. + pub merged_physical_exprs: Vec, +} + /// Results of merge operations, including group and goal merges. #[derive(Debug, Default)] pub struct MergeProducts { @@ -33,6 +43,9 @@ pub struct MergeProducts { /// Goal merge results. pub goal_merges: Vec, + + /// Physical expression merge results. + pub expr_merges: Vec, } /// Base trait defining a shared implemention-defined error type for all memo-related traits. @@ -124,6 +137,9 @@ pub trait Representative: MemoBase { /// expressions, and finding representative nodes of the union-find substructures. #[trait_variant::make(Send)] pub trait Memo: Representative + Materialize + Sync + 'static { + /// Prints the contents of the memo table to the console for debugging purposes. + async fn debug_dump(&self) -> Result<(), Self::MemoError>; + /// Retrieves logical properties for a group ID. /// /// # Parameters @@ -195,68 +211,29 @@ pub trait Memo: Representative + Materialize + Sync + 'static { // Physical expression and goal operations. // - /// Gets the best optimized physical expression ID for a goal ID. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to retrieve the best expression for. - /// - /// # Returns - /// The ID of the lowest-cost physical implementation found so far for the goal, - /// along with its cost. Returns None if no optimized expression exists. - async fn get_best_optimized_physical_expr( - &self, - goal_id: GoalId, - ) -> Result, Self::MemoError>; - /// Gets all members of a goal, which can be physical expressions or other goals. /// /// # Parameters /// * `goal_id` - ID of the goal to retrieve members from. /// /// # Returns - /// A vector of goal members, each being either a physical expression ID or another goal ID. + /// A set of goal members, each being either a physical expression ID or another goal ID. async fn get_all_goal_members( &self, goal_id: GoalId, - ) -> Result, Self::MemoError>; + ) -> Result, Self::MemoError>; /// Adds a member to a goal. /// /// # Parameters /// * `goal_id` - ID of the goal to add the member to. - /// * `member` - The member to add, either a physical expression ID or another goal ID. + /// * `member_id` - The member to add, either a physical expression ID or another goal ID. /// /// # Returns /// True if the member was added to the goal, or false if it already existed. async fn add_goal_member( &mut self, goal_id: GoalId, - member: GoalMemberId, - ) -> Result; - - /// Updates the cost of a physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to update. - /// * `new_cost` - New cost to assign to the physical expression. - /// - /// # Returns - /// Whether the cost of the expression has improved. - async fn update_physical_expr_cost( - &mut self, - physical_expr_id: PhysicalExpressionId, - new_cost: Cost, + member_id: GoalMemberId, ) -> Result; - - /// Gets the cost of a physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to retrieve the cost for. - /// - /// # Returns - /// The cost of the physical expression, or None if it doesn't exist. - async fn get_physical_expr_cost( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> Result, Self::MemoError>; } diff --git a/optd/src/optimizer/handlers.rs b/optd/src/optimizer/handlers.rs index a022945d..4d696c64 100644 --- a/optd/src/optimizer/handlers.rs +++ b/optd/src/optimizer/handlers.rs @@ -1,7 +1,4 @@ -use super::{ - JobId, Optimizer, TaskId, - jobs::{CostedContinuation, LogicalContinuation}, -}; +use super::{JobId, Optimizer, TaskId, jobs::LogicalContinuation}; use crate::{ cir::*, memo::Memo, @@ -132,31 +129,40 @@ impl Optimizer { /// * `Result<(), Error>` - Success or error during processing. pub(super) async fn process_new_physical_partial( &mut self, - _plan: PartialPhysicalPlan, - _goal_id: GoalId, - _job_id: JobId, + plan: PartialPhysicalPlan, + goal_id: GoalId, + job_id: JobId, ) -> Result<(), M::MemoError> { - todo!() - } + use GoalMemberId::*; - /// This method handles fully optimized physical expressions with cost information. - /// - /// When a new optimized expression is found, it's added to the memo. If it becomes - /// the new best expression for its goal, continuations are notified and and clients - /// receive the corresponding egested plan. - /// - /// # Parameters - /// * `expression_id` - ID of the physical expression to process. - /// * `cost` - Cost information for the expression. - /// - /// # Returns - /// * `Result<(), Error>` - Success or error during processing. - pub(super) async fn process_new_costed_physical( - &mut self, - _expression_id: PhysicalExpressionId, - _cost: Cost, - ) -> Result<(), M::MemoError> { - todo!() + let goal_id = self.memo.find_repr_goal_id(goal_id).await?; + let related_task_id = self.get_related_task_id(job_id); + + let member_id = self.probe_ingest_physical_plan(&plan).await?; + let added = self.memo.add_goal_member(goal_id, member_id).await?; + + match member_id { + PhysicalExpressionId(_) => { + // TODO: Here we would launch costing tasks based on the design. + } + GoalId(goal_id) => { + if added { + // Optimize the new sub-goal and add to task graph. + let sub_optimize_task_id = self.ensure_optimize_goal_task(goal_id).await?; + + self.get_optimize_goal_task_mut(sub_optimize_task_id) + .unwrap() + .optimize_goal_out + .insert(related_task_id); + self.get_optimize_goal_task_mut(related_task_id) + .unwrap() + .optimize_goal_in + .insert(sub_optimize_task_id); + } + } + } + + Ok(()) } /// This method handles group creation for expressions with derived properties @@ -201,25 +207,6 @@ impl Optimizer { .await } - /// Registers a continuation for receiving optimized physical expressions for a goal. - /// The continuation will be notified about the best existing expression and any better ones found. - /// - /// # Parameters - /// * `goal` - The goal to subscribe to. - /// * `continuation` - Continuation to call when new optimized expressions are found. - /// * `job_id` - ID of the job that initiated this request. - /// - /// # Returns - /// * `Result<(), Error>` - Success or error during processing. - pub(super) async fn process_goal_subscription( - &mut self, - _goal: &Goal, - _continuation: CostedContinuation, - _job_id: JobId, - ) -> Result<(), M::MemoError> { - todo!() - } - /// Retrieves the logical properties for the given group from the memo /// and sends them back to the requestor through the provided channel. /// diff --git a/optd/src/optimizer/hir_cir/from_cir.rs b/optd/src/optimizer/hir_cir/from_cir.rs index 2122b951..35c26765 100644 --- a/optd/src/optimizer/hir_cir/from_cir.rs +++ b/optd/src/optimizer/hir_cir/from_cir.rs @@ -4,34 +4,57 @@ use crate::cir::*; use crate::dsl::analyzer::hir::{ self, CoreData, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, }; -use std::sync::Arc; -/// Converts a [`PartialLogicalPlan`] into a [`Value`]. -pub fn partial_logical_to_value(plan: &PartialLogicalPlan) -> Value { +/// Converts a [`PartialLogicalPlan`] into a [`Value`], optionally associating it with a [`GroupId`]. +/// +/// If the plan is unmaterialized, it will be represented as a [`Value`] containing a group reference. +/// If the plan is materialized, the resulting [`Value`] contains the operator data, and may include +/// a group ID if provided. +pub fn partial_logical_to_value(plan: &PartialLogicalPlan, group_id: Option) -> Value { + use Child::*; use Materializable::*; match plan { - PartialLogicalPlan::UnMaterialized(group_id) => { - // For unmaterialized logical operators, we create a `Value` with the group ID. - Value::new(CoreData::Logical(UnMaterialized(hir::GroupId(group_id.0)))) + PartialLogicalPlan::UnMaterialized(gid) => { + // Represent unmaterialized logical operators using their group ID. + Value::new(CoreData::Logical(UnMaterialized(hir::GroupId(gid.0)))) } PartialLogicalPlan::Materialized(node) => { - // For materialized logical operators, we create a `Value` with the operator data. + // Convert each child to a Value recursively. + let children: Vec = node + .children + .iter() + .map(|child| match child { + Singleton(item) => partial_logical_to_value(item, None), + VarLength(items) => Value::new(CoreData::Array( + items + .iter() + .map(|item| partial_logical_to_value(item, None)) + .collect(), + )), + }) + .collect(); + let operator = Operator { tag: node.tag.clone(), data: convert_operator_data_to_values(&node.data), - children: convert_children_to_values(&node.children, partial_logical_to_value), + children, }; - Value::new(CoreData::Logical(Materialized(LogicalOp::logical( - operator, - )))) + let logical_op = if let Some(gid) = group_id { + LogicalOp::stored_logical(operator, cir_group_id_to_hir(&gid)) + } else { + LogicalOp::logical(operator) + }; + + Value::new(CoreData::Logical(Materialized(logical_op))) } } } /// Converts a [`PartialPhysicalPlan`] into a [`Value`]. pub fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { + use Child::*; use Materializable::*; match plan { @@ -45,7 +68,19 @@ pub fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { let operator = Operator { tag: node.tag.clone(), data: convert_operator_data_to_values(&node.data), - children: convert_children_to_values(&node.children, partial_physical_to_value), + children: node + .children + .iter() + .map(|child| match child { + Singleton(item) => partial_physical_to_value(item), + VarLength(items) => Value::new(CoreData::Array( + items + .iter() + .map(|item| partial_physical_to_value(item)) + .collect(), + )), + }) + .collect(), }; Value::new(CoreData::Physical(Materialized(PhysicalOp::physical( @@ -87,23 +122,6 @@ fn cir_group_id_to_hir(group_id: &GroupId) -> hir::GroupId { hir::GroupId(group_id.0) } -/// A generic function to convert a slice of children into a vector of [`Value`]s. -fn convert_children_to_values(children: &[Child>], converter: F) -> Vec -where - F: Fn(&T) -> Value, - T: 'static, -{ - children - .iter() - .map(|child| match child { - Child::Singleton(item) => converter(item), - Child::VarLength(items) => Value::new(CoreData::Array( - items.iter().map(|item| converter(item)).collect(), - )), - }) - .collect() -} - /// Converts a slice of [`OperatorData`] into a vector of [`Value`]s. fn convert_operator_data_to_values(data: &[OperatorData]) -> Vec { data.iter().map(operator_data_to_value).collect() diff --git a/optd/src/optimizer/jobs/execute.rs b/optd/src/optimizer/jobs/execute.rs index 0a795517..130c3d82 100644 --- a/optd/src/optimizer/jobs/execute.rs +++ b/optd/src/optimizer/jobs/execute.rs @@ -1,4 +1,4 @@ -use super::{CostedContinuation, JobId, LogicalContinuation}; +use super::{JobId, LogicalContinuation}; use crate::{ cir::*, dsl::{ @@ -9,12 +9,10 @@ use crate::{ optimizer::{ EngineProduct, Optimizer, OptimizerMessage, hir_cir::{ - from_cir::{ - partial_logical_to_value, partial_physical_to_value, physical_properties_to_value, - }, + from_cir::{partial_logical_to_value, physical_properties_to_value}, into_cir::{ - hir_goal_to_cir, hir_group_id_to_cir, value_to_cost, value_to_logical_properties, - value_to_partial_logical, value_to_partial_physical, + hir_group_id_to_cir, value_to_logical_properties, value_to_partial_logical, + value_to_partial_physical, }, }, }, @@ -49,6 +47,11 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + None, // FIXME: This is correct, however in the DSL right now the parameters + // of the derive function is a *, to allow its children to be * too. However, + // since the group is not yet created, it isn't really stored yet it. + // Hence, the input should be a different type, and ideally be made consistent + // with how costing is handled (i.e. with $ and *). ); let message_tx = self.message_tx.clone(); @@ -99,6 +102,7 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + Some(group_id), ); let message_tx = self.message_tx.clone(); @@ -143,6 +147,9 @@ impl Optimizer { ) -> Result<(), M::MemoError> { use EngineProduct::*; + let Goal(group_id, physical_props) = self.memo.materialize_goal(goal_id).await?; + let properties = physical_properties_to_value(&physical_props); + let engine = self.init_engine(); let plan = partial_logical_to_value( &self @@ -150,11 +157,9 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + Some(group_id), ); - let Goal(_, physical_props) = self.memo.materialize_goal(goal_id).await?; - let properties = physical_properties_to_value(&physical_props); - let message_tx = self.message_tx.clone(); tokio::spawn(async move { let response = engine @@ -175,55 +180,19 @@ impl Optimizer { Ok(()) } - /// Executes a job to compute the cost of a physical expression. - /// - /// This creates an engine instance and launches the cost calculation process - /// for the specified physical expression. - /// - /// # Parameters - /// * `expression_id`: The ID of the physical expression to cost. - /// * `job_id`: The ID of the job to be executed. - pub(super) async fn execute_cost_expression( - &self, - expression_id: PhysicalExpressionId, - job_id: JobId, - ) -> Result<(), M::MemoError> { - use EngineProduct::*; - - let engine = self.init_engine(); - let plan = partial_physical_to_value(&self.egest_partial_plan(expression_id).await?); - - let message_tx = self.message_tx.clone(); - tokio::spawn(async move { - let response = engine - .launch( - "cost", - vec![plan], - Arc::new(move |cost| { - Box::pin( - async move { NewCostedPhysical(expression_id, value_to_cost(&cost)) }, - ) - }), - ) - .await; - - Self::process_engine_response(job_id, message_tx, response).await; - }); - - Ok(()) - } - /// Executes a job to continue processing with a logical expression result. /// /// This materializes the logical expression and passes it to the continuation. /// /// # Parameters /// * `expression_id`: The ID of the logical expression to continue with. + /// * `group_id`: The ID of the group to which the expression belongs. /// * `k`: The continuation function to be called with the materialized plan. /// * `job_id`: The ID of the job to be executed. pub(super) async fn execute_continue_with_logical( &self, expression_id: LogicalExpressionId, + group_id: GroupId, k: LogicalContinuation, job_id: JobId, ) -> Result<(), M::MemoError> { @@ -236,6 +205,7 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + Some(group_id), ); let message_tx = self.message_tx.clone(); @@ -251,32 +221,6 @@ impl Optimizer { Ok(()) } - /// Executes a job to continue processing with an optimized physical expression result. - /// - /// This materializes the physical expression and passes it along with its cost - /// to the continuation. - /// - /// # Parameters - /// * `expression_id`: The ID of the physical expression to continue with. - /// * `k`: The continuation function to be called with the materialized plan. - /// * `job_id`: The ID of the job to be executed. - pub(super) async fn execute_continue_with_costed( - &self, - expression_id: PhysicalExpressionId, - k: CostedContinuation, - job_id: JobId, - ) -> Result<(), M::MemoError> { - let plan = partial_physical_to_value(&self.egest_partial_plan(expression_id).await?); - - let message_tx = self.message_tx.clone(); - tokio::spawn(async move { - let response = k.0(plan).await; - Self::process_engine_response(job_id, message_tx, response).await; - }); - - Ok(()) - } - /// Helper function to process the engine response and send it to the optimizer. /// /// Handles `YieldGroup` and `YieldGoal` responses by sending subscription messages @@ -300,10 +244,7 @@ impl Optimizer { SubscribeGroup(hir_group_id_to_cir(&group_id), LogicalContinuation(k)), job_id, ), - YieldGoal(goal, k) => OptimizerMessage::product( - SubscribeGoal(hir_goal_to_cir(&goal), CostedContinuation(k)), - job_id, - ), + YieldGoal(_, _) => todo!("Decide what to do here depending on the cost model"), }; engine_tx diff --git a/optd/src/optimizer/jobs/manage.rs b/optd/src/optimizer/jobs/manage.rs index ee8d1d10..f517c583 100644 --- a/optd/src/optimizer/jobs/manage.rs +++ b/optd/src/optimizer/jobs/manage.rs @@ -63,15 +63,8 @@ impl Optimizer { self.execute_implementation_rule(rule_name, expression_id, goal_id, job_id) .await?; } - CostExpression(expression_id) => { - self.execute_cost_expression(expression_id, job_id).await?; - } - ContinueWithLogical(expression_id, k) => { - self.execute_continue_with_logical(expression_id, k, job_id) - .await?; - } - ContinueWithCosted(expression_id, k) => { - self.execute_continue_with_costed(expression_id, k, job_id) + ContinueWithLogical(expression_id, group_id, k) => { + self.execute_continue_with_logical(expression_id, group_id, k, job_id) .await?; } } diff --git a/optd/src/optimizer/jobs/mod.rs b/optd/src/optimizer/jobs/mod.rs index 409ab3ba..7774ecaf 100644 --- a/optd/src/optimizer/jobs/mod.rs +++ b/optd/src/optimizer/jobs/mod.rs @@ -1,9 +1,6 @@ use super::{EngineProduct, TaskId}; use crate::{ - cir::{ - GoalId, GroupId, ImplementationRule, LogicalExpressionId, PhysicalExpressionId, - TransformationRule, - }, + cir::{GoalId, GroupId, ImplementationRule, LogicalExpressionId, TransformationRule}, dsl::{ analyzer::hir::Value, engine::{Continuation, EngineResponse}, @@ -29,10 +26,6 @@ pub(crate) struct Job(pub TaskId, pub JobKind); #[derive(Clone)] pub(crate) struct LogicalContinuation(Continuation>); -/// Represents a continuation for processing costed physical expressions. -#[derive(Clone)] -pub(crate) struct CostedContinuation(Continuation>); - /// Enumeration of different types of jobs in the optimizer. /// /// Each variant represents a specific optimization operation that can be @@ -55,26 +48,11 @@ pub(crate) enum JobKind { /// /// This job generates physical implementations of a logical expression /// based on specific implementation strategies. - #[allow(dead_code)] ImplementExpression(ImplementationRule, LogicalExpressionId, GoalId), - /// Starts computing the cost of a physical expression. - /// - /// This job estimates the execution cost of a physical implementation - /// to aid in selecting the optimal plan. - #[allow(dead_code)] - CostExpression(PhysicalExpressionId), - /// Continues processing with a logical expression result. /// /// This job represents a continuation-passing-style callback for /// handling the result of a logical expression operation. - ContinueWithLogical(LogicalExpressionId, LogicalContinuation), - - /// Continues processing with an optimized expression result. - /// - /// This job represents a continuation-passing-style callback for - /// handling the result of an optimized physical expression operation. - #[allow(dead_code)] - ContinueWithCosted(PhysicalExpressionId, CostedContinuation), + ContinueWithLogical(LogicalExpressionId, GroupId, LogicalContinuation), } diff --git a/optd/src/optimizer/memo_io/egest.rs b/optd/src/optimizer/memo_io/egest.rs index 4c6a4fba..0a87ca47 100644 --- a/optd/src/optimizer/memo_io/egest.rs +++ b/optd/src/optimizer/memo_io/egest.rs @@ -7,6 +7,7 @@ use async_recursion::async_recursion; use futures::future::try_join_all; use std::sync::Arc; +// TODO: Until costing is resolved, this file is not used (and just left here for reference) impl Optimizer { /// Recursively transforms a physical expression ID in the memo into a complete physical plan. /// @@ -21,7 +22,6 @@ impl Optimizer { /// * `Ok(None)` if any goal ID lacks a best expression ID. /// * `Err(Error)` if a memo operation fails. #[async_recursion] - #[allow(dead_code)] pub(crate) async fn egest_best_plan( &self, expression_id: PhysicalExpressionId, diff --git a/optd/src/optimizer/memo_io/ingest.rs b/optd/src/optimizer/memo_io/ingest.rs index ab648a0d..898bbd48 100644 --- a/optd/src/optimizer/memo_io/ingest.rs +++ b/optd/src/optimizer/memo_io/ingest.rs @@ -151,7 +151,6 @@ impl Optimizer { } } - #[allow(dead_code)] async fn probe_ingest_physical_operator( &mut self, operator: &Operator>, diff --git a/optd/src/optimizer/memo_io/mod.rs b/optd/src/optimizer/memo_io/mod.rs index 2fa128dd..d18f2dce 100644 --- a/optd/src/optimizer/memo_io/mod.rs +++ b/optd/src/optimizer/memo_io/mod.rs @@ -1,6 +1,3 @@ -mod egest; mod ingest; -#[allow(unused)] -pub(super) use egest::*; pub(super) use ingest::*; diff --git a/optd/src/optimizer/merge/helpers.rs b/optd/src/optimizer/merge/helpers.rs index d15c6e75..68c10af5 100644 --- a/optd/src/optimizer/merge/helpers.rs +++ b/optd/src/optimizer/merge/helpers.rs @@ -12,13 +12,21 @@ impl Optimizer { /// 1. Computes newly discovered expressions by subtracting existing ones. /// 2. For each fork task in the group, launches continuation tasks for each new expression. /// 3. If the task is the principal, launches transform tasks for each new expression and rule. - /// 4. Updates the task's dispatched expressions with the full input set. + /// 4. For all related optimize goal tasks, launch implement tasks for each new expression. + /// + /// *NOTE*: This happens before merging goals. While this might be slightly inefficient, as + /// we might implement twice for the same "soon-to-be" merged goals. However it keeps + /// the code cleaner and easier to understand. Also, the performance impact is negligible, + /// as once we merge the goals, we effectively delete the implement tasks and its + /// associated jobs will *not* be launched. + /// + /// 5. Updates the task's dispatched expressions with the full input set. /// /// # Arguments /// * `task_id` - The ID of the group exploration task to update. /// * `all_logical_exprs` - The complete set of logical expressions known for this group. /// * `principal` - Whether this task is the principal one (responsible for launching transforms). - pub(super) async fn update_group_explore( + pub(super) async fn update_tasks( &mut self, task_id: TaskId, all_logical_exprs: &HashSet, @@ -26,41 +34,67 @@ impl Optimizer { ) -> Result<(), M::MemoError> { let new_exprs = self.compute_new_expressions(task_id, all_logical_exprs); - if !new_exprs.is_empty() { - let (group_id, fork_tasks) = { - let task = self.get_explore_group_task(task_id).unwrap(); - (task.group_id, task.fork_logical_out.clone()) - }; + if new_exprs.is_empty() { + return Ok(()); + } - for &fork_task_id in &fork_tasks { - let continuation = self - .get_fork_logical_task(fork_task_id) - .unwrap() - .continuation - .clone(); + let (group_id, fork_tasks, optimize_goal_tasks) = { + let task = self.get_explore_group_task(task_id).unwrap(); + ( + task.group_id, + task.fork_logical_out.clone(), + task.optimize_goal_out.clone(), + ) + }; - let continuation_tasks = - self.create_logical_cont_tasks(&new_exprs, fork_task_id, &continuation); + // For each fork task, create continuation tasks for each new expression. + fork_tasks.iter().for_each(|&fork_task_id| { + let continuation = self + .get_fork_logical_task(fork_task_id) + .unwrap() + .continuation + .clone(); - self.get_fork_logical_task_mut(fork_task_id) - .unwrap() - .continue_with_logical_in - .extend(continuation_tasks); - } + let continuation_tasks = + self.create_logical_cont_tasks(&new_exprs, group_id, fork_task_id, &continuation); - if principal { - let transform_tasks = self.create_transform_tasks(&new_exprs, group_id, task_id); - self.get_explore_group_task_mut(task_id) - .unwrap() - .transform_expr_in - .extend(transform_tasks); - } + self.get_fork_logical_task_mut(fork_task_id) + .unwrap() + .continue_with_logical_in + .extend(continuation_tasks); + }); + + // For each optimize goal task, create implement tasks for each new expression. + optimize_goal_tasks.iter().for_each(|&optimize_goal_id| { + let goal_id = self + .get_optimize_goal_task(optimize_goal_id) + .unwrap() + .goal_id; + + let implement_tasks = + self.create_implement_tasks(&new_exprs, goal_id, optimize_goal_id); + + self.get_optimize_goal_task_mut(optimize_goal_id) + .unwrap() + .implement_expression_in + .extend(implement_tasks); + }); + + // For the principal task, create transform tasks for each new expression. + // We could always do it, but this is a straightforward optimization. + if principal { + let transform_tasks = self.create_transform_tasks(&new_exprs, group_id, task_id); self.get_explore_group_task_mut(task_id) .unwrap() - .dispatched_exprs = all_logical_exprs.clone(); + .transform_expr_in + .extend(transform_tasks); } + self.get_explore_group_task_mut(task_id) + .unwrap() + .dispatched_exprs = all_logical_exprs.clone(); + Ok(()) } @@ -114,10 +148,69 @@ impl Optimizer { let goal_task = self.get_optimize_goal_task_mut(goal_id).unwrap(); goal_task.explore_group_in = principal_task_id; }); + + // *NOTE*: No need to consolidate transformations as all missing + // expressions have already been added during the update phase. }); } - /// Deduplicate dispatched expressions for a group exploration task. + /// Consolidate a goal optimization task into a principal task. + /// + /// - Similar to the group exploration consolidation, a merge in the memo may cause + /// several goal optimization tasks to refer to the same underlying goal. This function + /// consolidates all such secondary tasks into a principal one. + /// + /// - This involves: + /// - Moving all outgoing optimize goal tasks, incoming optimize goal tasks, and + /// optimize plan tasks from secondary to principal. + /// - Updating each such task's references to point to the principal task. + /// - Deleting the secondary tasks. + /// + /// # Arguments + /// * `principal_task_id` - The ID of the task to retain as canonical. + /// * `secondary_task_ids` - All other task IDs to merge into the principal. + pub(super) async fn consolidate_goal_optimize( + &mut self, + principal_task_id: TaskId, + secondary_task_ids: &[TaskId], + ) { + secondary_task_ids.iter().for_each(|&task_id| { + let task = self.get_optimize_goal_task_mut(task_id).unwrap(); + + // Move out the task sets before deletion. + let optimize_out_tasks = std::mem::take(&mut task.optimize_goal_out); + let optimize_in_tasks = std::mem::take(&mut task.optimize_goal_in); + let optimize_plan_tasks = std::mem::take(&mut task.optimize_plan_out); + + self.delete_task(task_id); + + optimize_out_tasks.into_iter().for_each(|optimize_id| { + let optimize_goal_task = self.get_optimize_goal_task_mut(optimize_id).unwrap(); + optimize_goal_task.optimize_goal_in.remove(&task_id); + optimize_goal_task + .optimize_goal_in + .insert(principal_task_id); + }); + + optimize_in_tasks.into_iter().for_each(|optimize_id| { + let optimize_goal_task = self.get_optimize_goal_task_mut(optimize_id).unwrap(); + optimize_goal_task.optimize_goal_out.remove(&task_id); + optimize_goal_task + .optimize_goal_out + .insert(principal_task_id); + }); + + optimize_plan_tasks.into_iter().for_each(|optimize_id| { + let optimize_plan_task = self.get_optimize_plan_task_mut(optimize_id).unwrap(); + optimize_plan_task.optimize_goal_in = Some(principal_task_id); + }); + + // *NOTE*: No need to consolidate implementations as all missing + // expressions have already been added during the update phase. + }); + } + + /// Deduplicate dispatched expressions for a group. /// /// - During optimization, logical expressions in a group may get merged /// in the memo structure. As a result, multiple expressions in the same @@ -126,19 +219,17 @@ impl Optimizer { /// - This function: /// - Maps each expression to its current representative. /// - Identifies and prunes redundant (duplicate) expressions. - /// - Updates or deletes related transform and fork/continuation tasks - /// accordingly. + /// - Updates or deletes related transform, implement and continuation + /// tasks accordingly. /// /// # Arguments /// * `task_id` - The ID of the group exploration task to deduplicate. - pub(super) async fn dedup_group_explore( - &mut self, - task_id: TaskId, - ) -> Result<(), M::MemoError> { + pub(super) async fn dedup_tasks(&mut self, task_id: TaskId) -> Result<(), M::MemoError> { let task = self.get_explore_group_task_mut(task_id).unwrap(); let old_exprs = std::mem::take(&mut task.dispatched_exprs); let transform_ids: Vec<_> = task.transform_expr_in.iter().copied().collect(); let fork_ids: Vec<_> = task.fork_logical_out.iter().copied().collect(); + let optimize_ids: Vec<_> = task.optimize_goal_out.iter().copied().collect(); let expr_to_repr = self.map_to_representatives(&old_exprs).await?; @@ -155,8 +246,9 @@ impl Optimizer { (seen, dups) }; - self.process_transform_tasks(&transform_ids, &expr_to_repr, &to_delete); - self.process_fork_tasks(&fork_ids, &expr_to_repr, &to_delete); + self.dedup_transform_tasks(&transform_ids, &expr_to_repr, &to_delete); + self.dedup_continue_tasks(&fork_ids, &expr_to_repr, &to_delete); + self.dedup_implement_tasks(&optimize_ids, &expr_to_repr, &to_delete); let task = self.get_explore_group_task_mut(task_id).unwrap(); task.dispatched_exprs = unique_reprs; @@ -178,22 +270,22 @@ impl Optimizer { } /// Update or delete transform tasks based on deduplicated expressions. - fn process_transform_tasks( + fn dedup_transform_tasks( &mut self, - tasks: &[TaskId], + transform_ids: &[TaskId], expr_to_repr: &HashMap, to_delete: &[LogicalExpressionId], ) { - for &task_id in tasks { + for &transform_id in transform_ids { let expr_id = self - .get_transform_expression_task(task_id) + .get_transform_expression_task(transform_id) .unwrap() .expression_id; if to_delete.contains(&expr_id) { - self.delete_task(task_id); + self.delete_task(transform_id); } else { - self.get_transform_expression_task_mut(task_id) + self.get_transform_expression_task_mut(transform_id) .unwrap() .expression_id = expr_to_repr[&expr_id]; } @@ -201,15 +293,15 @@ impl Optimizer { } /// Update or delete continuation tasks spawned by fork tasks. - fn process_fork_tasks( + fn dedup_continue_tasks( &mut self, - tasks: &[TaskId], + fork_ids: &[TaskId], expr_to_repr: &HashMap, to_delete: &[LogicalExpressionId], ) { - for &task_id in tasks { + for &fork_id in fork_ids { let continue_ids = self - .get_fork_logical_task(task_id) + .get_fork_logical_task(fork_id) .unwrap() .continue_with_logical_in .clone(); // Clone to avoid mutable borrow conflict. @@ -230,4 +322,35 @@ impl Optimizer { } } } + + /// Update or delete implement tasks based on deduplicated expressions. + fn dedup_implement_tasks( + &mut self, + optimize_goals: &[TaskId], + expr_to_repr: &HashMap, + to_delete: &[LogicalExpressionId], + ) { + for &optimize_goal in optimize_goals { + let implement_ids = self + .get_optimize_goal_task(optimize_goal) + .unwrap() + .implement_expression_in + .clone(); // Clone to avoid mutable borrow conflict. + + for implement_id in implement_ids { + let expr_id = self + .get_implement_expression_task(implement_id) + .unwrap() + .expression_id; + + if to_delete.contains(&expr_id) { + self.delete_task(implement_id); + } else { + self.get_implement_expression_task_mut(implement_id) + .unwrap() + .expression_id = expr_to_repr[&expr_id]; + } + } + } + } } diff --git a/optd/src/optimizer/merge/mod.rs b/optd/src/optimizer/merge/mod.rs index b553c6c7..8d2bb618 100644 --- a/optd/src/optimizer/merge/mod.rs +++ b/optd/src/optimizer/merge/mod.rs @@ -1,11 +1,29 @@ use super::Optimizer; -use crate::memo::{Memo, MergeGroupProduct, MergeProducts}; +use crate::memo::{Memo, MergeGoalProduct, MergeGroupProduct, MergeProducts}; mod helpers; impl Optimizer { /// Processes merge results by updating the task graph to reflect merges in the memo. /// + /// This function handles both group merges and goal merges by delegating to specialized + /// handlers for each type of merge. + /// + /// # Parameters + /// * `result` - The merge result to handle, containing information about merged groups and goals. + /// + /// # Returns + /// * `Result<(), OptimizeError>` - Success or an error that occurred during processing. + pub(super) async fn handle_merge_result( + &mut self, + result: MergeProducts, + ) -> Result<(), M::MemoError> { + self.handle_group_merges(&result.group_merges).await?; + self.handle_goal_merges(&result.goal_merges).await + } + + /// Handles merges of logical groups by updating the task graph. + /// /// When groups are merged in the memo, multiple exploration tasks may now refer to /// the same underlying group. This method handles the task graph updates required /// by such merges. For each merged group, it: @@ -15,7 +33,7 @@ impl Optimizer { /// related transform and continuation tasks. /// /// 2. **Updates**: Sends any new logical expressions to each task, creating appropriate - /// transform tasks (for the principal task) and continuation tasks (for all tasks). + /// transform (for the principal task), implement and continuation tasks (for all tasks). /// /// 3. **Consolidates**: Merges all secondary tasks into a principal task by transferring /// their dependencies and updating references, ensuring a clean 1:1 mapping between @@ -28,18 +46,18 @@ impl Optimizer { /// with it, containing all logical expressions from the original groups with no duplicates. /// /// # Parameters - /// * `result` - The merge result to handle, containing information about merged groups. + /// * `group_merges` - A slice of group merge products to handle. /// /// # Returns /// * `Result<(), OptimizeError>` - Success or an error that occurred during processing. - pub(super) async fn handle_merge_result( + async fn handle_group_merges( &mut self, - result: MergeProducts, + group_merges: &[MergeGroupProduct], ) -> Result<(), M::MemoError> { for MergeGroupProduct { new_group_id, merged_groups, - } in result.group_merges + } in group_merges { // For each merged group, get all group exploration tasks. // We don't need to check for the new group ID since it is guaranteed to be new. @@ -49,18 +67,18 @@ impl Optimizer { .collect(); if !group_explore_tasks.is_empty() { - let all_logical_exprs = self.memo.get_all_logical_exprs(new_group_id).await?; + let all_logical_exprs = self.memo.get_all_logical_exprs(*new_group_id).await?; let (principal_task_id, secondary_task_ids) = group_explore_tasks.split_first().unwrap(); for task_id in &group_explore_tasks { - // Step 1: Start by deduplicating all transform & continue tasks given - // potentially merged logical expressions. - self.dedup_group_explore(*task_id).await?; + // Step 1: Start by deduplicating all transform, implement & continue tasks + // given potentially merged logical expressions. + self.dedup_tasks(*task_id).await?; // Step 2: Send *new* logical expressions to each task. let is_principal = task_id == principal_task_id; - self.update_group_explore(*task_id, &all_logical_exprs, is_principal) + self.update_tasks(*task_id, &all_logical_exprs, is_principal) .await?; } @@ -70,7 +88,64 @@ impl Optimizer { // Step 4: Set the index to point to the new representative task. self.group_exploration_task_index - .insert(new_group_id, *principal_task_id); + .insert(*new_group_id, *principal_task_id); + } + } + + Ok(()) + } + + /// Handles merges of optimization goals by updating the task graph. + /// + /// When goals are merged in the memo, multiple optimization tasks may now refer to + /// the same underlying goal. This method handles the task graph updates required + /// by such merges. For each merged goal, it: + /// + /// 1. **Consolidates**: Merges all secondary tasks into a principal task by transferring + /// their dependencies (incoming and outgoing optimization tasks and plan tasks) and + /// updating references, ensuring a clean 1:1 mapping between goals and optimization tasks. + /// + /// 2. **Re-indexes**: Updates the goal optimization index to point to the principal task + /// for the new goal ID. + /// + /// After processing, each merged goal will have exactly one optimization task associated + /// with it, with all dependencies properly redirected to it. + /// + /// # Parameters + /// * `goal_merges` - A slice of goal merge products to handle. + /// + /// # Returns + /// * `Result<(), OptimizeError>` - Success or an error that occurred during processing. + async fn handle_goal_merges( + &mut self, + goal_merges: &[MergeGoalProduct], + ) -> Result<(), M::MemoError> { + for MergeGoalProduct { + new_goal_id, + merged_goals, + } in goal_merges + { + // For each merged goal, get all goal optimization tasks. + // We don't need to check for the new goal ID since it is guaranteed to be new. + let goal_optimize_tasks: Vec<_> = merged_goals + .iter() + .filter_map(|goal_id| self.goal_optimization_task_index.get(goal_id).copied()) + .collect(); + + if !goal_optimize_tasks.is_empty() { + let (principal_task_id, secondary_task_ids) = + goal_optimize_tasks.split_first().unwrap(); + + // *NOTE*: Deduplication and updates of implementation tasks have already been + // handled in the `handle_group_merges` method, so we don't need to do it here. + + // Step 1: Consolidate all dependent tasks into the new "representative" task. + self.consolidate_goal_optimize(*principal_task_id, secondary_task_ids) + .await; + + // Step 2: Set the index to point to the new representative task. + self.goal_optimization_task_index + .insert(*new_goal_id, *principal_task_id); } } diff --git a/optd/src/optimizer/mod.rs b/optd/src/optimizer/mod.rs index de38da2f..45d42ac7 100644 --- a/optd/src/optimizer/mod.rs +++ b/optd/src/optimizer/mod.rs @@ -17,14 +17,24 @@ mod merge; mod retriever; mod tasks; -use jobs::{CostedContinuation, Job, JobId, LogicalContinuation}; +use jobs::{Job, JobId, LogicalContinuation}; use retriever::OptimizerRetriever; use tasks::{Task, TaskId}; /// Default maximum number of concurrent jobs to run in the optimizer. const DEFAULT_MAX_CONCURRENT_JOBS: usize = 1000; -/// External client request to optimize a query in the optimizer. +/// External client requests that can be sent to the optimizer. +#[derive(Clone, Debug)] +pub enum ClientRequest { + /// Request to optimize a logical plan into a physical plan. + Optimize(OptimizeRequest), + + /// Request to dump the contents of the memo for debugging purposes. + DumpMemo, +} + +/// Request to optimize a query in the optimizer. /// /// Defines the public API for submitting a query and receiving execution plans. #[derive(Clone, Debug)] @@ -49,17 +59,11 @@ enum EngineProduct { /// New physical implementation for a goal, awaiting recursive optimization. NewPhysicalPartial(PartialPhysicalPlan, GoalId), - /// Fully optimized physical expression with complete costing. - NewCostedPhysical(PhysicalExpressionId, Cost), - /// Create a new group with the provided logical properties. CreateGroup(LogicalExpressionId, LogicalProperties), /// Subscribe to logical expressions in a specific group. SubscribeGroup(GroupId, LogicalContinuation), - - /// Subscribe to costed physical expressions for a goal. - SubscribeGoal(Goal, CostedContinuation), } /// Messages passed within the optimization system. @@ -67,7 +71,7 @@ enum EngineProduct { /// Each message represents either a client request or the result of completed work, /// allowing the optimizer to track which tasks are progressing. enum OptimizerMessage { - /// Client request to optimize a plan. + /// Client request to the optimizer. Request(OptimizeRequest, TaskId), /// Request to retrieve the properties of a group. @@ -121,7 +125,7 @@ pub struct Optimizer { pending_messages: Vec, message_tx: Sender, message_rx: Receiver, - optimize_rx: Receiver, + client_rx: Receiver, // Task management. tasks: HashMap, @@ -150,7 +154,7 @@ impl Optimizer { catalog: Arc, message_tx: Sender, message_rx: Receiver, - optimize_rx: Receiver, + client_rx: Receiver, ) -> Self { Self { // Core components. @@ -164,7 +168,7 @@ impl Optimizer { pending_messages: Vec::new(), message_tx, message_rx, - optimize_rx, + client_rx, // Task management. tasks: HashMap::new(), @@ -185,9 +189,9 @@ impl Optimizer { } /// Launch a new optimizer and return a sender for client communication. - pub fn launch(memo: M, catalog: Arc, hir: HIR) -> Sender { + pub fn launch(memo: M, catalog: Arc, hir: HIR) -> Sender { let (message_tx, message_rx) = mpsc::channel(1); - let (optimize_tx, optimize_rx) = mpsc::channel(1); + let (client_tx, client_rx) = mpsc::channel(1); // Start the background processing loop. let optimizer = Self::new( @@ -196,7 +200,7 @@ impl Optimizer { catalog, message_tx.clone(), message_rx, - optimize_rx, + client_rx, ); tokio::spawn(async move { @@ -205,30 +209,40 @@ impl Optimizer { optimizer.run().await.expect("Optimizer failure"); }); - optimize_tx + client_tx } /// Run the optimizer's main processing loop. async fn run(mut self) -> Result<(), M::MemoError> { + use ClientRequest::*; use EngineProduct::*; use OptimizerMessage::*; loop { tokio::select! { - Some(request) = self.optimize_rx.recv() => { - let task_id = self.create_optimize_plan_task(request.plan.clone(), - request.physical_tx.clone()); + Some(client_request) = self.client_rx.recv() => { let message_tx = self.message_tx.clone(); - // Forward the optimization request to the message processing loop - // in a new coroutine to avoid a deadlock. - tokio::spawn( - async move { - message_tx.send(Request(request, task_id)) - .await - .expect("Failed to forward optimize request - channel closed"); + match client_request { + Optimize(optimize_request) => { + // Create a task for the optimization request. + let task_id = self.create_optimize_plan_task( + optimize_request.plan.clone(), + optimize_request.physical_tx.clone() + ); + + // Forward the client request to the message processing loop + // in a new coroutine to avoid a deadlock. + tokio::spawn(async move { + message_tx.send(Request(optimize_request, task_id)) + .await + .expect("Failed to forward client request - channel closed"); + }); + }, + DumpMemo => { + self.memo.debug_dump().await?; } - ); + } }, Some(message) = self.message_rx.recv() => { // Process the next message in the channel. @@ -237,7 +251,7 @@ impl Optimizer { self.process_optimize_request(plan, physical_tx, task_id).await?, Retrieve(group_id, response_tx) => { self.process_retrieve_properties(group_id, response_tx).await?; - } + }, Product(product, job_id) => { let task_id = self.get_related_task_id(job_id); @@ -250,18 +264,12 @@ impl Optimizer { NewPhysicalPartial(plan, goal_id) => { self.process_new_physical_partial(plan, goal_id, job_id).await?; } - NewCostedPhysical(expression_id, cost) => { - self.process_new_costed_physical(expression_id, cost).await?; - } CreateGroup(expression_id, properties) => { self.process_create_group(expression_id, &properties, job_id).await?; } SubscribeGroup(group_id, continuation) => { self.process_group_subscription(group_id, continuation, job_id).await?; } - SubscribeGoal(goal, continuation) => { - self.process_goal_subscription(&goal, continuation, job_id).await?; - } } } diff --git a/optd/src/optimizer/tasks/delete.rs b/optd/src/optimizer/tasks/delete.rs index 88b82e6e..4bb93485 100644 --- a/optd/src/optimizer/tasks/delete.rs +++ b/optd/src/optimizer/tasks/delete.rs @@ -28,7 +28,16 @@ impl Optimizer { self.delete_task(fork_id); } } + ImplementExpression(implement_expression_task) => { + let optimize_goal_task = self + .get_optimize_goal_task_mut(implement_expression_task.optimize_goal_out) + .unwrap(); + optimize_goal_task.implement_expression_in.remove(&task_id); + if let Some(fork_id) = implement_expression_task.fork_in { + self.delete_task(fork_id); + } + } ContinueWithLogical(task) => { let fork_task = self.get_fork_logical_task_mut(task.fork_out).unwrap(); fork_task.continue_with_logical_in.remove(&task_id); @@ -37,12 +46,12 @@ impl Optimizer { self.delete_task(fork_id); } } - ForkLogical(task) => { let explore_task = self .get_explore_group_task_mut(task.explore_group_in) .unwrap(); explore_task.fork_logical_out.remove(&task_id); + // Delete explore task if it has no more purpose. if explore_task.fork_logical_out.is_empty() && explore_task.optimize_goal_out.is_empty() @@ -56,7 +65,6 @@ impl Optimizer { self.delete_task(continue_id); } } - ExploreGroup(task) => { assert!(task.fork_logical_out.is_empty()); assert!(task.optimize_goal_out.is_empty()); @@ -69,10 +77,20 @@ impl Optimizer { self.delete_task(transform_id); } } + OptimizeGoal(task) => { + assert!(task.optimize_goal_out.is_empty()); + assert!(task.optimize_plan_out.is_empty()); - _ => { - todo!(); + self.goal_optimization_task_index + .retain(|_, &mut v| v != task_id); + + let implement_tasks: Vec<_> = + task.implement_expression_in.iter().copied().collect(); + for implement_id in implement_tasks { + self.delete_task(implement_id); + } } + OptimizePlan(_) => todo!(), } // Finally, remove the task from the task collection. diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index b67311fc..fddcd019 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -6,11 +6,12 @@ use crate::{ Optimizer, jobs::{JobKind, LogicalContinuation}, tasks::{ - ContinueWithLogicalTask, ExploreGroupTask, ForkLogicalTask, OptimizeGoalTask, - TransformExpressionTask, + ContinueWithLogicalTask, ExploreGroupTask, ForkLogicalTask, ImplementExpressionTask, + OptimizeGoalTask, TransformExpressionTask, }, }, }; +use async_recursion::async_recursion; use hashbrown::HashSet; use tokio::sync::mpsc::Sender; @@ -96,7 +97,7 @@ impl Optimizer { .dispatched_exprs .clone(); let continue_with_logical_in = - self.create_logical_cont_tasks(&expressions, fork_task_id, &continuation); + self.create_logical_cont_tasks(&expressions, group_id, fork_task_id, &continuation); // Create the fork task. let fork_logical_task = ForkLogicalTask { @@ -121,6 +122,7 @@ impl Optimizer { /// /// # Parameters /// * `expression_id`: The ID of the logical expression to continue with. + /// * `group_id`: The ID of the group that this expression belongs to. /// * `fork_out`: The ID of the fork task that this continue task feeds into. /// * `continuation`: The logical continuation to be used. /// @@ -129,6 +131,7 @@ impl Optimizer { pub(crate) fn launch_continue_with_logical_task( &mut self, expression_id: LogicalExpressionId, + group_id: GroupId, fork_out: TaskId, continuation: LogicalContinuation, ) -> TaskId { @@ -144,7 +147,7 @@ impl Optimizer { self.add_task(task_id, ContinueWithLogical(task)); self.schedule_job( task_id, - JobKind::ContinueWithLogical(expression_id, continuation), + JobKind::ContinueWithLogical(expression_id, group_id, continuation), ); task_id @@ -186,6 +189,42 @@ impl Optimizer { task_id } + /// Method to launch a task for implementing a logical expression with a rule. + /// + /// # Parameters + /// * `expr_id`: The ID of the logical expression to implement. + /// * `rule`: The implementation rule to apply. + /// * `optimize_goal_out`: The ID of the optimization task that this implementation feeds into. + /// * `goal_id`: The ID of the goal that this implementation belongs to. + /// + /// # Returns + /// * `TaskId`: The ID of the created implement task. + pub(crate) fn launch_implement_expression_task( + &mut self, + expr_id: LogicalExpressionId, + rule: ImplementationRule, + optimize_goal_out: TaskId, + goal_id: GoalId, + ) -> TaskId { + use Task::*; + + let task_id = self.next_task_id(); + let task = ImplementExpressionTask { + _rule: rule.clone(), + expression_id: expr_id, + optimize_goal_out, + fork_in: None, + }; + + self.add_task(task_id, ImplementExpression(task)); + self.schedule_job( + task_id, + JobKind::ImplementExpression(rule, expr_id, goal_id), + ); + + task_id + } + /// Creates transform tasks for a set of logical expressions. /// /// # Arguments @@ -219,10 +258,44 @@ impl Optimizer { transform_tasks } + /// Creates implement tasks for a set of logical expressions. + /// + /// # Arguments + /// * `expressions` - The logical expressions to implement + /// * `goal_id` - The goal ID to implement for + /// * `optimize_task_id` - The optimization task that these transforms feed into + /// + /// # Returns + /// * `HashSet` - The IDs of all created implement tasks + pub(crate) fn create_implement_tasks( + &mut self, + expressions: &HashSet, + goal_id: GoalId, + optimize_task_id: TaskId, + ) -> HashSet { + let implementations = self.rule_book.get_implementations().to_vec(); + let mut implement_tasks = HashSet::new(); + + for &expr_id in expressions { + for rule in &implementations { + let task_id = self.launch_implement_expression_task( + expr_id, + rule.clone(), + optimize_task_id, + goal_id, + ); + implement_tasks.insert(task_id); + } + } + + implement_tasks + } + /// Creates logical continuation tasks for a fork task and a set of expressions. /// /// # Arguments /// * `expressions` - The logical expressions to continue with + /// * `group_id` - The group ID these expressions belong to /// * `fork_task_id` - The fork task that these continuations feed into /// * `continuation` - The continuation to apply /// @@ -231,13 +304,18 @@ impl Optimizer { pub(crate) fn create_logical_cont_tasks( &mut self, expressions: &HashSet, + group_id: GroupId, fork_task_id: TaskId, continuation: &LogicalContinuation, ) -> HashSet { let mut continuation_tasks = HashSet::new(); for &expr_id in expressions { - let task_id = - self.launch_continue_with_logical_task(expr_id, fork_task_id, continuation.clone()); + let task_id = self.launch_continue_with_logical_task( + expr_id, + group_id, + fork_task_id, + continuation.clone(), + ); continuation_tasks.insert(task_id); } @@ -257,7 +335,7 @@ impl Optimizer { /// /// # Returns /// * `TaskId`: The ID of the task that was created or reused. - async fn ensure_group_exploration_task( + pub(crate) async fn ensure_group_exploration_task( &mut self, group_id: GroupId, ) -> Result { @@ -303,7 +381,11 @@ impl Optimizer { /// /// # Returns /// * `TaskId`: The ID of the task that was created or reused. - async fn ensure_optimize_goal_task(&mut self, goal_id: GoalId) -> Result { + #[async_recursion] + pub(crate) async fn ensure_optimize_goal_task( + &mut self, + goal_id: GoalId, + ) -> Result { use Task::*; // Find the representative goal for the given goal ID. @@ -316,20 +398,21 @@ impl Optimizer { let goal_optimize_task_id = self.next_task_id(); - // TODO(Alexis): Materialize the goal and only explore the group - for now. - // This is sufficient to support logical->logical transformation. let Goal(group_id, _) = self.memo.materialize_goal(goal_id).await?; let explore_group_in = self.ensure_group_exploration_task(group_id).await?; + // Launch all implementation tasks. + let expressions = self.memo.get_all_logical_exprs(group_id).await?; + let implement_expression_in = + self.create_implement_tasks(&expressions, goal_id, goal_optimize_task_id); + let goal_optimize_task = OptimizeGoalTask { goal_id, optimize_plan_out: HashSet::new(), - optimize_goal_out: HashSet::new(), - fork_costed_out: HashSet::new(), - optimize_goal_in: HashSet::new(), + optimize_goal_out: HashSet::new(), // filled below to avoid infinite recursion + optimize_goal_in: HashSet::new(), // filled below to avoid infinite recursion explore_group_in, - implement_expression_in: HashSet::new(), - cost_expression_in: HashSet::new(), + implement_expression_in, }; // Add this task to the exploration task's outgoing edges. @@ -343,22 +426,36 @@ impl Optimizer { self.goal_optimization_task_index .insert(goal_repr, goal_optimize_task_id); - Ok(goal_optimize_task_id) - } + // Ensure sub-goals are getting explored too, we do this after registering + // the task to avoid infinite recursion. + let sub_goals = self + .memo + .get_all_goal_members(goal_id) + .await? + .into_iter() + .filter_map(|member_id| match member_id { + GoalMemberId::GoalId(sub_goal_id) => Some(sub_goal_id), + _ => None, + }); + + // Launch optimization tasks for each subgoal and establish links. + let mut subgoal_task_ids = HashSet::new(); + for sub_goal_id in sub_goals { + let sub_goal_task_id = self.ensure_optimize_goal_task(sub_goal_id).await?; + subgoal_task_ids.insert(sub_goal_task_id); + + // Add current task to subgoal's outgoing edges. + self.get_optimize_goal_task_mut(sub_goal_task_id) + .unwrap() + .optimize_goal_out + .insert(goal_optimize_task_id); + } - /// Ensures a cost expression task exists and and returns its id. - /// If a costing task already exists, we reuse it. - /// - /// # Parameters - /// * `expression_id`: The ID of the expression to be costed. - /// - /// # Returns - /// * `TaskId`: The ID of the task that was created or reused. - #[allow(dead_code)] - async fn ensure_cost_expression_task( - &mut self, - _expression_id: PhysicalExpressionId, - ) -> Result { - todo!() + // Update the parent task with links to subgoals. + self.get_optimize_goal_task_mut(goal_optimize_task_id) + .unwrap() + .optimize_goal_in = subgoal_task_ids; + + Ok(goal_optimize_task_id) } } diff --git a/optd/src/optimizer/tasks/manage.rs b/optd/src/optimizer/tasks/manage.rs index 3cf528ff..b80119b7 100644 --- a/optd/src/optimizer/tasks/manage.rs +++ b/optd/src/optimizer/tasks/manage.rs @@ -1,7 +1,6 @@ use super::{ - ContinueWithCostedTask, ContinueWithLogicalTask, CostExpressionTask, ExploreGroupTask, - ForkCostedTask, ForkLogicalTask, ImplementExpressionTask, OptimizeGoalTask, OptimizePlanTask, - Task, TaskId, TransformExpressionTask, + ContinueWithLogicalTask, ExploreGroupTask, ForkLogicalTask, ImplementExpressionTask, + OptimizeGoalTask, OptimizePlanTask, Task, TaskId, TransformExpressionTask, }; use crate::{memo::Memo, optimizer::Optimizer}; @@ -36,7 +35,6 @@ impl Optimizer { // Direct task type access by ID - immutable versions /// Get a task as an OptimizePlanTask by its ID. - #[allow(dead_code)] pub(crate) fn get_optimize_plan_task(&self, task_id: TaskId) -> Option<&OptimizePlanTask> { self.get_task(task_id).and_then(|task| match &task { Task::OptimizePlan(task) => Some(task), @@ -45,7 +43,6 @@ impl Optimizer { } /// Get a task as an OptimizeGoalTask by its ID. - #[allow(dead_code)] pub(crate) fn get_optimize_goal_task(&self, task_id: TaskId) -> Option<&OptimizeGoalTask> { self.get_task(task_id).and_then(|task| match &task { Task::OptimizeGoal(task) => Some(task), @@ -62,7 +59,6 @@ impl Optimizer { } /// Get a task as an ImplementExpressionTask by its ID. - #[allow(dead_code)] pub(crate) fn get_implement_expression_task( &self, task_id: TaskId, @@ -84,15 +80,6 @@ impl Optimizer { }) } - /// Get a task as a CostExpressionTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_cost_expression_task(&self, task_id: TaskId) -> Option<&CostExpressionTask> { - self.get_task(task_id).and_then(|task| match &task { - Task::CostExpression(task) => Some(task), - _ => None, - }) - } - /// Get a task as a ForkLogicalTask by its ID. pub(crate) fn get_fork_logical_task(&self, task_id: TaskId) -> Option<&ForkLogicalTask> { self.get_task(task_id).and_then(|task| match &task { @@ -101,15 +88,6 @@ impl Optimizer { }) } - /// Get a task as a ForkCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_fork_costed_task(&self, task_id: TaskId) -> Option<&ForkCostedTask> { - self.get_task(task_id).and_then(|task| match &task { - Task::ForkCosted(task) => Some(task), - _ => None, - }) - } - /// Get a task as a ContinueWithLogicalTask by its ID. pub(crate) fn get_continue_with_logical_task( &self, @@ -121,22 +99,9 @@ impl Optimizer { }) } - /// Get a task as a ContinueWithCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_continue_with_costed_task( - &self, - task_id: TaskId, - ) -> Option<&ContinueWithCostedTask> { - self.get_task(task_id).and_then(|task| match &task { - Task::ContinueWithCosted(task) => Some(task), - _ => None, - }) - } - // Direct task type access by ID - mutable versions /// Get a mutable task as an OptimizePlanTask by its ID. - #[allow(dead_code)] pub(crate) fn get_optimize_plan_task_mut( &mut self, task_id: TaskId, @@ -170,7 +135,6 @@ impl Optimizer { } /// Get a mutable task as an ImplementExpressionTask by its ID. - #[allow(dead_code)] pub(crate) fn get_implement_expression_task_mut( &mut self, task_id: TaskId, @@ -192,18 +156,6 @@ impl Optimizer { }) } - /// Get a mutable task as a CostExpressionTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_cost_expression_task_mut( - &mut self, - task_id: TaskId, - ) -> Option<&mut CostExpressionTask> { - self.get_task_mut(task_id).and_then(|task| match task { - Task::CostExpression(task) => Some(task), - _ => None, - }) - } - /// Get a mutable task as a ForkLogicalTask by its ID. pub(crate) fn get_fork_logical_task_mut( &mut self, @@ -215,18 +167,6 @@ impl Optimizer { }) } - /// Get a mutable task as a ForkCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_fork_costed_task_mut( - &mut self, - task_id: TaskId, - ) -> Option<&mut ForkCostedTask> { - self.get_task_mut(task_id).and_then(|task| match task { - Task::ForkCosted(task) => Some(task), - _ => None, - }) - } - /// Get a mutable task as a ContinueWithLogicalTask by its ID. pub(crate) fn get_continue_with_logical_task_mut( &mut self, @@ -237,16 +177,4 @@ impl Optimizer { _ => None, }) } - - /// Get a mutable task as a ContinueWithCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_continue_with_costed_task_mut( - &mut self, - task_id: TaskId, - ) -> Option<&mut ContinueWithCostedTask> { - self.get_task_mut(task_id).and_then(|task| match task { - Task::ContinueWithCosted(task) => Some(task), - _ => None, - }) - } } diff --git a/optd/src/optimizer/tasks/mod.rs b/optd/src/optimizer/tasks/mod.rs index c020e4fd..d6fcc498 100644 --- a/optd/src/optimizer/tasks/mod.rs +++ b/optd/src/optimizer/tasks/mod.rs @@ -1,7 +1,7 @@ -use super::jobs::{CostedContinuation, LogicalContinuation}; +use super::jobs::LogicalContinuation; use crate::cir::{ - Cost, GoalId, GroupId, ImplementationRule, LogicalExpressionId, LogicalPlan, - PhysicalExpressionId, PhysicalPlan, TransformationRule, + GoalId, GroupId, ImplementationRule, LogicalExpressionId, LogicalPlan, PhysicalPlan, + TransformationRule, }; use hashbrown::HashSet; use tokio::sync::mpsc::Sender; @@ -23,17 +23,10 @@ pub(crate) enum Task { OptimizePlan(OptimizePlanTask), OptimizeGoal(OptimizeGoalTask), ExploreGroup(ExploreGroupTask), - #[allow(dead_code)] ImplementExpression(ImplementExpressionTask), TransformExpression(TransformExpressionTask), - #[allow(dead_code)] - CostExpression(CostExpressionTask), ForkLogical(ForkLogicalTask), - #[allow(dead_code)] - ForkCosted(ForkCostedTask), ContinueWithLogical(ContinueWithLogicalTask), - #[allow(dead_code)] - ContinueWithCosted(ContinueWithCostedTask), } //============================================================================= @@ -65,8 +58,6 @@ pub(crate) struct OptimizeGoalTask { /// `OptimizeGoalTask` parent goals that this task is simultaneously /// producing for. pub optimize_goal_out: HashSet, - /// `ForkCostedTask` subscribed to this goal. - pub fork_costed_out: HashSet, // Input tasks that feed this task. /// `OptimizeGoalTask` member (children) goals producing for this goal. @@ -76,8 +67,6 @@ pub(crate) struct OptimizeGoalTask { pub explore_group_in: TaskId, /// `ImplementExpressionTask` rules that are implementing logical expressions. pub implement_expression_in: HashSet, - /// `CostExpressionTask` costing of physical expressions. - pub cost_expression_in: HashSet, } /// Task to explore expressions in a logical group. @@ -121,10 +110,10 @@ pub(crate) struct TransformExpressionTask { /// Task to implement a logical expression into a physical expression. #[derive(Clone)] -#[allow(dead_code)] pub(crate) struct ImplementExpressionTask { /// The implementation rule to apply. - pub rule: ImplementationRule, + /// NOTE: Variable not used but kept for observability. + pub _rule: ImplementationRule, /// The logical expression to implement. pub expression_id: LogicalExpressionId, @@ -138,25 +127,6 @@ pub(crate) struct ImplementExpressionTask { pub fork_in: Option, } -/// Task to cost a physical expression. -#[derive(Clone)] -#[allow(dead_code)] -pub(crate) struct CostExpressionTask { - /// The physical expression to cost. - pub expression_id: PhysicalExpressionId, - /// The current upper bound on the allowed cost budget. - pub budget: Cost, - - // Output tasks that get fed by the output of this task. - /// `OptimizeGoalTask` corresponding goal optimization task. - pub optimize_goal_out: HashSet, - - // Input tasks that feed this task. - /// `ForkCostedTask` cost fork points encountered during the - /// costing. - pub fork_in: Option, -} - /// Task to fork the logical optimization process. #[derive(Clone)] pub(crate) struct ForkLogicalTask { @@ -177,27 +147,6 @@ pub(crate) struct ForkLogicalTask { pub continue_with_logical_in: HashSet, } -/// Task to fork the costed optimization process. -#[derive(Clone)] -#[allow(dead_code)] -pub(crate) struct ForkCostedTask { - /// The fork continuation. - pub continuation: CostedContinuation, - /// The current upper bound on the allowed cost budget. - pub budget: Cost, - - /// `ContinueWithCostedTask` | `CostExpressionTask` that gets fed by the - /// output of this task. - pub out: TaskId, - - // Input tasks that feed this task. - /// `OptimizeGoalTask` corresponding goal optimization task producing - /// costed expressions. - pub optimize_goal_in: TaskId, - /// `ContinueWithCosted` tasks spawned off and producing for this task. - pub continue_with_costed_in: HashSet, -} - /// Task to continue with a logical expression. #[derive(Clone)] pub(crate) struct ContinueWithLogicalTask { @@ -209,16 +158,3 @@ pub(crate) struct ContinueWithLogicalTask { /// Potential `ForkLogicalTask` fork spawned off from this task. pub fork_in: Option, } - -/// Task to continue with a costed expression. -#[derive(Clone)] -#[allow(dead_code)] -pub(crate) struct ContinueWithCostedTask { - /// The physical expression to continue with. - pub expression_id: PhysicalExpressionId, - - /// `ForkCostedTask` that gets fed by the output of this continuation. - pub fork_out: TaskId, - /// Potential `ForkCostedTask` fork spawned off from this task. - pub fork_in: Option, -}