From 3b89bcc3e531ef0e177d63ddcd5f75206483f39b Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 14 May 2025 19:18:45 -0400 Subject: [PATCH 01/23] Add types and skeleton for physical memo API --- optd/src/memo/memory/helpers.rs | 77 ++++++++++++++++++++++++++ optd/src/memo/memory/implementation.rs | 3 +- optd/src/memo/memory/materialize.rs | 58 +++++++++++++++++-- optd/src/memo/memory/mod.rs | 16 +++++- optd/src/memo/mod.rs | 13 +++++ 5 files changed, 159 insertions(+), 8 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index c087a6cc..4bf5896c 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -64,11 +64,18 @@ impl MemoryMemo { /// /// See the implementation itself for the documentation of each helper method. pub trait MemoryMemoHelper: Memo { + async fn find_repr_goal_member_id(&self, id: GoalMemberId) -> Result; + async fn remap_logical_expr<'a>( &self, logical_expr: &'a LogicalExpression, ) -> Result, Infallible>; + async fn remap_physical_expr<'a>( + &self, + physical_expr: &'a PhysicalExpression, + ) -> Result, Infallible>; + async fn remap_goal<'a>(&self, goal: &'a Goal) -> Result, Infallible>; async fn merge_group_pair( @@ -98,6 +105,25 @@ pub trait MemoryMemoHelper: Memo { } impl MemoryMemoHelper for MemoryMemo { + /// Finds the representative ID for a given [`GoalMemberId`]. + /// + /// This handles both variants of `GoalMemberId` by finding the appropriate representative + /// based on whether it's a Goal or PhysicalExpr. + async fn find_repr_goal_member_id(&self, id: GoalMemberId) -> Result { + use GoalMemberId::*; + + match id { + GoalId(goal_id) => { + let repr_goal_id = self.find_repr_goal_id(goal_id).await?; + Ok(GoalId(repr_goal_id)) + } + PhysicalExpressionId(expr_id) => { + let repr_expr_id = self.find_repr_physical_expr_id(expr_id).await?; + Ok(PhysicalExpressionId(repr_expr_id)) + } + } + } + /// Remaps the children of a logical expression such that they are all identified by their /// representative IDs. /// @@ -147,6 +173,55 @@ impl MemoryMemoHelper for MemoryMemo { }) } + /// Remaps the children of a physical expression such that they are all identified by their + /// representative IDs. + /// + /// For example, if a physical expression has a child goal with [`GoalMemberId`] 3, but the + /// representative of goal 3 is [`GoalMemberId`] 42, then the output expression will be the input + /// physical expression with a child goal of 42. + /// + /// If no remapping needs to occur, this returns the same [`PhysicalExpression`] object via the + /// [`Cow`]. Otherwise, this function will create a new owned [`PhysicalExpression`]. + async fn remap_physical_expr<'a>( + &self, + physical_expr: &'a PhysicalExpression, + ) -> Result, Infallible> { + use Child::*; + + let mut needs_remapping = false; + let mut remapped_children = Vec::with_capacity(physical_expr.children.len()); + + for child in &physical_expr.children { + let remapped = match child { + Singleton(goal_id) => { + let repr = self.find_repr_goal_member_id(*goal_id).await?; + needs_remapping |= repr != *goal_id; + Singleton(repr) + } + VarLength(goal_ids) => { + let mut reprs = Vec::with_capacity(goal_ids.len()); + for goal_id in goal_ids { + let repr = self.find_repr_goal_member_id(*goal_id).await?; + needs_remapping |= repr != *goal_id; + reprs.push(repr); + } + VarLength(reprs) + } + }; + remapped_children.push(remapped); + } + + Ok(if needs_remapping { + Cow::Owned(PhysicalExpression { + tag: physical_expr.tag.clone(), + data: physical_expr.data.clone(), + children: remapped_children, + }) + } else { + Cow::Borrowed(physical_expr) + }) + } + /// Remaps the [`GroupId`] component of a [`Goal`] to its representative [`GroupId`]. /// /// For example, if the [`Goal`] has a [`GroupID`] of 3, but the representative of group 3 is @@ -227,6 +302,7 @@ impl MemoryMemoHelper for MemoryMemo { // Create and save the new group. let new_group_info = GroupInfo { expressions: all_exprs, + goals: HashMap::new(), // TODO(Alexis): this will be used to trigger the goal merge process. logical_properties: group1_info.logical_properties, }; self.group_info.insert(new_group_id, new_group_info); @@ -414,6 +490,7 @@ impl MemoryMemoHelper for MemoryMemo { Ok(MergeProducts { group_merges, goal_merges: vec![], + expr_merges: vec![], }) } } diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index ecd62309..0081b07b 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -6,7 +6,7 @@ use super::{ Infallible, Memo, MemoryMemo, MergeProducts, Representative, helpers::MemoryMemoHelper, }; use crate::{cir::*, memo::memory::GroupInfo}; -use hashbrown::HashSet; +use hashbrown::{HashMap, HashSet}; impl Memo for MemoryMemo { async fn get_logical_properties( @@ -60,6 +60,7 @@ impl Memo for MemoryMemo { let group_id = self.next_group_id(); let group_info = GroupInfo { expressions: HashSet::from([logical_expr_id]), + goals: HashMap::new(), logical_properties: props.clone(), }; diff --git a/optd/src/memo/memory/materialize.rs b/optd/src/memo/memory/materialize.rs index 711f37c4..8fbc993f 100644 --- a/optd/src/memo/memory/materialize.rs +++ b/optd/src/memo/memory/materialize.rs @@ -24,7 +24,7 @@ impl Materialize for MemoryMemo { // Otherwise, create a new entry in the memo table (slow path). let expr_id = self.next_logical_expression_id(); - // Update the logical expression to group index. + // Update the group referencing expression index. remapped_expr .children .iter() @@ -70,7 +70,15 @@ impl Materialize for MemoryMemo { // Otherwise, create a new entry in the memo table (slow path). let goal_id = self.next_goal_id(); self.id_to_goal.insert(goal_id, goal.clone().into_owned()); - self.goal_to_id.insert(goal.into_owned(), goal_id); + self.goal_to_id.insert(goal.clone().into_owned(), goal_id); + + // Connect the goal to its group. + let Goal(group_id, props) = goal.into_owned(); + self.group_info + .get_mut(&group_id) + .expect("Group not found in memo table") + .goals + .insert(props, goal_id); Ok(goal_id) } @@ -86,15 +94,53 @@ impl Materialize for MemoryMemo { async fn get_physical_expr_id( &mut self, - _physical_expr: &PhysicalExpression, + physical_expr: &PhysicalExpression, ) -> Result { - todo!() + use Child::*; + + // Check if the expression is already in the memo table (fast path). + let remapped_expr = self.remap_physical_expr(physical_expr).await?; + if let Some(&expr_id) = self.physical_expr_to_id.get(remapped_expr.as_ref()) { + return Ok(expr_id); + } + + // Otherwise, create a new entry in the memo table (slow path). + let expr_id = self.next_physical_expression_id(); + + // Update the goal member referencing expression index. + remapped_expr + .children + .iter() + .flat_map(|child| match child { + Singleton(member_id) => vec![*member_id], + VarLength(member_ids) => member_ids.clone(), + }) + .for_each(|member_id| { + self.goal_member_referencing_exprs_index + .entry(member_id) + .or_default() + .insert(expr_id); + }); + + // Update the physical expression ID indexes. + self.id_to_physical_expr + .insert(expr_id, remapped_expr.clone().into_owned()); + self.id_to_cost.insert(expr_id, None); + self.physical_expr_to_id + .insert(remapped_expr.into_owned(), expr_id); + + Ok(expr_id) } async fn materialize_physical_expr( &self, - _physical_expr_id: PhysicalExpressionId, + physical_expr_id: PhysicalExpressionId, ) -> Result { - todo!() + let repr_expr_id = self.find_repr_physical_expr_id(physical_expr_id).await?; + Ok(self + .id_to_physical_expr + .get(&repr_expr_id) + .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_expr_id)) + .clone()) } } diff --git a/optd/src/memo/memory/mod.rs b/optd/src/memo/memory/mod.rs index 7f8b813d..6be06abe 100644 --- a/optd/src/memo/memory/mod.rs +++ b/optd/src/memo/memory/mod.rs @@ -41,12 +41,24 @@ pub struct MemoryMemo { /// Each representative expression is mapped to its id, for faster lookups. logical_expr_to_id: HashMap, + /// Key is always a representative ID. + id_to_physical_expr: HashMap, + /// The cost of each representative physical expression, if available. + id_to_cost: HashMap>, + /// Each representative expression is mapped to its id, for faster lookups. + physical_expr_to_id: HashMap, + // Indexes: only deal with representative IDs, but speeds up most queries. /// To speed up expr->group lookup, we maintain a mapping from logical expression IDs to group IDs. logical_id_to_group_index: HashMap, /// To speed up recursive merges, we maintain a mapping from group IDs to all logical expression IDs - /// that contain a reference to this group. The value logical_expr_ids may *NOT* be a representative ID. + /// that contain a reference to this group. + /// The value logical_expr_ids may *NOT* be a representative ID. group_referencing_exprs_index: HashMap>, + /// To speed up recursive merges, we maintain a mapping from goal member IDs to all goal member IDs + /// that contain a reference to this goal member. + /// The value physical_expr_ids may *NOT* be a representative ID. + goal_member_referencing_exprs_index: HashMap>, /// The shared next unique id to be used for goals, groups, logical expressions, and physical expressions. next_shared_id: i64, @@ -60,9 +72,11 @@ pub struct MemoryMemo { /// Information about a group: /// - All logical expressions in this group (always representative IDs). +/// - All goals that have this group as objective, for each physical properties. /// - Logical properties of this group. #[derive(Clone, Debug)] struct GroupInfo { expressions: HashSet, + goals: HashMap, logical_properties: LogicalProperties, } diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index fa619da7..dd0231b5 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -25,6 +25,16 @@ pub struct MergeGoalProduct { pub merged_goals: Vec, } +/// Result of merging two physical expressions. +#[derive(Debug)] +pub struct MergePhysicalExprProduct { + /// ID of the new physical expression. + pub new_physical_expr_id: PhysicalExpressionId, + + /// Physical expressions that were merged. + pub merged_physical_exprs: Vec, +} + /// Results of merge operations, including group and goal merges. #[derive(Debug, Default)] pub struct MergeProducts { @@ -33,6 +43,9 @@ pub struct MergeProducts { /// Goal merge results. pub goal_merges: Vec, + + /// Physical expression merge results. + pub expr_merges: Vec, } /// Base trait defining a shared implemention-defined error type for all memo-related traits. From 214f344b1478647f6ff02bf00444db0ac2371a04 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 14 May 2025 19:38:10 -0400 Subject: [PATCH 02/23] Add extra physical indexes and start implementation --- optd/src/memo/memory/implementation.rs | 38 ++++++++++++++++++++++---- optd/src/memo/memory/materialize.rs | 2 ++ optd/src/memo/memory/mod.rs | 5 ++++ optd/src/memo/mod.rs | 4 +-- 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index 0081b07b..b0e6bf2b 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -160,10 +160,33 @@ impl Memo for MemoryMemo { async fn add_goal_member( &mut self, - _goal_id: GoalId, - _member: GoalMemberId, + goal_id: GoalId, + member_id: GoalMemberId, ) -> Result { - todo!() + let repr_goal_id = self.find_repr_goal_id(goal_id).await?; + let repr_member_id = self.find_repr_goal_member_id(member_id).await?; + + // Check if the member is already in the goal (fast path). + if self + .id_to_goal_members + .get(&repr_goal_id) + .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_goal_id)) + .contains(&repr_member_id) + { + return Ok(false); + } + + // Otherwise, add the member to the goal (slow path). + self.id_to_goal_members + .get_mut(&repr_goal_id) + .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_goal_id)) + .insert(repr_member_id); + self.goal_member_to_goals_index + .entry(repr_member_id) + .or_default() + .insert(repr_goal_id); + + Ok(false) } async fn update_physical_expr_cost( @@ -176,9 +199,14 @@ impl Memo for MemoryMemo { async fn get_physical_expr_cost( &self, - _physical_expr_id: PhysicalExpressionId, + physical_expr_id: PhysicalExpressionId, ) -> Result, Infallible> { - todo!() + let repr_physical_expr_id = self.find_repr_physical_expr_id(physical_expr_id).await?; + Ok(self + .id_to_cost + .get(&repr_physical_expr_id) + .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_physical_expr_id)) + .clone()) } } diff --git a/optd/src/memo/memory/materialize.rs b/optd/src/memo/memory/materialize.rs index 8fbc993f..c800cef2 100644 --- a/optd/src/memo/memory/materialize.rs +++ b/optd/src/memo/memory/materialize.rs @@ -7,6 +7,7 @@ use crate::{ cir::*, memo::{Materialize, Representative}, }; +use hashbrown::HashSet; impl Materialize for MemoryMemo { async fn get_logical_expr_id( @@ -70,6 +71,7 @@ impl Materialize for MemoryMemo { // Otherwise, create a new entry in the memo table (slow path). let goal_id = self.next_goal_id(); self.id_to_goal.insert(goal_id, goal.clone().into_owned()); + self.id_to_goal_members.insert(goal_id, HashSet::new()); self.goal_to_id.insert(goal.clone().into_owned(), goal_id); // Connect the goal to its group. diff --git a/optd/src/memo/memory/mod.rs b/optd/src/memo/memory/mod.rs index 6be06abe..f2eb4235 100644 --- a/optd/src/memo/memory/mod.rs +++ b/optd/src/memo/memory/mod.rs @@ -32,6 +32,8 @@ pub struct MemoryMemo { // Goals. /// Key is always a representative ID. id_to_goal: HashMap, + /// The members (physical expressions & sub-goals) inside each goal. + id_to_goal_members: HashMap>, /// Each representative goal is mapped to its id, for faster lookups. goal_to_id: HashMap, @@ -55,10 +57,13 @@ pub struct MemoryMemo { /// that contain a reference to this group. /// The value logical_expr_ids may *NOT* be a representative ID. group_referencing_exprs_index: HashMap>, + /// To speed up recursive merges, we maintain a mapping from goal member IDs to all goal member IDs /// that contain a reference to this goal member. /// The value physical_expr_ids may *NOT* be a representative ID. goal_member_referencing_exprs_index: HashMap>, + /// Similar idea for sub-goals. + goal_member_to_goals_index: HashMap>, /// The shared next unique id to be used for goals, groups, logical expressions, and physical expressions. next_shared_id: i64, diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index dd0231b5..f242da13 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -237,14 +237,14 @@ pub trait Memo: Representative + Materialize + Sync + 'static { /// /// # Parameters /// * `goal_id` - ID of the goal to add the member to. - /// * `member` - The member to add, either a physical expression ID or another goal ID. + /// * `member_id` - The member to add, either a physical expression ID or another goal ID. /// /// # Returns /// True if the member was added to the goal, or false if it already existed. async fn add_goal_member( &mut self, goal_id: GoalId, - member: GoalMemberId, + member_id: GoalMemberId, ) -> Result; /// Updates the cost of a physical expression ID. From 381749479e0e1b1f7e425e2341c6070d4c578923 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 13:21:24 -0400 Subject: [PATCH 03/23] remove all costing related code in optimizer & memo --- optd/src/memo/memory/implementation.rs | 30 +--------- optd/src/memo/mod.rs | 41 +------------ optd/src/optimizer/handlers.rs | 44 +------------- optd/src/optimizer/jobs/execute.rs | 79 ++------------------------ optd/src/optimizer/jobs/manage.rs | 7 --- optd/src/optimizer/jobs/mod.rs | 24 +------- optd/src/optimizer/memo_io/egest.rs | 2 +- optd/src/optimizer/memo_io/ingest.rs | 1 - optd/src/optimizer/memo_io/mod.rs | 3 - optd/src/optimizer/mod.rs | 14 +---- optd/src/optimizer/tasks/launch.rs | 2 +- optd/src/optimizer/tasks/manage.rs | 76 +------------------------ optd/src/optimizer/tasks/mod.rs | 67 +--------------------- 13 files changed, 17 insertions(+), 373 deletions(-) diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index b0e6bf2b..0cd9619c 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -1,12 +1,11 @@ //! The main implementation of the in-memory memo table. -use std::collections::VecDeque; - use super::{ Infallible, Memo, MemoryMemo, MergeProducts, Representative, helpers::MemoryMemoHelper, }; use crate::{cir::*, memo::memory::GroupInfo}; use hashbrown::{HashMap, HashSet}; +use std::collections::VecDeque; impl Memo for MemoryMemo { async fn get_logical_properties( @@ -144,13 +143,6 @@ impl Memo for MemoryMemo { self.consolidate_merge_results(merge_operations).await } - async fn get_best_optimized_physical_expr( - &self, - _goal_id: GoalId, - ) -> Result, Infallible> { - todo!() - } - async fn get_all_goal_members( &self, _goal_id: GoalId, @@ -188,26 +180,6 @@ impl Memo for MemoryMemo { Ok(false) } - - async fn update_physical_expr_cost( - &mut self, - _physical_expr_id: PhysicalExpressionId, - _new_cost: Cost, - ) -> Result { - todo!() - } - - async fn get_physical_expr_cost( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> Result, Infallible> { - let repr_physical_expr_id = self.find_repr_physical_expr_id(physical_expr_id).await?; - Ok(self - .id_to_cost - .get(&repr_physical_expr_id) - .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_physical_expr_id)) - .clone()) - } } #[cfg(test)] diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index f242da13..bf19b011 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -207,20 +207,7 @@ pub trait Memo: Representative + Materialize + Sync + 'static { // // Physical expression and goal operations. // - - /// Gets the best optimized physical expression ID for a goal ID. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to retrieve the best expression for. - /// - /// # Returns - /// The ID of the lowest-cost physical implementation found so far for the goal, - /// along with its cost. Returns None if no optimized expression exists. - async fn get_best_optimized_physical_expr( - &self, - goal_id: GoalId, - ) -> Result, Self::MemoError>; - + /// Gets all members of a goal, which can be physical expressions or other goals. /// /// # Parameters @@ -246,30 +233,4 @@ pub trait Memo: Representative + Materialize + Sync + 'static { goal_id: GoalId, member_id: GoalMemberId, ) -> Result; - - /// Updates the cost of a physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to update. - /// * `new_cost` - New cost to assign to the physical expression. - /// - /// # Returns - /// Whether the cost of the expression has improved. - async fn update_physical_expr_cost( - &mut self, - physical_expr_id: PhysicalExpressionId, - new_cost: Cost, - ) -> Result; - - /// Gets the cost of a physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to retrieve the cost for. - /// - /// # Returns - /// The cost of the physical expression, or None if it doesn't exist. - async fn get_physical_expr_cost( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> Result, Self::MemoError>; } diff --git a/optd/src/optimizer/handlers.rs b/optd/src/optimizer/handlers.rs index a022945d..f0e3cef7 100644 --- a/optd/src/optimizer/handlers.rs +++ b/optd/src/optimizer/handlers.rs @@ -1,7 +1,4 @@ -use super::{ - JobId, Optimizer, TaskId, - jobs::{CostedContinuation, LogicalContinuation}, -}; +use super::{JobId, Optimizer, TaskId, jobs::LogicalContinuation}; use crate::{ cir::*, memo::Memo, @@ -139,26 +136,6 @@ impl Optimizer { todo!() } - /// This method handles fully optimized physical expressions with cost information. - /// - /// When a new optimized expression is found, it's added to the memo. If it becomes - /// the new best expression for its goal, continuations are notified and and clients - /// receive the corresponding egested plan. - /// - /// # Parameters - /// * `expression_id` - ID of the physical expression to process. - /// * `cost` - Cost information for the expression. - /// - /// # Returns - /// * `Result<(), Error>` - Success or error during processing. - pub(super) async fn process_new_costed_physical( - &mut self, - _expression_id: PhysicalExpressionId, - _cost: Cost, - ) -> Result<(), M::MemoError> { - todo!() - } - /// This method handles group creation for expressions with derived properties /// and updates any pending messages that depend on this group. /// @@ -201,25 +178,6 @@ impl Optimizer { .await } - /// Registers a continuation for receiving optimized physical expressions for a goal. - /// The continuation will be notified about the best existing expression and any better ones found. - /// - /// # Parameters - /// * `goal` - The goal to subscribe to. - /// * `continuation` - Continuation to call when new optimized expressions are found. - /// * `job_id` - ID of the job that initiated this request. - /// - /// # Returns - /// * `Result<(), Error>` - Success or error during processing. - pub(super) async fn process_goal_subscription( - &mut self, - _goal: &Goal, - _continuation: CostedContinuation, - _job_id: JobId, - ) -> Result<(), M::MemoError> { - todo!() - } - /// Retrieves the logical properties for the given group from the memo /// and sends them back to the requestor through the provided channel. /// diff --git a/optd/src/optimizer/jobs/execute.rs b/optd/src/optimizer/jobs/execute.rs index 0a795517..8e4612d3 100644 --- a/optd/src/optimizer/jobs/execute.rs +++ b/optd/src/optimizer/jobs/execute.rs @@ -1,4 +1,4 @@ -use super::{CostedContinuation, JobId, LogicalContinuation}; +use super::{JobId, LogicalContinuation}; use crate::{ cir::*, dsl::{ @@ -9,12 +9,10 @@ use crate::{ optimizer::{ EngineProduct, Optimizer, OptimizerMessage, hir_cir::{ - from_cir::{ - partial_logical_to_value, partial_physical_to_value, physical_properties_to_value, - }, + from_cir::{partial_logical_to_value, physical_properties_to_value}, into_cir::{ - hir_goal_to_cir, hir_group_id_to_cir, value_to_cost, value_to_logical_properties, - value_to_partial_logical, value_to_partial_physical, + hir_group_id_to_cir, value_to_logical_properties, value_to_partial_logical, + value_to_partial_physical, }, }, }, @@ -175,44 +173,6 @@ impl Optimizer { Ok(()) } - /// Executes a job to compute the cost of a physical expression. - /// - /// This creates an engine instance and launches the cost calculation process - /// for the specified physical expression. - /// - /// # Parameters - /// * `expression_id`: The ID of the physical expression to cost. - /// * `job_id`: The ID of the job to be executed. - pub(super) async fn execute_cost_expression( - &self, - expression_id: PhysicalExpressionId, - job_id: JobId, - ) -> Result<(), M::MemoError> { - use EngineProduct::*; - - let engine = self.init_engine(); - let plan = partial_physical_to_value(&self.egest_partial_plan(expression_id).await?); - - let message_tx = self.message_tx.clone(); - tokio::spawn(async move { - let response = engine - .launch( - "cost", - vec![plan], - Arc::new(move |cost| { - Box::pin( - async move { NewCostedPhysical(expression_id, value_to_cost(&cost)) }, - ) - }), - ) - .await; - - Self::process_engine_response(job_id, message_tx, response).await; - }); - - Ok(()) - } - /// Executes a job to continue processing with a logical expression result. /// /// This materializes the logical expression and passes it to the continuation. @@ -251,32 +211,6 @@ impl Optimizer { Ok(()) } - /// Executes a job to continue processing with an optimized physical expression result. - /// - /// This materializes the physical expression and passes it along with its cost - /// to the continuation. - /// - /// # Parameters - /// * `expression_id`: The ID of the physical expression to continue with. - /// * `k`: The continuation function to be called with the materialized plan. - /// * `job_id`: The ID of the job to be executed. - pub(super) async fn execute_continue_with_costed( - &self, - expression_id: PhysicalExpressionId, - k: CostedContinuation, - job_id: JobId, - ) -> Result<(), M::MemoError> { - let plan = partial_physical_to_value(&self.egest_partial_plan(expression_id).await?); - - let message_tx = self.message_tx.clone(); - tokio::spawn(async move { - let response = k.0(plan).await; - Self::process_engine_response(job_id, message_tx, response).await; - }); - - Ok(()) - } - /// Helper function to process the engine response and send it to the optimizer. /// /// Handles `YieldGroup` and `YieldGoal` responses by sending subscription messages @@ -300,10 +234,7 @@ impl Optimizer { SubscribeGroup(hir_group_id_to_cir(&group_id), LogicalContinuation(k)), job_id, ), - YieldGoal(goal, k) => OptimizerMessage::product( - SubscribeGoal(hir_goal_to_cir(&goal), CostedContinuation(k)), - job_id, - ), + YieldGoal(_, _) => todo!("Decide what to do here depending on the cost model"), }; engine_tx diff --git a/optd/src/optimizer/jobs/manage.rs b/optd/src/optimizer/jobs/manage.rs index ee8d1d10..607567a3 100644 --- a/optd/src/optimizer/jobs/manage.rs +++ b/optd/src/optimizer/jobs/manage.rs @@ -63,17 +63,10 @@ impl Optimizer { self.execute_implementation_rule(rule_name, expression_id, goal_id, job_id) .await?; } - CostExpression(expression_id) => { - self.execute_cost_expression(expression_id, job_id).await?; - } ContinueWithLogical(expression_id, k) => { self.execute_continue_with_logical(expression_id, k, job_id) .await?; } - ContinueWithCosted(expression_id, k) => { - self.execute_continue_with_costed(expression_id, k, job_id) - .await?; - } } } diff --git a/optd/src/optimizer/jobs/mod.rs b/optd/src/optimizer/jobs/mod.rs index 409ab3ba..78968b3d 100644 --- a/optd/src/optimizer/jobs/mod.rs +++ b/optd/src/optimizer/jobs/mod.rs @@ -1,9 +1,6 @@ use super::{EngineProduct, TaskId}; use crate::{ - cir::{ - GoalId, GroupId, ImplementationRule, LogicalExpressionId, PhysicalExpressionId, - TransformationRule, - }, + cir::{GoalId, GroupId, ImplementationRule, LogicalExpressionId, TransformationRule}, dsl::{ analyzer::hir::Value, engine::{Continuation, EngineResponse}, @@ -29,10 +26,6 @@ pub(crate) struct Job(pub TaskId, pub JobKind); #[derive(Clone)] pub(crate) struct LogicalContinuation(Continuation>); -/// Represents a continuation for processing costed physical expressions. -#[derive(Clone)] -pub(crate) struct CostedContinuation(Continuation>); - /// Enumeration of different types of jobs in the optimizer. /// /// Each variant represents a specific optimization operation that can be @@ -55,26 +48,11 @@ pub(crate) enum JobKind { /// /// This job generates physical implementations of a logical expression /// based on specific implementation strategies. - #[allow(dead_code)] ImplementExpression(ImplementationRule, LogicalExpressionId, GoalId), - /// Starts computing the cost of a physical expression. - /// - /// This job estimates the execution cost of a physical implementation - /// to aid in selecting the optimal plan. - #[allow(dead_code)] - CostExpression(PhysicalExpressionId), - /// Continues processing with a logical expression result. /// /// This job represents a continuation-passing-style callback for /// handling the result of a logical expression operation. ContinueWithLogical(LogicalExpressionId, LogicalContinuation), - - /// Continues processing with an optimized expression result. - /// - /// This job represents a continuation-passing-style callback for - /// handling the result of an optimized physical expression operation. - #[allow(dead_code)] - ContinueWithCosted(PhysicalExpressionId, CostedContinuation), } diff --git a/optd/src/optimizer/memo_io/egest.rs b/optd/src/optimizer/memo_io/egest.rs index 4c6a4fba..0a87ca47 100644 --- a/optd/src/optimizer/memo_io/egest.rs +++ b/optd/src/optimizer/memo_io/egest.rs @@ -7,6 +7,7 @@ use async_recursion::async_recursion; use futures::future::try_join_all; use std::sync::Arc; +// TODO: Until costing is resolved, this file is not used (and just left here for reference) impl Optimizer { /// Recursively transforms a physical expression ID in the memo into a complete physical plan. /// @@ -21,7 +22,6 @@ impl Optimizer { /// * `Ok(None)` if any goal ID lacks a best expression ID. /// * `Err(Error)` if a memo operation fails. #[async_recursion] - #[allow(dead_code)] pub(crate) async fn egest_best_plan( &self, expression_id: PhysicalExpressionId, diff --git a/optd/src/optimizer/memo_io/ingest.rs b/optd/src/optimizer/memo_io/ingest.rs index ab648a0d..898bbd48 100644 --- a/optd/src/optimizer/memo_io/ingest.rs +++ b/optd/src/optimizer/memo_io/ingest.rs @@ -151,7 +151,6 @@ impl Optimizer { } } - #[allow(dead_code)] async fn probe_ingest_physical_operator( &mut self, operator: &Operator>, diff --git a/optd/src/optimizer/memo_io/mod.rs b/optd/src/optimizer/memo_io/mod.rs index 2fa128dd..d18f2dce 100644 --- a/optd/src/optimizer/memo_io/mod.rs +++ b/optd/src/optimizer/memo_io/mod.rs @@ -1,6 +1,3 @@ -mod egest; mod ingest; -#[allow(unused)] -pub(super) use egest::*; pub(super) use ingest::*; diff --git a/optd/src/optimizer/mod.rs b/optd/src/optimizer/mod.rs index de38da2f..3b13cd92 100644 --- a/optd/src/optimizer/mod.rs +++ b/optd/src/optimizer/mod.rs @@ -17,7 +17,7 @@ mod merge; mod retriever; mod tasks; -use jobs::{CostedContinuation, Job, JobId, LogicalContinuation}; +use jobs::{Job, JobId, LogicalContinuation}; use retriever::OptimizerRetriever; use tasks::{Task, TaskId}; @@ -49,17 +49,11 @@ enum EngineProduct { /// New physical implementation for a goal, awaiting recursive optimization. NewPhysicalPartial(PartialPhysicalPlan, GoalId), - /// Fully optimized physical expression with complete costing. - NewCostedPhysical(PhysicalExpressionId, Cost), - /// Create a new group with the provided logical properties. CreateGroup(LogicalExpressionId, LogicalProperties), /// Subscribe to logical expressions in a specific group. SubscribeGroup(GroupId, LogicalContinuation), - - /// Subscribe to costed physical expressions for a goal. - SubscribeGoal(Goal, CostedContinuation), } /// Messages passed within the optimization system. @@ -250,18 +244,12 @@ impl Optimizer { NewPhysicalPartial(plan, goal_id) => { self.process_new_physical_partial(plan, goal_id, job_id).await?; } - NewCostedPhysical(expression_id, cost) => { - self.process_new_costed_physical(expression_id, cost).await?; - } CreateGroup(expression_id, properties) => { self.process_create_group(expression_id, &properties, job_id).await?; } SubscribeGroup(group_id, continuation) => { self.process_group_subscription(group_id, continuation, job_id).await?; } - SubscribeGoal(goal, continuation) => { - self.process_goal_subscription(&goal, continuation, job_id).await?; - } } } diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index b67311fc..62506769 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -359,6 +359,6 @@ impl Optimizer { &mut self, _expression_id: PhysicalExpressionId, ) -> Result { - todo!() + todo!("What do we decide to do with costing is an open question"); } } diff --git a/optd/src/optimizer/tasks/manage.rs b/optd/src/optimizer/tasks/manage.rs index 3cf528ff..b80119b7 100644 --- a/optd/src/optimizer/tasks/manage.rs +++ b/optd/src/optimizer/tasks/manage.rs @@ -1,7 +1,6 @@ use super::{ - ContinueWithCostedTask, ContinueWithLogicalTask, CostExpressionTask, ExploreGroupTask, - ForkCostedTask, ForkLogicalTask, ImplementExpressionTask, OptimizeGoalTask, OptimizePlanTask, - Task, TaskId, TransformExpressionTask, + ContinueWithLogicalTask, ExploreGroupTask, ForkLogicalTask, ImplementExpressionTask, + OptimizeGoalTask, OptimizePlanTask, Task, TaskId, TransformExpressionTask, }; use crate::{memo::Memo, optimizer::Optimizer}; @@ -36,7 +35,6 @@ impl Optimizer { // Direct task type access by ID - immutable versions /// Get a task as an OptimizePlanTask by its ID. - #[allow(dead_code)] pub(crate) fn get_optimize_plan_task(&self, task_id: TaskId) -> Option<&OptimizePlanTask> { self.get_task(task_id).and_then(|task| match &task { Task::OptimizePlan(task) => Some(task), @@ -45,7 +43,6 @@ impl Optimizer { } /// Get a task as an OptimizeGoalTask by its ID. - #[allow(dead_code)] pub(crate) fn get_optimize_goal_task(&self, task_id: TaskId) -> Option<&OptimizeGoalTask> { self.get_task(task_id).and_then(|task| match &task { Task::OptimizeGoal(task) => Some(task), @@ -62,7 +59,6 @@ impl Optimizer { } /// Get a task as an ImplementExpressionTask by its ID. - #[allow(dead_code)] pub(crate) fn get_implement_expression_task( &self, task_id: TaskId, @@ -84,15 +80,6 @@ impl Optimizer { }) } - /// Get a task as a CostExpressionTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_cost_expression_task(&self, task_id: TaskId) -> Option<&CostExpressionTask> { - self.get_task(task_id).and_then(|task| match &task { - Task::CostExpression(task) => Some(task), - _ => None, - }) - } - /// Get a task as a ForkLogicalTask by its ID. pub(crate) fn get_fork_logical_task(&self, task_id: TaskId) -> Option<&ForkLogicalTask> { self.get_task(task_id).and_then(|task| match &task { @@ -101,15 +88,6 @@ impl Optimizer { }) } - /// Get a task as a ForkCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_fork_costed_task(&self, task_id: TaskId) -> Option<&ForkCostedTask> { - self.get_task(task_id).and_then(|task| match &task { - Task::ForkCosted(task) => Some(task), - _ => None, - }) - } - /// Get a task as a ContinueWithLogicalTask by its ID. pub(crate) fn get_continue_with_logical_task( &self, @@ -121,22 +99,9 @@ impl Optimizer { }) } - /// Get a task as a ContinueWithCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_continue_with_costed_task( - &self, - task_id: TaskId, - ) -> Option<&ContinueWithCostedTask> { - self.get_task(task_id).and_then(|task| match &task { - Task::ContinueWithCosted(task) => Some(task), - _ => None, - }) - } - // Direct task type access by ID - mutable versions /// Get a mutable task as an OptimizePlanTask by its ID. - #[allow(dead_code)] pub(crate) fn get_optimize_plan_task_mut( &mut self, task_id: TaskId, @@ -170,7 +135,6 @@ impl Optimizer { } /// Get a mutable task as an ImplementExpressionTask by its ID. - #[allow(dead_code)] pub(crate) fn get_implement_expression_task_mut( &mut self, task_id: TaskId, @@ -192,18 +156,6 @@ impl Optimizer { }) } - /// Get a mutable task as a CostExpressionTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_cost_expression_task_mut( - &mut self, - task_id: TaskId, - ) -> Option<&mut CostExpressionTask> { - self.get_task_mut(task_id).and_then(|task| match task { - Task::CostExpression(task) => Some(task), - _ => None, - }) - } - /// Get a mutable task as a ForkLogicalTask by its ID. pub(crate) fn get_fork_logical_task_mut( &mut self, @@ -215,18 +167,6 @@ impl Optimizer { }) } - /// Get a mutable task as a ForkCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_fork_costed_task_mut( - &mut self, - task_id: TaskId, - ) -> Option<&mut ForkCostedTask> { - self.get_task_mut(task_id).and_then(|task| match task { - Task::ForkCosted(task) => Some(task), - _ => None, - }) - } - /// Get a mutable task as a ContinueWithLogicalTask by its ID. pub(crate) fn get_continue_with_logical_task_mut( &mut self, @@ -237,16 +177,4 @@ impl Optimizer { _ => None, }) } - - /// Get a mutable task as a ContinueWithCostedTask by its ID. - #[allow(dead_code)] - pub(crate) fn get_continue_with_costed_task_mut( - &mut self, - task_id: TaskId, - ) -> Option<&mut ContinueWithCostedTask> { - self.get_task_mut(task_id).and_then(|task| match task { - Task::ContinueWithCosted(task) => Some(task), - _ => None, - }) - } } diff --git a/optd/src/optimizer/tasks/mod.rs b/optd/src/optimizer/tasks/mod.rs index c020e4fd..ffba21e7 100644 --- a/optd/src/optimizer/tasks/mod.rs +++ b/optd/src/optimizer/tasks/mod.rs @@ -1,7 +1,7 @@ -use super::jobs::{CostedContinuation, LogicalContinuation}; +use super::jobs::LogicalContinuation; use crate::cir::{ - Cost, GoalId, GroupId, ImplementationRule, LogicalExpressionId, LogicalPlan, - PhysicalExpressionId, PhysicalPlan, TransformationRule, + GoalId, GroupId, ImplementationRule, LogicalExpressionId, LogicalPlan, PhysicalPlan, + TransformationRule, }; use hashbrown::HashSet; use tokio::sync::mpsc::Sender; @@ -23,17 +23,10 @@ pub(crate) enum Task { OptimizePlan(OptimizePlanTask), OptimizeGoal(OptimizeGoalTask), ExploreGroup(ExploreGroupTask), - #[allow(dead_code)] ImplementExpression(ImplementExpressionTask), TransformExpression(TransformExpressionTask), - #[allow(dead_code)] - CostExpression(CostExpressionTask), ForkLogical(ForkLogicalTask), - #[allow(dead_code)] - ForkCosted(ForkCostedTask), ContinueWithLogical(ContinueWithLogicalTask), - #[allow(dead_code)] - ContinueWithCosted(ContinueWithCostedTask), } //============================================================================= @@ -121,7 +114,6 @@ pub(crate) struct TransformExpressionTask { /// Task to implement a logical expression into a physical expression. #[derive(Clone)] -#[allow(dead_code)] pub(crate) struct ImplementExpressionTask { /// The implementation rule to apply. pub rule: ImplementationRule, @@ -138,25 +130,6 @@ pub(crate) struct ImplementExpressionTask { pub fork_in: Option, } -/// Task to cost a physical expression. -#[derive(Clone)] -#[allow(dead_code)] -pub(crate) struct CostExpressionTask { - /// The physical expression to cost. - pub expression_id: PhysicalExpressionId, - /// The current upper bound on the allowed cost budget. - pub budget: Cost, - - // Output tasks that get fed by the output of this task. - /// `OptimizeGoalTask` corresponding goal optimization task. - pub optimize_goal_out: HashSet, - - // Input tasks that feed this task. - /// `ForkCostedTask` cost fork points encountered during the - /// costing. - pub fork_in: Option, -} - /// Task to fork the logical optimization process. #[derive(Clone)] pub(crate) struct ForkLogicalTask { @@ -177,27 +150,6 @@ pub(crate) struct ForkLogicalTask { pub continue_with_logical_in: HashSet, } -/// Task to fork the costed optimization process. -#[derive(Clone)] -#[allow(dead_code)] -pub(crate) struct ForkCostedTask { - /// The fork continuation. - pub continuation: CostedContinuation, - /// The current upper bound on the allowed cost budget. - pub budget: Cost, - - /// `ContinueWithCostedTask` | `CostExpressionTask` that gets fed by the - /// output of this task. - pub out: TaskId, - - // Input tasks that feed this task. - /// `OptimizeGoalTask` corresponding goal optimization task producing - /// costed expressions. - pub optimize_goal_in: TaskId, - /// `ContinueWithCosted` tasks spawned off and producing for this task. - pub continue_with_costed_in: HashSet, -} - /// Task to continue with a logical expression. #[derive(Clone)] pub(crate) struct ContinueWithLogicalTask { @@ -209,16 +161,3 @@ pub(crate) struct ContinueWithLogicalTask { /// Potential `ForkLogicalTask` fork spawned off from this task. pub fork_in: Option, } - -/// Task to continue with a costed expression. -#[derive(Clone)] -#[allow(dead_code)] -pub(crate) struct ContinueWithCostedTask { - /// The physical expression to continue with. - pub expression_id: PhysicalExpressionId, - - /// `ForkCostedTask` that gets fed by the output of this continuation. - pub fork_out: TaskId, - /// Potential `ForkCostedTask` fork spawned off from this task. - pub fork_in: Option, -} From 973efc632997676155d705537b169c65d32ed9eb Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 13:22:26 -0400 Subject: [PATCH 04/23] fmt --- optd/src/memo/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index bf19b011..024cdfc2 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -207,7 +207,7 @@ pub trait Memo: Representative + Materialize + Sync + 'static { // // Physical expression and goal operations. // - + /// Gets all members of a goal, which can be physical expressions or other goals. /// /// # Parameters From a01ff0ecf31d81df543c74c39a6cd0ec82d664a5 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 15:51:12 -0400 Subject: [PATCH 05/23] fix regression and update memo struct --- optd/src/memo/memory/helpers.rs | 22 +++++++------ optd/src/memo/memory/implementation.rs | 43 ++++++++++++-------------- optd/src/memo/memory/materialize.rs | 16 ++++++---- optd/src/memo/memory/mod.rs | 21 ++++++++----- optd/src/memo/mod.rs | 4 +-- 5 files changed, 58 insertions(+), 48 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index 4bf5896c..ad6298eb 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -35,21 +35,25 @@ impl MemoryMemo { PhysicalExpressionId(id) } + /// Takes the set of [`LogicalExpressionId`] that reference a group, mapped to their + /// representatives. + fn take_referencing_expr_set(&mut self, group_id: GroupId) -> HashSet { + self.group_referencing_exprs_index + .remove(&group_id) + .unwrap_or_default() + .iter() + .map(|id| self.repr_logical_expr_id.find(id)) + .collect() + } + /// Merges the two sets of logical expressions that reference the two groups into a single set /// of expressions under a new [`GroupId`]. /// /// If a group does not exist, then the set of expressions referencing it is the empty set. fn merge_referencing_exprs(&mut self, group1: GroupId, group2: GroupId, new_group: GroupId) { // Remove the entries for the original two groups that we want to merge. - let exprs1 = self - .group_referencing_exprs_index - .remove(&group1) - .unwrap_or_default(); - let exprs2 = self - .group_referencing_exprs_index - .remove(&group2) - .unwrap_or_default(); - + let exprs1 = self.take_referencing_expr_set(group1); + let exprs2 = self.take_referencing_expr_set(group2); let new_set = exprs1.union(&exprs2).copied().collect(); // Update the index for the new group / set of logical expressions. diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index 0cd9619c..39d6b9e3 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -111,8 +111,6 @@ impl Memo for MemoryMemo { group_id_1: GroupId, group_id_2: GroupId, ) -> Result { - println!("Merging: {:?} and {:?}", group_id_1, group_id_2); - let mut merge_operations = vec![]; let mut pending_merges = VecDeque::from(vec![(group_id_1, group_id_2)]); @@ -124,7 +122,7 @@ impl Memo for MemoryMemo { continue; } - // Perform the group merge, creating a new representative + // Perform the group merge, creating a new representative. let (new_group_id, merge_product) = self.merge_group_pair(group_id_1, group_id_2).await?; merge_operations.push(merge_product); @@ -145,9 +143,22 @@ impl Memo for MemoryMemo { async fn get_all_goal_members( &self, - _goal_id: GoalId, - ) -> Result, Infallible> { - todo!() + goal_id: GoalId, + ) -> Result, Infallible> { + let repr_goal_id = self.find_repr_goal_id(goal_id).await?; + + // Get all members and map each to its representative ID. + let mut result = HashSet::new(); + for member_id in &self + .goal_info + .get(&repr_goal_id) + .expect("Goal not found in memo table") + .members + { + result.insert(self.find_repr_goal_member_id(*member_id).await?); + } + + Ok(result) } async fn add_goal_member( @@ -158,25 +169,11 @@ impl Memo for MemoryMemo { let repr_goal_id = self.find_repr_goal_id(goal_id).await?; let repr_member_id = self.find_repr_goal_member_id(member_id).await?; - // Check if the member is already in the goal (fast path). - if self - .id_to_goal_members - .get(&repr_goal_id) - .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_goal_id)) - .contains(&repr_member_id) - { - return Ok(false); - } - - // Otherwise, add the member to the goal (slow path). - self.id_to_goal_members + self.goal_info .get_mut(&repr_goal_id) - .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_goal_id)) + .expect("Goal not found in memo table") + .members .insert(repr_member_id); - self.goal_member_to_goals_index - .entry(repr_member_id) - .or_default() - .insert(repr_goal_id); Ok(false) } diff --git a/optd/src/memo/memory/materialize.rs b/optd/src/memo/memory/materialize.rs index c800cef2..fe4195ec 100644 --- a/optd/src/memo/memory/materialize.rs +++ b/optd/src/memo/memory/materialize.rs @@ -2,7 +2,7 @@ //! //! See the documentation for [`Materialize`] for more information. -use super::{Infallible, MemoryMemo, helpers::MemoryMemoHelper}; +use super::{GoalInfo, Infallible, MemoryMemo, helpers::MemoryMemoHelper}; use crate::{ cir::*, memo::{Materialize, Representative}, @@ -70,8 +70,12 @@ impl Materialize for MemoryMemo { // Otherwise, create a new entry in the memo table (slow path). let goal_id = self.next_goal_id(); - self.id_to_goal.insert(goal_id, goal.clone().into_owned()); - self.id_to_goal_members.insert(goal_id, HashSet::new()); + let goal_info = GoalInfo { + goal: goal.as_ref().clone(), + members: HashSet::new(), + }; + + self.goal_info.insert(goal_id, goal_info); self.goal_to_id.insert(goal.clone().into_owned(), goal_id); // Connect the goal to its group. @@ -88,9 +92,10 @@ impl Materialize for MemoryMemo { async fn materialize_goal(&self, goal_id: GoalId) -> Result { let repr_goal_id = self.find_repr_goal_id(goal_id).await?; Ok(self - .id_to_goal + .goal_info .get(&repr_goal_id) - .unwrap_or_else(|| panic!("{:?} not found in memo table", repr_goal_id)) + .expect("Goal not found in memo table") + .goal .clone()) } @@ -127,7 +132,6 @@ impl Materialize for MemoryMemo { // Update the physical expression ID indexes. self.id_to_physical_expr .insert(expr_id, remapped_expr.clone().into_owned()); - self.id_to_cost.insert(expr_id, None); self.physical_expr_to_id .insert(remapped_expr.into_owned(), expr_id); diff --git a/optd/src/memo/memory/mod.rs b/optd/src/memo/memory/mod.rs index f2eb4235..4a04eab3 100644 --- a/optd/src/memo/memory/mod.rs +++ b/optd/src/memo/memory/mod.rs @@ -31,9 +31,7 @@ pub struct MemoryMemo { // Goals. /// Key is always a representative ID. - id_to_goal: HashMap, - /// The members (physical expressions & sub-goals) inside each goal. - id_to_goal_members: HashMap>, + goal_info: HashMap, /// Each representative goal is mapped to its id, for faster lookups. goal_to_id: HashMap, @@ -45,25 +43,21 @@ pub struct MemoryMemo { /// Key is always a representative ID. id_to_physical_expr: HashMap, - /// The cost of each representative physical expression, if available. - id_to_cost: HashMap>, /// Each representative expression is mapped to its id, for faster lookups. physical_expr_to_id: HashMap, // Indexes: only deal with representative IDs, but speeds up most queries. /// To speed up expr->group lookup, we maintain a mapping from logical expression IDs to group IDs. logical_id_to_group_index: HashMap, + /// To speed up recursive merges, we maintain a mapping from group IDs to all logical expression IDs /// that contain a reference to this group. /// The value logical_expr_ids may *NOT* be a representative ID. group_referencing_exprs_index: HashMap>, - /// To speed up recursive merges, we maintain a mapping from goal member IDs to all goal member IDs /// that contain a reference to this goal member. /// The value physical_expr_ids may *NOT* be a representative ID. goal_member_referencing_exprs_index: HashMap>, - /// Similar idea for sub-goals. - goal_member_to_goals_index: HashMap>, /// The shared next unique id to be used for goals, groups, logical expressions, and physical expressions. next_shared_id: i64, @@ -85,3 +79,14 @@ struct GroupInfo { goals: HashMap, logical_properties: LogicalProperties, } + +/// Information about a goal: +/// - The goal (group + properties), always representative. +/// - The members of the goal (physical expression IDs or other (sub)goal IDs), may *NOT* be representative. +#[derive(Clone, Debug)] +struct GoalInfo { + /// The goal (group + properties) + goal: Goal, + /// The members of the goal (physical expression IDs or other sub-goal IDs). + members: HashSet, +} diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index 024cdfc2..d2a21397 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -214,11 +214,11 @@ pub trait Memo: Representative + Materialize + Sync + 'static { /// * `goal_id` - ID of the goal to retrieve members from. /// /// # Returns - /// A vector of goal members, each being either a physical expression ID or another goal ID. + /// A set of goal members, each being either a physical expression ID or another goal ID. async fn get_all_goal_members( &self, goal_id: GoalId, - ) -> Result, Self::MemoError>; + ) -> Result, Self::MemoError>; /// Adds a member to a goal. /// From 5c5ad3182b3b29bfb2cd3443e83967fa6803653b Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 17:30:28 -0400 Subject: [PATCH 06/23] Start implementing goal merges --- optd/src/memo/memory/helpers.rs | 72 +++++++++++++++++++++----- optd/src/memo/memory/implementation.rs | 36 ++++++++++++- optd/src/memo/memory/materialize.rs | 2 +- optd/src/memo/memory/mod.rs | 7 +-- 4 files changed, 97 insertions(+), 20 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index ad6298eb..2873ab9a 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -4,7 +4,7 @@ use super::{Infallible, MemoryMemo}; use crate::{ cir::*, memo::{ - Materialize, Memo, MergeGroupProduct, MergeProducts, Representative, memory::GroupInfo, + Materialize, Memo, MergeGoalProduct, MergeGroupProduct, Representative, memory::GroupInfo, }, }; use hashbrown::{HashMap, HashSet}; @@ -102,10 +102,15 @@ pub trait MemoryMemoHelper: Memo { new_group_id: GroupId, ) -> Result, Infallible>; - async fn consolidate_merge_results( + async fn consolidate_merge_group_products( &self, merge_operations: Vec, - ) -> Result; + ) -> Result, Infallible>; + + async fn process_goal_merges( + &mut self, + group_merges: &[MergeGroupProduct], + ) -> Result, Infallible>; } impl MemoryMemoHelper for MemoryMemo { @@ -294,6 +299,18 @@ impl MemoryMemoHelper for MemoryMemo { .copied() .collect(); + // Combine goals from both groups. + let all_goals: HashMap<_, Vec<_>> = group1_info + .goals + .into_iter() + .chain(group2_info.goals) + .fold(HashMap::new(), |mut acc, (props, mut goals)| { + acc.entry(props) + .and_modify(|existing_goals| existing_goals.append(&mut goals)) + .or_insert(goals); + acc + }); + // Update the union-find structure. self.repr_group_id.merge(&group_id_1, &new_group_id); self.repr_group_id.merge(&group_id_2, &new_group_id); @@ -306,7 +323,7 @@ impl MemoryMemoHelper for MemoryMemo { // Create and save the new group. let new_group_info = GroupInfo { expressions: all_exprs, - goals: HashMap::new(), // TODO(Alexis): this will be used to trigger the goal merge process. + goals: all_goals, logical_properties: group1_info.logical_properties, }; self.group_info.insert(new_group_id, new_group_info); @@ -451,16 +468,16 @@ impl MemoryMemoHelper for MemoryMemo { Ok(new_pending_merges) } - /// Consolidates merge operations into a comprehensive result. + /// Consolidates merge group operations into a comprehensive result. /// - /// This function takes a list of individual merge operations and consolidates them into a - /// complete picture of which groups were merged into each final representative. + /// This function takes a list of individual group merge operations and consolidates + /// them into a complete picture of which groups were merged into each final representative. /// /// It also handles cases where past representatives themselves are merged into newer ones. - async fn consolidate_merge_results( + async fn consolidate_merge_group_products( &self, merge_operations: Vec, - ) -> Result { + ) -> Result, Infallible> { // Collect operations into a map from representative to all merged groups. let mut consolidated_map: HashMap> = HashMap::new(); @@ -491,10 +508,37 @@ impl MemoryMemoHelper for MemoryMemo { }) .collect(); - Ok(MergeProducts { - group_merges, - goal_merges: vec![], - expr_merges: vec![], - }) + Ok(group_merges) + } + + /// Processes goal merges and returns the results. + /// + /// This function consolidates the merge operations for goals and returns a list of + /// `MergeGoalProduct` instances representing the merged goals. + /// + /// # Parameters + /// + /// * `group_merges` - A slice of `MergeGroupProduct` instances representing the merged groups. + /// + /// # Returns + /// + /// A vector of `MergeGoalProduct` instances representing the merged goals. + async fn process_goal_merges( + &mut self, + group_merges: &[MergeGroupProduct], + ) -> Result, Infallible> { + let mut merge_products = Vec::new(); + + for (goal_id, goal_info) in self.goal_info.iter() { + let repr_goal_id = self.find_repr_goal_id(*goal_id).await?; + if repr_goal_id != *goal_id { + merge_products.push(MergeGoalProduct { + new_goal_id: repr_goal_id, + merged_goals: vec![*goal_id], + }); + } + } + + Ok(merge_products) } } diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index 39d6b9e3..6a9058c8 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -90,12 +90,31 @@ impl Memo for MemoryMemo { /// 3. When expressions become identical, their containing groups also need to be merged, /// creating a cascading effect. /// + /// 4. After all logical groups have been merged, we need to identify and merge all goals + /// that have become equivalent. Goals are considered equivalent when they reference the + /// same group and share identical logical properties. + /// + /// 5. When goals are merged, this triggers a cascading effect on physical expressions: + /// - Physical expressions that reference merged goals need updating. + /// - This may cause previously distinct physical expressions to become identical. + /// - These newly identical expressions must then be merged. + /// - Each merge may create more identical expressions, continuing the chain reaction + /// until the entire structure is consistent. + /// /// To handle this complexity, we use an iterative approach: + /// PHASE 1: LOGICAL /// - We maintain a queue of group pairs that need to be merged. /// - For each pair, we create a new representative group. /// - We update all expressions that reference the merged groups. /// - If this creates new equivalences, we add the affected groups to the merge queue. /// - We continue until no more merges are needed. + /// PHASE 2: PHYSICAL + /// - We inspect the result of logical merges, and identify the goals to merge in + /// the corresponding group_info of the new representative group. + /// - For each group, we create a new representative goal. + /// - We update all expressions that reference the merged goals. + /// - If this creates new equivalences, we add the affected goals to the merge queue. + /// - We continue until no more merges are needed. /// /// This approach ensures that all cascading effects are properly handled and the memo /// structure remains consistent after the merge. @@ -136,9 +155,22 @@ impl Memo for MemoryMemo { pending_merges.extend(new_pending_merges); } - // Consolidate the merge results by replacing the incremental merges + // Consolidate the merge products by replacing the incremental merges // with consolidated results that show the full picture. - self.consolidate_merge_results(merge_operations).await + let group_merges = self + .consolidate_merge_group_products(merge_operations) + .await?; + + // Now handle goal merges: we do not need to pass any extra parameters as + // the goals to merge are gathered in the `goals` member of each new + // representative group. + let goal_merges = self.process_goal_merges(&group_merges).await?; + + Ok(MergeProducts { + group_merges, + goal_merges, + expr_merges: vec![], // TODO(Alexis): Implement expression merges. + }) } async fn get_all_goal_members( diff --git a/optd/src/memo/memory/materialize.rs b/optd/src/memo/memory/materialize.rs index fe4195ec..aa351073 100644 --- a/optd/src/memo/memory/materialize.rs +++ b/optd/src/memo/memory/materialize.rs @@ -84,7 +84,7 @@ impl Materialize for MemoryMemo { .get_mut(&group_id) .expect("Group not found in memo table") .goals - .insert(props, goal_id); + .insert(props, vec![goal_id]); Ok(goal_id) } diff --git a/optd/src/memo/memory/mod.rs b/optd/src/memo/memory/mod.rs index 4a04eab3..f9af48fc 100644 --- a/optd/src/memo/memory/mod.rs +++ b/optd/src/memo/memory/mod.rs @@ -76,7 +76,10 @@ pub struct MemoryMemo { #[derive(Clone, Debug)] struct GroupInfo { expressions: HashSet, - goals: HashMap, + // We make the key a Vec so that we can accumulate + // goals to merge while merging groups. Outside of merging, + // the value is always a single goal. + goals: HashMap>, logical_properties: LogicalProperties, } @@ -85,8 +88,6 @@ struct GroupInfo { /// - The members of the goal (physical expression IDs or other (sub)goal IDs), may *NOT* be representative. #[derive(Clone, Debug)] struct GoalInfo { - /// The goal (group + properties) goal: Goal, - /// The members of the goal (physical expression IDs or other sub-goal IDs). members: HashSet, } From a58cd325430aeda9050b20def7f057359b2b0f4a Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 18:46:52 -0400 Subject: [PATCH 07/23] Merge goals completed --- optd/src/memo/memory/helpers.rs | 105 ++++++++++++++++++------- optd/src/memo/memory/implementation.rs | 37 +++++---- optd/src/memo/memory/mod.rs | 2 +- 3 files changed, 98 insertions(+), 46 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index 2873ab9a..aa01c543 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -68,7 +68,7 @@ impl MemoryMemo { /// /// See the implementation itself for the documentation of each helper method. pub trait MemoryMemoHelper: Memo { - async fn find_repr_goal_member_id(&self, id: GoalMemberId) -> Result; + fn find_repr_goal_member_id(&self, id: GoalMemberId) -> GoalMemberId; async fn remap_logical_expr<'a>( &self, @@ -118,17 +118,17 @@ impl MemoryMemoHelper for MemoryMemo { /// /// This handles both variants of `GoalMemberId` by finding the appropriate representative /// based on whether it's a Goal or PhysicalExpr. - async fn find_repr_goal_member_id(&self, id: GoalMemberId) -> Result { + fn find_repr_goal_member_id(&self, id: GoalMemberId) -> GoalMemberId { use GoalMemberId::*; match id { GoalId(goal_id) => { - let repr_goal_id = self.find_repr_goal_id(goal_id).await?; - Ok(GoalId(repr_goal_id)) + let repr_goal_id = self.repr_goal_id.find(&goal_id); + GoalId(repr_goal_id) } PhysicalExpressionId(expr_id) => { - let repr_expr_id = self.find_repr_physical_expr_id(expr_id).await?; - Ok(PhysicalExpressionId(repr_expr_id)) + let repr_expr_id = self.repr_physical_expr_id.find(&expr_id); + PhysicalExpressionId(repr_expr_id) } } } @@ -198,27 +198,33 @@ impl MemoryMemoHelper for MemoryMemo { use Child::*; let mut needs_remapping = false; - let mut remapped_children = Vec::with_capacity(physical_expr.children.len()); - for child in &physical_expr.children { - let remapped = match child { + let remapped_children = physical_expr + .children + .iter() + .map(|child| match child { Singleton(goal_id) => { - let repr = self.find_repr_goal_member_id(*goal_id).await?; - needs_remapping |= repr != *goal_id; + let repr = self.find_repr_goal_member_id(*goal_id); + if repr != *goal_id { + needs_remapping = true; + } Singleton(repr) } VarLength(goal_ids) => { - let mut reprs = Vec::with_capacity(goal_ids.len()); - for goal_id in goal_ids { - let repr = self.find_repr_goal_member_id(*goal_id).await?; - needs_remapping |= repr != *goal_id; - reprs.push(repr); - } + let reprs: Vec<_> = goal_ids + .iter() + .map(|id| { + let repr = self.find_repr_goal_member_id(*id); + if repr != *id { + needs_remapping = true; + } + repr + }) + .collect(); VarLength(reprs) } - }; - remapped_children.push(remapped); - } + }) + .collect(); Ok(if needs_remapping { Cow::Owned(PhysicalExpression { @@ -527,18 +533,59 @@ impl MemoryMemoHelper for MemoryMemo { &mut self, group_merges: &[MergeGroupProduct], ) -> Result, Infallible> { - let mut merge_products = Vec::new(); - - for (goal_id, goal_info) in self.goal_info.iter() { - let repr_goal_id = self.find_repr_goal_id(*goal_id).await?; - if repr_goal_id != *goal_id { - merge_products.push(MergeGoalProduct { - new_goal_id: repr_goal_id, - merged_goals: vec![*goal_id], + let mut goal_merges = Vec::new(); + + for merge_product in group_merges { + let new_group_id = merge_product.new_group_id; + let group_info = self.group_info.get_mut(&new_group_id).unwrap(); + + // Take related goals to the group to avoid borrowing issues. + let related_goals = std::mem::take(&mut group_info.goals); + let mut updated_goals = HashMap::new(); + + for (physical_props, merged_goals) in related_goals { + // Create a new representative goal for this physical property. + let new_goal = Goal(new_group_id, physical_props.clone()); + let new_goal_id = self.get_goal_id(&new_goal).await?; + + // Process each goal being merged. + for &old_goal_id in &merged_goals { + // Remove old goal info and extend new goal with its members. + let old_goal_info = self.goal_info.remove(&old_goal_id).unwrap(); + let new_goal_info = self.goal_info.get_mut(&new_goal_id).unwrap(); + + new_goal_info.members.extend(old_goal_info.members); + self.goal_to_id.remove(&old_goal_info.goal); + + // Update referencing expressions index. + let refs = self + .goal_member_referencing_exprs_index + .remove(&GoalMemberId::GoalId(old_goal_id)) + .unwrap_or_default(); + + self.goal_member_referencing_exprs_index + .entry(GoalMemberId::GoalId(new_goal_id)) + .or_default() + .extend(refs); + + // Update the union-find structure for goal representatives. + self.repr_goal_id.merge(&old_goal_id, &new_goal_id); + } + + // Update the goals mapping for this physical property. + updated_goals.insert(physical_props, vec![new_goal_id]); + + // Record the merge operation. + goal_merges.push(MergeGoalProduct { + new_goal_id, + merged_goals, }); } + + // Update the group's goals with the consolidated mapping. + self.group_info.get_mut(&new_group_id).unwrap().goals = updated_goals; } - Ok(merge_products) + Ok(goal_merges) } } diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index 6a9058c8..3fed9782 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -179,18 +179,16 @@ impl Memo for MemoryMemo { ) -> Result, Infallible> { let repr_goal_id = self.find_repr_goal_id(goal_id).await?; - // Get all members and map each to its representative ID. - let mut result = HashSet::new(); - for member_id in &self + let members = self .goal_info .get(&repr_goal_id) .expect("Goal not found in memo table") .members - { - result.insert(self.find_repr_goal_member_id(*member_id).await?); - } + .iter() + .map(|&member_id| self.find_repr_goal_member_id(member_id)) + .collect(); - Ok(result) + Ok(members) } async fn add_goal_member( @@ -198,16 +196,23 @@ impl Memo for MemoryMemo { goal_id: GoalId, member_id: GoalMemberId, ) -> Result { - let repr_goal_id = self.find_repr_goal_id(goal_id).await?; - let repr_member_id = self.find_repr_goal_member_id(member_id).await?; - - self.goal_info - .get_mut(&repr_goal_id) - .expect("Goal not found in memo table") - .members - .insert(repr_member_id); + // We call `get_all_goal_members` to ensure we only have the representative IDs + // in the set. This is important because members may have been merged with another. + let mut current_members = self.get_all_goal_members(goal_id).await?; + + let repr_member_id = self.find_repr_goal_member_id(member_id); + let added = current_members.insert(repr_member_id); + + if added { + let repr_goal_id = self.find_repr_goal_id(goal_id).await?; + self.goal_info + .get_mut(&repr_goal_id) + .expect("Goal not found in memo table") + .members + .insert(repr_member_id); + } - Ok(false) + Ok(added) } } diff --git a/optd/src/memo/memory/mod.rs b/optd/src/memo/memory/mod.rs index f9af48fc..a09dc405 100644 --- a/optd/src/memo/memory/mod.rs +++ b/optd/src/memo/memory/mod.rs @@ -54,7 +54,7 @@ pub struct MemoryMemo { /// that contain a reference to this group. /// The value logical_expr_ids may *NOT* be a representative ID. group_referencing_exprs_index: HashMap>, - /// To speed up recursive merges, we maintain a mapping from goal member IDs to all goal member IDs + /// To speed up recursive merges, we maintain a mapping from goal member IDs to all physical expression IDs /// that contain a reference to this goal member. /// The value physical_expr_ids may *NOT* be a representative ID. goal_member_referencing_exprs_index: HashMap>, From dedf201a3c9cbb1cbf11d8962de2b80fddd86bab Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 19:00:43 -0400 Subject: [PATCH 08/23] Start implementing physical expr merge --- optd/src/memo/memory/helpers.rs | 36 ++++++++++++++++++++++---- optd/src/memo/memory/implementation.rs | 8 ++++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index aa01c543..92905a45 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -4,7 +4,8 @@ use super::{Infallible, MemoryMemo}; use crate::{ cir::*, memo::{ - Materialize, Memo, MergeGoalProduct, MergeGroupProduct, Representative, memory::GroupInfo, + Materialize, Memo, MergeGoalProduct, MergeGroupProduct, MergePhysicalExprProduct, + Representative, memory::GroupInfo, }, }; use hashbrown::{HashMap, HashSet}; @@ -107,10 +108,15 @@ pub trait MemoryMemoHelper: Memo { merge_operations: Vec, ) -> Result, Infallible>; - async fn process_goal_merges( + async fn merge_dependent_goals( &mut self, group_merges: &[MergeGroupProduct], ) -> Result, Infallible>; + + async fn merge_dependent_physical_exprs( + &mut self, + goal_merges: &[MergeGoalProduct], + ) -> Result, Infallible>; } impl MemoryMemoHelper for MemoryMemo { @@ -519,8 +525,8 @@ impl MemoryMemoHelper for MemoryMemo { /// Processes goal merges and returns the results. /// - /// This function consolidates the merge operations for goals and returns a list of - /// `MergeGoalProduct` instances representing the merged goals. + /// This function performs the merge operations for goals based on the merged groups and + /// returns a list of `MergeGoalProduct` instances representing the merged goals. /// /// # Parameters /// @@ -529,7 +535,7 @@ impl MemoryMemoHelper for MemoryMemo { /// # Returns /// /// A vector of `MergeGoalProduct` instances representing the merged goals. - async fn process_goal_merges( + async fn merge_dependent_goals( &mut self, group_merges: &[MergeGroupProduct], ) -> Result, Infallible> { @@ -588,4 +594,24 @@ impl MemoryMemoHelper for MemoryMemo { Ok(goal_merges) } + + /// Processes physical expression merges and returns the results. + /// + /// This function performs the merge operations for physical expressions based on + /// the merged goals and returns a list of `MergePhysicalExprProduct` instances + /// representing the merged physical expressions. + /// + /// # Parameters + /// + /// * `goal_merges` - A slice of `MergeGoalProduct` instances representing the merged goals. + /// + /// # Returns + /// + /// A vector of `MergePhysicalExprProduct` instances representing the merged physical expressions. + async fn merge_dependent_physical_exprs( + &mut self, + goal_merges: &[MergeGoalProduct], + ) -> Result, Infallible> { + Ok(vec![]) + } } diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index 3fed9782..d68cf80d 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -164,12 +164,16 @@ impl Memo for MemoryMemo { // Now handle goal merges: we do not need to pass any extra parameters as // the goals to merge are gathered in the `goals` member of each new // representative group. - let goal_merges = self.process_goal_merges(&group_merges).await?; + let goal_merges = self.merge_dependent_goals(&group_merges).await?; + + // Finally, we need to recursively merge the physical expressions that are + // dependent on the merged goals (and the recursively merged expressions themselves). + let expr_merges = self.merge_dependent_physical_exprs(&goal_merges).await?; Ok(MergeProducts { group_merges, goal_merges, - expr_merges: vec![], // TODO(Alexis): Implement expression merges. + expr_merges, }) } From f89d6973ff64bd5d4c62b0ed1c57fa26c688e567 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Sun, 18 May 2025 20:07:40 -0400 Subject: [PATCH 09/23] Finish physical expr merges --- optd/src/memo/memory/helpers.rs | 130 +++++++++++++++++++++++-- optd/src/memo/memory/implementation.rs | 3 +- 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index 92905a45..0eb9eff3 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -9,7 +9,7 @@ use crate::{ }, }; use hashbrown::{HashMap, HashSet}; -use std::borrow::Cow; +use std::{borrow::Cow, collections::VecDeque}; impl MemoryMemo { pub(super) fn next_group_id(&mut self) -> GroupId { @@ -51,7 +51,12 @@ impl MemoryMemo { /// of expressions under a new [`GroupId`]. /// /// If a group does not exist, then the set of expressions referencing it is the empty set. - fn merge_referencing_exprs(&mut self, group1: GroupId, group2: GroupId, new_group: GroupId) { + fn merge_referencing_logical_exprs( + &mut self, + group1: GroupId, + group2: GroupId, + new_group: GroupId, + ) { // Remove the entries for the original two groups that we want to merge. let exprs1 = self.take_referencing_expr_set(group1); let exprs2 = self.take_referencing_expr_set(group2); @@ -96,7 +101,7 @@ pub trait MemoryMemoHelper: Memo { prev_expr: LogicalExpression, ) -> Result, Infallible>; - async fn process_referencing_expressions( + async fn process_referencing_logical_exprs( &mut self, group_id_1: GroupId, group_id_2: GroupId, @@ -113,6 +118,11 @@ pub trait MemoryMemoHelper: Memo { group_merges: &[MergeGroupProduct], ) -> Result, Infallible>; + async fn consolidate_merge_physical_expr_products( + &self, + merge_operations: Vec, + ) -> Result, Infallible>; + async fn merge_dependent_physical_exprs( &mut self, goal_merges: &[MergeGoalProduct], @@ -415,7 +425,7 @@ impl MemoryMemoHelper for MemoryMemo { } } - /// Processes expressions that reference the merged groups. + /// Processes logical expressions that reference the merged groups. /// /// Returns the group pairs that need to be merged due to new equivalences. /// @@ -438,15 +448,15 @@ impl MemoryMemoHelper for MemoryMemo { /// b. Check if the remapped expression is different from the original. /// c. If different, handle the change and check for new equivalences. /// 4. Return any new group pairs that need to be merged. - async fn process_referencing_expressions( + async fn process_referencing_logical_exprs( &mut self, group_id_1: GroupId, group_id_2: GroupId, new_group_id: GroupId, ) -> Result, Infallible> { - // Merge the set of expressions that reference these two groups into one that references the - // new group. - self.merge_referencing_exprs(group_id_1, group_id_2, new_group_id); + // Merge the set of expressions that reference these two groups into one + // that references the new group. + self.merge_referencing_logical_exprs(group_id_1, group_id_2, new_group_id); // We need to clone here because we are modifying our `self` state inside the loop. // TODO: This is an inefficiency. This referencing index shouldn't be modified in the loop @@ -595,6 +605,51 @@ impl MemoryMemoHelper for MemoryMemo { Ok(goal_merges) } + /// Consolidates merge physical expression operations into a comprehensive result. + /// + /// This function takes a list of individual physical expression merge operations + /// and consolidates them into a complete picture of which physical expressions + /// were merged into each final representative. + /// + /// It also handles cases where past representatives themselves are merged into newer ones. + async fn consolidate_merge_physical_expr_products( + &self, + merge_operations: Vec, + ) -> Result, Infallible> { + // Collect operations into a map from representative to all merged physical expressions. + let mut consolidated_map: HashMap> = + HashMap::new(); + + for op in merge_operations { + let current_repr = self.repr_physical_expr_id.find(&op.new_physical_expr_id); + + consolidated_map + .entry(current_repr) + .or_default() + .extend(op.merged_physical_exprs.iter().copied()); + + if op.new_physical_expr_id != current_repr { + consolidated_map + .entry(current_repr) + .or_default() + .insert(op.new_physical_expr_id); + } + } + + // Build the final list of merge products from the consolidated map. + let physical_expr_merges = consolidated_map + .into_iter() + .filter_map(|(repr, exprs)| { + (!exprs.is_empty()).then(|| MergePhysicalExprProduct { + new_physical_expr_id: repr, + merged_physical_exprs: exprs.into_iter().collect(), + }) + }) + .collect(); + + Ok(physical_expr_merges) + } + /// Processes physical expression merges and returns the results. /// /// This function performs the merge operations for physical expressions based on @@ -612,6 +667,63 @@ impl MemoryMemoHelper for MemoryMemo { &mut self, goal_merges: &[MergeGoalProduct], ) -> Result, Infallible> { - Ok(vec![]) + let mut physical_expr_merges = Vec::new(); + let mut pending_dependencies = VecDeque::new(); + + // Initialize pending dependencies with the input goal merges. + for goal_merge in goal_merges { + pending_dependencies.push_back(GoalMemberId::GoalId(goal_merge.new_goal_id)); + } + + // Process dependencies in a loop to handle cascading merges. + while let Some(current_dependency) = pending_dependencies.pop_front() { + let referencing_exprs = self + .goal_member_referencing_exprs_index + .get(¤t_dependency) + .cloned() + .unwrap_or_default(); + + // Process each expression that references this dependency. + for reference_id in referencing_exprs { + // Remap the expression to use updated member references. + let prev_expr = self.materialize_physical_expr(reference_id).await?; + let new_expr = self.remap_physical_expr(&prev_expr).await?; + let new_id = self.get_physical_expr_id(&new_expr).await?; + + if reference_id != new_id { + // Remove old expression from indexes. + self.id_to_physical_expr.remove(&reference_id); + self.physical_expr_to_id.remove(&prev_expr); + + // Update the goal member referencing expressions index + // for the new physical expr. + let old_refs = self + .goal_member_referencing_exprs_index + .remove(&GoalMemberId::PhysicalExpressionId(reference_id)) + .unwrap_or_default(); + self.goal_member_referencing_exprs_index + .entry(GoalMemberId::PhysicalExpressionId(new_id)) + .or_default() + .extend(old_refs); + + // Update the representative ID for the expression. + self.repr_physical_expr_id.merge(&reference_id, &new_id); + + // This handles cascading effects where merging + // one expression affects others. + pending_dependencies.push_back(GoalMemberId::PhysicalExpressionId(new_id)); + + // Record the merge. + physical_expr_merges.push(MergePhysicalExprProduct { + new_physical_expr_id: new_id, + merged_physical_exprs: vec![reference_id], + }); + } + } + } + + // Consolidate the merge products to ensure no duplicates. + self.consolidate_merge_physical_expr_products(physical_expr_merges) + .await } } diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index d68cf80d..ab63a01d 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -108,6 +108,7 @@ impl Memo for MemoryMemo { /// - We update all expressions that reference the merged groups. /// - If this creates new equivalences, we add the affected groups to the merge queue. /// - We continue until no more merges are needed. + /// /// PHASE 2: PHYSICAL /// - We inspect the result of logical merges, and identify the goals to merge in /// the corresponding group_info of the new representative group. @@ -149,7 +150,7 @@ impl Memo for MemoryMemo { // Process expressions that reference the merged groups, // which may trigger additional group merges. let new_pending_merges = self - .process_referencing_expressions(group_id_1, group_id_2, new_group_id) + .process_referencing_logical_exprs(group_id_1, group_id_2, new_group_id) .await?; pending_merges.extend(new_pending_merges); From 0be6463598f293c7dc031d688e395e04b1fc295e Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Mon, 19 May 2025 10:22:45 -0400 Subject: [PATCH 10/23] Test physical side of the memo table --- optd/src/memo/memory/helpers.rs | 45 ++--- optd/src/memo/memory/implementation.rs | 260 ++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 38 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index 0eb9eff3..63df5c91 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -35,37 +35,6 @@ impl MemoryMemo { self.next_shared_id += 1; PhysicalExpressionId(id) } - - /// Takes the set of [`LogicalExpressionId`] that reference a group, mapped to their - /// representatives. - fn take_referencing_expr_set(&mut self, group_id: GroupId) -> HashSet { - self.group_referencing_exprs_index - .remove(&group_id) - .unwrap_or_default() - .iter() - .map(|id| self.repr_logical_expr_id.find(id)) - .collect() - } - - /// Merges the two sets of logical expressions that reference the two groups into a single set - /// of expressions under a new [`GroupId`]. - /// - /// If a group does not exist, then the set of expressions referencing it is the empty set. - fn merge_referencing_logical_exprs( - &mut self, - group1: GroupId, - group2: GroupId, - new_group: GroupId, - ) { - // Remove the entries for the original two groups that we want to merge. - let exprs1 = self.take_referencing_expr_set(group1); - let exprs2 = self.take_referencing_expr_set(group2); - let new_set = exprs1.union(&exprs2).copied().collect(); - - // Update the index for the new group / set of logical expressions. - self.group_referencing_exprs_index - .insert(new_group, new_set); - } } /// Helper functions for the in-memory memo table implementation. @@ -456,7 +425,19 @@ impl MemoryMemoHelper for MemoryMemo { ) -> Result, Infallible> { // Merge the set of expressions that reference these two groups into one // that references the new group. - self.merge_referencing_logical_exprs(group_id_1, group_id_2, new_group_id); + let exprs1 = self + .group_referencing_exprs_index + .remove(&group_id_1) + .unwrap_or_default(); + let exprs2 = self + .group_referencing_exprs_index + .remove(&group_id_2) + .unwrap_or_default(); + let new_set = exprs1.union(&exprs2).copied().collect(); + + // Update the index for the new group / set of logical expressions. + self.group_referencing_exprs_index + .insert(new_group_id, new_set); // We need to clone here because we are modifying our `self` state inside the loop. // TODO: This is an inefficiency. This referencing index shouldn't be modified in the loop diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index ab63a01d..82ee550b 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -91,15 +91,15 @@ impl Memo for MemoryMemo { /// creating a cascading effect. /// /// 4. After all logical groups have been merged, we need to identify and merge all goals - /// that have become equivalent. Goals are considered equivalent when they reference the - /// same group and share identical logical properties. + /// that have become equivalent. Goals are considered equivalent when they reference the + /// same group and share identical logical properties. /// /// 5. When goals are merged, this triggers a cascading effect on physical expressions: /// - Physical expressions that reference merged goals need updating. /// - This may cause previously distinct physical expressions to become identical. /// - These newly identical expressions must then be merged. /// - Each merge may create more identical expressions, continuing the chain reaction - /// until the entire structure is consistent. + /// until the entire structure is consistent. /// /// To handle this complexity, we use an iterative approach: /// PHASE 1: LOGICAL @@ -108,10 +108,10 @@ impl Memo for MemoryMemo { /// - We update all expressions that reference the merged groups. /// - If this creates new equivalences, we add the affected groups to the merge queue. /// - We continue until no more merges are needed. - /// + /// /// PHASE 2: PHYSICAL /// - We inspect the result of logical merges, and identify the goals to merge in - /// the corresponding group_info of the new representative group. + /// the corresponding group_info of the new representative group. /// - For each group, we create a new representative goal. /// - We update all expressions that reference the merged goals. /// - If this creates new equivalences, we add the affected goals to the merge queue. @@ -224,7 +224,10 @@ impl Memo for MemoryMemo { #[cfg(test)] pub mod tests { use super::*; - use crate::cir::{Child, OperatorData}; + use crate::{ + cir::{Child, OperatorData}, + memo::Materialize, + }; pub async fn lookup_or_insert( memo: &mut impl Memo, @@ -316,6 +319,28 @@ pub mod tests { assert_eq!(retrieve(&memo, g4).await, vec![0, 1, 2, 3]); } + async fn create_goal( + memo: &mut MemoryMemo, + group_id: GroupId, + props: PhysicalProperties, + ) -> GoalId { + let goal = Goal(group_id, props); + memo.get_goal_id(&goal).await.unwrap() + } + + async fn create_physical_expr( + memo: &mut MemoryMemo, + tag: &str, + children: Vec, + ) -> PhysicalExpressionId { + let expr = PhysicalExpression { + tag: tag.to_string(), + data: vec![], + children: children.into_iter().map(Child::Singleton).collect(), + }; + memo.get_physical_expr_id(&expr).await.unwrap() + } + #[tokio::test] async fn test_recursive_merge() { let mut memo = MemoryMemo::default(); @@ -593,4 +618,227 @@ pub mod tests { "Merged group should contain exactly one copy of each expression" ); } + + #[tokio::test] + async fn test_goal_merge() { + let mut memo = MemoryMemo::default(); + + // Create two groups. + let g1 = lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = lookup_or_insert(&mut memo, 2, vec![]).await; + + // Create goals with the same properties. + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props).await; + + // Add a member to each goal. + let p1 = create_physical_expr(&mut memo, "a", vec![]).await; + let p2 = create_physical_expr(&mut memo, "b", vec![]).await; + + memo.add_goal_member(goal1, GoalMemberId::PhysicalExpressionId(p1)) + .await + .unwrap(); + memo.add_goal_member(goal2, GoalMemberId::PhysicalExpressionId(p2)) + .await + .unwrap(); + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // Verify goal merges in the result. + assert!( + !merge_result.goal_merges.is_empty(), + "Goal merges should not be empty" + ); + + // Get the new goal and verify members. + let new_goal_id = merge_result.goal_merges[0].new_goal_id; + let members = memo.get_all_goal_members(new_goal_id).await.unwrap(); + + assert_eq!(members.len(), 2, "Merged goal should have two members"); + assert!(members.contains(&GoalMemberId::PhysicalExpressionId(p1))); + assert!(members.contains(&GoalMemberId::PhysicalExpressionId(p2))); + + // Verify representatives. + assert_eq!(memo.find_repr_goal_id(goal1).await.unwrap(), new_goal_id); + assert_eq!(memo.find_repr_goal_id(goal2).await.unwrap(), new_goal_id); + } + + #[tokio::test] + async fn test_physical_expr_merge() { + let mut memo = MemoryMemo::default(); + + // Create two groups and goals. + let g1 = lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = lookup_or_insert(&mut memo, 2, vec![]).await; + + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props).await; + + // Create physical expressions referencing these goals. + let p1 = create_physical_expr(&mut memo, "op", vec![GoalMemberId::GoalId(goal1)]).await; + let p2 = create_physical_expr(&mut memo, "op", vec![GoalMemberId::GoalId(goal2)]).await; + + // Verify they're different before merge. + assert_ne!(p1, p2); + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // Verify physical expression merges in the result. + assert!( + !merge_result.expr_merges.is_empty(), + "Physical expr merges should not be empty" + ); + + // Get the new physical expression id. + let new_expr_id = merge_result.expr_merges[0].new_physical_expr_id; + + // Verify representatives. + assert_eq!( + memo.find_repr_physical_expr_id(p1).await.unwrap(), + new_expr_id + ); + assert_eq!( + memo.find_repr_physical_expr_id(p2).await.unwrap(), + new_expr_id + ); + } + + #[tokio::test] + async fn test_recursive_physical_expr_merge() { + let mut memo = MemoryMemo::default(); + + // Create groups. + let g1 = lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = lookup_or_insert(&mut memo, 2, vec![]).await; + + // Create goals. + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props.clone()).await; + + // Create first level physical expressions. + let p1 = create_physical_expr(&mut memo, "op1", vec![GoalMemberId::GoalId(goal1)]).await; + let p2 = create_physical_expr(&mut memo, "op1", vec![GoalMemberId::GoalId(goal2)]).await; + + // Create second level physical expressions. + let p3 = create_physical_expr( + &mut memo, + "op2", + vec![GoalMemberId::PhysicalExpressionId(p1)], + ) + .await; + let p4 = create_physical_expr( + &mut memo, + "op2", + vec![GoalMemberId::PhysicalExpressionId(p2)], + ) + .await; + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // There should be multiple levels of physical expr merges. + assert!( + merge_result.expr_merges.len() >= 2, + "Should have at least 2 levels of physical expr merges" + ); + + // Verify representatives for both levels. + let p1_repr = memo.find_repr_physical_expr_id(p1).await.unwrap(); + let p2_repr = memo.find_repr_physical_expr_id(p2).await.unwrap(); + let p3_repr = memo.find_repr_physical_expr_id(p3).await.unwrap(); + let p4_repr = memo.find_repr_physical_expr_id(p4).await.unwrap(); + + assert_eq!( + p1_repr, p2_repr, + "First level expressions should share representative" + ); + assert_eq!( + p3_repr, p4_repr, + "Second level expressions should share representative" + ); + } + + #[tokio::test] + async fn test_merge_products_completeness() { + let mut memo = MemoryMemo::default(); + + // Create a three-level structure. + let g1 = super::tests::lookup_or_insert(&mut memo, 1, vec![]).await; + let g2 = super::tests::lookup_or_insert(&mut memo, 2, vec![]).await; + + let props = PhysicalProperties(None); + let goal1 = create_goal(&mut memo, g1, props.clone()).await; + let goal2 = create_goal(&mut memo, g2, props.clone()).await; + + // Level 1 + let p1 = create_physical_expr(&mut memo, "leaf1", vec![GoalMemberId::GoalId(goal1)]).await; + let p2 = create_physical_expr(&mut memo, "leaf1", vec![GoalMemberId::GoalId(goal2)]).await; + + // Level 2 + let p3 = create_physical_expr( + &mut memo, + "mid1", + vec![GoalMemberId::PhysicalExpressionId(p1)], + ) + .await; + let p4 = create_physical_expr( + &mut memo, + "mid1", + vec![GoalMemberId::PhysicalExpressionId(p2)], + ) + .await; + + // Level 3 + let p5 = create_physical_expr( + &mut memo, + "top1", + vec![GoalMemberId::PhysicalExpressionId(p3)], + ) + .await; + let p6 = create_physical_expr( + &mut memo, + "top1", + vec![GoalMemberId::PhysicalExpressionId(p4)], + ) + .await; + + // Merge the groups. + let merge_result = memo.merge_groups(g1, g2).await.unwrap(); + + // Verify structure of the merge products. + assert_eq!( + merge_result.group_merges.len(), + 1, + "Should have 1 group merge" + ); + assert_eq!( + merge_result.goal_merges.len(), + 1, + "Should have 1 goal merge" + ); + assert_eq!( + merge_result.expr_merges.len(), + 3, + "Should have 3 expression merges for the three levels" + ); + + // Verify all expressions share the same representatives at their respective levels. + assert_eq!( + memo.find_repr_physical_expr_id(p1).await.unwrap(), + memo.find_repr_physical_expr_id(p2).await.unwrap(), + ); + assert_eq!( + memo.find_repr_physical_expr_id(p3).await.unwrap(), + memo.find_repr_physical_expr_id(p4).await.unwrap(), + ); + assert_eq!( + memo.find_repr_physical_expr_id(p5).await.unwrap(), + memo.find_repr_physical_expr_id(p6).await.unwrap(), + ); + } } From 827a2eb8acd6bf2c6e01c7282fc4a6622eeb3c1a Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 15:25:59 -0400 Subject: [PATCH 11/23] Add * handling and better retrieve props --- optd-cli/src/main.rs | 20 ++---- optd/src/demo/demo.opt | 22 +++++- optd/src/demo/mod.rs | 37 ++++++++-- optd/src/dsl/analyzer/from_ast/converter.rs | 26 ++----- optd/src/dsl/analyzer/hir/mod.rs | 32 ++++++--- optd/src/dsl/engine/eval/expr.rs | 46 +++++++------ optd/src/dsl/engine/eval/match.rs | 66 +++++++++++------- optd/src/memo/memory/implementation.rs | 2 + optd/src/optimizer/hir_cir/from_cir.rs | 76 +++++++++++++-------- optd/src/optimizer/jobs/execute.rs | 16 ++++- optd/src/optimizer/jobs/manage.rs | 4 +- optd/src/optimizer/jobs/mod.rs | 2 +- optd/src/optimizer/merge/helpers.rs | 8 ++- optd/src/optimizer/tasks/launch.rs | 16 +++-- 14 files changed, 238 insertions(+), 135 deletions(-) diff --git a/optd-cli/src/main.rs b/optd-cli/src/main.rs index cdad6e73..8b9cbccf 100644 --- a/optd-cli/src/main.rs +++ b/optd-cli/src/main.rs @@ -34,13 +34,12 @@ use clap::{Parser, Subcommand}; use colored::Colorize; -use optd::catalog::Catalog; use optd::catalog::iceberg::memory_catalog; use optd::dsl::analyzer::hir::{CoreData, HIR, Udf, Value}; use optd::dsl::compile::{Config, compile_hir}; use optd::dsl::engine::{Continuation, Engine, EngineResponse}; use optd::dsl::utils::errors::{CompileError, Diagnose}; -use optd::dsl::utils::retriever::{MockRetriever, Retriever}; +use optd::dsl::utils::retriever::MockRetriever; use std::collections::HashMap; use std::sync::Arc; use tokio::runtime::Builder; @@ -66,22 +65,17 @@ enum Commands { RunFunctions(Config), } -/// A unimplemented user-defined function. -pub fn unimplemented_udf( - _args: &[Value], - _catalog: &dyn Catalog, - _retriever: &dyn Retriever, -) -> Value { - println!("This user-defined function is unimplemented!"); - Value::new(CoreData::::None) -} - fn main() -> Result<(), Vec> { let cli = Cli::parse(); let mut udfs = HashMap::new(); let udf = Udf { - func: unimplemented_udf, + func: Arc::new(|_, _, _| { + Box::pin(async move { + println!("This user-defined function is unimplemented!"); + Value::new(CoreData::::None) + }) + }), }; udfs.insert("unimplemented_udf".to_string(), udf.clone()); diff --git a/optd/src/demo/demo.opt b/optd/src/demo/demo.opt index 2245f82b..15e63e7e 100644 --- a/optd/src/demo/demo.opt +++ b/optd/src/demo/demo.opt @@ -1,7 +1,9 @@ data Physical data PhysicalProperties data Statistics -data LogicalProperties +// Taking folded here is not the most interesting property, +// but it ensures they are the same for all expressions in the same group. +data LogicalProperties(folded: I64) data Logical = | Add(left: Logical, right: Logical) @@ -10,6 +12,8 @@ data Logical = | Div(left: Logical, right: Logical) \ Const(val: I64) +// This will be the input plan that will be optimized. +// Result is: ((1 - 2) * (3 / 4)) + ((5 - 6) * (7 / 8)) = 0 fn input(): Logical = Add( Mult( @@ -22,9 +26,21 @@ fn input(): Logical = ) ) -// TODO(Alexis): This should be $ really, make costing and derive consistent with each other. +// External function to allow the retrieval of properties. +fn properties(op: Logical*): LogicalProperties + +// FIXME: This should be $ really (or other), make costing and derive consistent with each other. // Also, be careful of not forking in there! And make it a required function in analyzer. -fn derive(op: Logical) = LogicalProperties +fn derive(op: Logical*) = match op + | Add(left, right) -> + LogicalProperties(left.properties()#folded + right.properties()#folded) + | Sub(left, right) -> + LogicalProperties(left.properties()#folded - right.properties()#folded) + | Mult(left, right) -> + LogicalProperties(left.properties()#folded * right.properties()#folded) + | Div(left, right) -> + LogicalProperties(left.properties()#folded / right.properties()#folded) + \ Const(val) -> LogicalProperties(val) [transformation] fn (op: Logical*) mult_commute(): Logical? = match op diff --git a/optd/src/demo/mod.rs b/optd/src/demo/mod.rs index 48f720c4..ce1f568f 100644 --- a/optd/src/demo/mod.rs +++ b/optd/src/demo/mod.rs @@ -1,10 +1,10 @@ use crate::{ - catalog::iceberg::memory_catalog, + catalog::{Catalog, iceberg::memory_catalog}, dsl::{ - analyzer::hir::Value, + analyzer::hir::{CoreData, LogicalOp, Materializable, Udf, Value}, compile::{Config, compile_hir}, engine::{Continuation, Engine, EngineResponse}, - utils::retriever::MockRetriever, + utils::retriever::{MockRetriever, Retriever}, }, memo::MemoryMemo, optimizer::{OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical}, @@ -12,10 +12,39 @@ use crate::{ use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{sync::mpsc, time::timeout}; +pub async fn properties( + args: Vec, + _catalog: Arc, + retriever: Arc, +) -> Value { + let arg = args[0].clone(); + let group_id = match &arg.data { + CoreData::Logical(Materializable::Materialized(LogicalOp { group_id, .. })) => { + group_id.unwrap() + } + CoreData::Logical(Materializable::UnMaterialized(group_id)) => *group_id, + _ => panic!("Expected a logical plan"), + }; + + retriever.get_properties(group_id).await +} + async fn run_demo() { // Compile the HIR. let config = Config::new("src/demo/demo.opt".into()); - let udfs = HashMap::new(); + + // Create a properties UDF. + let properties_udf = Udf { + func: Arc::new(|args, catalog, retriever| { + Box::pin(async move { properties(args, catalog, retriever).await }) + }), + }; + + // Create the UDFs HashMap. + let mut udfs = HashMap::new(); + udfs.insert("properties".to_string(), properties_udf); + + // Compile with the config and UDFs. let hir = compile_hir(config, udfs).unwrap(); // Create necessary components. diff --git a/optd/src/dsl/analyzer/from_ast/converter.rs b/optd/src/dsl/analyzer/from_ast/converter.rs index 92d311c8..76297c41 100644 --- a/optd/src/dsl/analyzer/from_ast/converter.rs +++ b/optd/src/dsl/analyzer/from_ast/converter.rs @@ -177,13 +177,12 @@ impl ASTConverter { #[cfg(test)] mod converter_tests { use super::*; - use crate::catalog::Catalog; use crate::dsl::analyzer::from_ast::from_ast; use crate::dsl::analyzer::hir::{CoreData, FunKind}; use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind}; use crate::dsl::parser::ast::{self, Adt, Function, Item, Module, Type as AstType}; - use crate::dsl::utils::retriever::Retriever; use crate::dsl::utils::span::{Span, Spanned}; + use std::sync::Arc; // Helper functions to create test items fn create_test_span() -> Span { @@ -382,19 +381,15 @@ mod converter_tests { let ext_func = create_simple_function("external_function", false); let module = create_module_with_functions(vec![ext_func]); - pub fn external_function( - _args: &[Value], - _catalog: &dyn Catalog, - _retriever: &dyn Retriever, - ) -> Value { - println!("Hello from UDF!"); - Value::new(CoreData::::None) - } - // Link the dummy function. let mut udfs = HashMap::new(); let udf = Udf { - func: external_function, + func: Arc::new(|_, _, _| { + Box::pin(async move { + println!("Hello from UDF!"); + Value::new(CoreData::None) + }) + }), }; udfs.insert("external_function".to_string(), udf); @@ -408,13 +403,6 @@ mod converter_tests { // Check that the function is in the context. let func_val = hir.context.lookup("external_function"); assert!(func_val.is_some()); - - // Verify it is the same function pointer. - if let CoreData::Function(FunKind::Udf(udf)) = &func_val.unwrap().data { - assert_eq!(udf.func as usize, external_function as usize); - } else { - panic!("Expected UDF function"); - } } #[test] diff --git a/optd/src/dsl/analyzer/hir/mod.rs b/optd/src/dsl/analyzer/hir/mod.rs index 458e6f8e..92c1cb68 100644 --- a/optd/src/dsl/analyzer/hir/mod.rs +++ b/optd/src/dsl/analyzer/hir/mod.rs @@ -21,7 +21,8 @@ use crate::dsl::utils::retriever::Retriever; use crate::dsl::utils::span::Span; use context::Context; use map::Map; -use std::fmt::Debug; +use std::fmt::{self, Debug}; +use std::pin::Pin; use std::{collections::HashMap, sync::Arc}; pub(crate) mod context; @@ -72,22 +73,33 @@ impl TypedSpan { } } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Udf { - /// The function pointer to the user-defined function. - /// - /// Note that [`Value`]s passed to and returned from this UDF do not have associated metadata. - pub func: fn(&[Value], &dyn Catalog, &dyn Retriever) -> Value, + pub func: Arc< + dyn Fn( + Vec, + Arc, + Arc, + ) -> Pin + Send>> + + Send + + Sync, + >, +} + +impl fmt::Debug for Udf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[udf]") + } } impl Udf { - pub fn call( + pub async fn call( &self, values: &[Value], - catalog: &dyn Catalog, - retriever: &dyn Retriever, + catalog: Arc, + retriever: Arc, ) -> Value { - (self.func)(values, catalog, retriever) + (self.func)(values.to_vec(), catalog, retriever).await } } diff --git a/optd/src/dsl/engine/eval/expr.rs b/optd/src/dsl/engine/eval/expr.rs index 616e7d1d..8cded69b 100644 --- a/optd/src/dsl/engine/eval/expr.rs +++ b/optd/src/dsl/engine/eval/expr.rs @@ -588,7 +588,7 @@ impl Engine { Arc::new(move |arg_values| { Box::pin(capture!([udf, catalog, deriver, k], async move { // Call the UDF with the argument values. - let result = udf.call(&arg_values, catalog.as_ref(), deriver.as_ref()); + let result = udf.call(&arg_values, catalog, deriver).await; // Pass the result to the continuation. k(result).await @@ -1143,18 +1143,22 @@ mod tests { // Define a Rust UDF that calculates the sum of array elements let sum_function = Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Array(elements) => { - let mut sum = 0; - for elem in elements { - if let CoreData::Literal(Literal::Int64(value)) = &elem.data { - sum += value; + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Array(elements) => { + let mut sum: i64 = 0; + for elem in elements { + if let CoreData::Literal(Literal::Int64(value)) = &elem.data { + sum += value; + } + } + Value::new(CoreData::Literal(Literal::Int64(sum))) } + _ => panic!("Expected array argument"), } - Value::new(CoreData::Literal(Literal::Int64(sum))) - } - _ => panic!("Expected array argument"), - }, + }) + }), }))); ctx.bind("sum".to_string(), sum_function); @@ -1271,16 +1275,18 @@ mod tests { ctx.bind( "get".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| { - if args.len() != 2 { - panic!("get function requires 2 arguments"); - } + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + if args.len() != 2 { + panic!("get function requires 2 arguments"); + } - match &args[0].data { - CoreData::Map(map) => map.get(&args[1]), - _ => panic!("First argument must be a map"), - } - }, + match &args[0].data { + CoreData::Map(map) => map.get(&args[1]), + _ => panic!("First argument must be a map"), + } + }) + }), }))), ); diff --git a/optd/src/dsl/engine/eval/match.rs b/optd/src/dsl/engine/eval/match.rs index e1554afa..9fc853be 100644 --- a/optd/src/dsl/engine/eval/match.rs +++ b/optd/src/dsl/engine/eval/match.rs @@ -671,12 +671,16 @@ mod tests { ctx.bind( "length".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Array(elements) => { - Value::new(CoreData::Literal(int(elements.len() as i64))) - } - _ => panic!("Expected array"), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Array(elements) => { + Value::new(CoreData::Literal(int(elements.len() as i64))) + } + _ => panic!("Expected array"), + } + }) + }), }))), ); @@ -985,12 +989,16 @@ mod tests { // Create a to_string function let to_string_fn = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Function( FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Literal(lit) => { - Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) - } - _ => Value::new(CoreData::Literal(string(""))), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Literal(lit) => { + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) + } + _ => Value::new(CoreData::Literal(string(""))), + } + }) + }), }), ))))); @@ -1098,12 +1106,16 @@ mod tests { ctx.bind( "to_string".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Literal(Literal::Int64(i)) => { - Value::new(CoreData::Literal(string(&i.to_string()))) - } - _ => panic!("Expected integer literal"), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Literal(Literal::Int64(i)) => { + Value::new(CoreData::Literal(string(&i.to_string()))) + } + _ => panic!("Expected integer literal"), + } + }) + }), }))), ); @@ -1232,13 +1244,17 @@ mod tests { ctx.bind( "to_string".to_string(), Value::new(CoreData::Function(FunKind::Udf(Udf { - func: |args, _catalog, _deriver| match &args[0].data { - CoreData::Literal(lit) => { - Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) - } - CoreData::Array(_) => Value::new(CoreData::Literal(string(""))), - _ => Value::new(CoreData::Literal(string(""))), - }, + func: Arc::new(|args, _catalog, _retriever| { + Box::pin(async move { + match &args[0].data { + CoreData::Literal(lit) => { + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) + } + CoreData::Array(_) => Value::new(CoreData::Literal(string(""))), + _ => Value::new(CoreData::Literal(string(""))), + } + }) + }), }))), ); diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index 82ee550b..fc948d2b 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -171,6 +171,8 @@ impl Memo for MemoryMemo { // dependent on the merged goals (and the recursively merged expressions themselves). let expr_merges = self.merge_dependent_physical_exprs(&goal_merges).await?; + println!("Group_info: {:?}", self.group_info); + Ok(MergeProducts { group_merges, goal_merges, diff --git a/optd/src/optimizer/hir_cir/from_cir.rs b/optd/src/optimizer/hir_cir/from_cir.rs index 2122b951..35c26765 100644 --- a/optd/src/optimizer/hir_cir/from_cir.rs +++ b/optd/src/optimizer/hir_cir/from_cir.rs @@ -4,34 +4,57 @@ use crate::cir::*; use crate::dsl::analyzer::hir::{ self, CoreData, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, }; -use std::sync::Arc; -/// Converts a [`PartialLogicalPlan`] into a [`Value`]. -pub fn partial_logical_to_value(plan: &PartialLogicalPlan) -> Value { +/// Converts a [`PartialLogicalPlan`] into a [`Value`], optionally associating it with a [`GroupId`]. +/// +/// If the plan is unmaterialized, it will be represented as a [`Value`] containing a group reference. +/// If the plan is materialized, the resulting [`Value`] contains the operator data, and may include +/// a group ID if provided. +pub fn partial_logical_to_value(plan: &PartialLogicalPlan, group_id: Option) -> Value { + use Child::*; use Materializable::*; match plan { - PartialLogicalPlan::UnMaterialized(group_id) => { - // For unmaterialized logical operators, we create a `Value` with the group ID. - Value::new(CoreData::Logical(UnMaterialized(hir::GroupId(group_id.0)))) + PartialLogicalPlan::UnMaterialized(gid) => { + // Represent unmaterialized logical operators using their group ID. + Value::new(CoreData::Logical(UnMaterialized(hir::GroupId(gid.0)))) } PartialLogicalPlan::Materialized(node) => { - // For materialized logical operators, we create a `Value` with the operator data. + // Convert each child to a Value recursively. + let children: Vec = node + .children + .iter() + .map(|child| match child { + Singleton(item) => partial_logical_to_value(item, None), + VarLength(items) => Value::new(CoreData::Array( + items + .iter() + .map(|item| partial_logical_to_value(item, None)) + .collect(), + )), + }) + .collect(); + let operator = Operator { tag: node.tag.clone(), data: convert_operator_data_to_values(&node.data), - children: convert_children_to_values(&node.children, partial_logical_to_value), + children, }; - Value::new(CoreData::Logical(Materialized(LogicalOp::logical( - operator, - )))) + let logical_op = if let Some(gid) = group_id { + LogicalOp::stored_logical(operator, cir_group_id_to_hir(&gid)) + } else { + LogicalOp::logical(operator) + }; + + Value::new(CoreData::Logical(Materialized(logical_op))) } } } /// Converts a [`PartialPhysicalPlan`] into a [`Value`]. pub fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { + use Child::*; use Materializable::*; match plan { @@ -45,7 +68,19 @@ pub fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { let operator = Operator { tag: node.tag.clone(), data: convert_operator_data_to_values(&node.data), - children: convert_children_to_values(&node.children, partial_physical_to_value), + children: node + .children + .iter() + .map(|child| match child { + Singleton(item) => partial_physical_to_value(item), + VarLength(items) => Value::new(CoreData::Array( + items + .iter() + .map(|item| partial_physical_to_value(item)) + .collect(), + )), + }) + .collect(), }; Value::new(CoreData::Physical(Materialized(PhysicalOp::physical( @@ -87,23 +122,6 @@ fn cir_group_id_to_hir(group_id: &GroupId) -> hir::GroupId { hir::GroupId(group_id.0) } -/// A generic function to convert a slice of children into a vector of [`Value`]s. -fn convert_children_to_values(children: &[Child>], converter: F) -> Vec -where - F: Fn(&T) -> Value, - T: 'static, -{ - children - .iter() - .map(|child| match child { - Child::Singleton(item) => converter(item), - Child::VarLength(items) => Value::new(CoreData::Array( - items.iter().map(|item| converter(item)).collect(), - )), - }) - .collect() -} - /// Converts a slice of [`OperatorData`] into a vector of [`Value`]s. fn convert_operator_data_to_values(data: &[OperatorData]) -> Vec { data.iter().map(operator_data_to_value).collect() diff --git a/optd/src/optimizer/jobs/execute.rs b/optd/src/optimizer/jobs/execute.rs index 8e4612d3..130c3d82 100644 --- a/optd/src/optimizer/jobs/execute.rs +++ b/optd/src/optimizer/jobs/execute.rs @@ -47,6 +47,11 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + None, // FIXME: This is correct, however in the DSL right now the parameters + // of the derive function is a *, to allow its children to be * too. However, + // since the group is not yet created, it isn't really stored yet it. + // Hence, the input should be a different type, and ideally be made consistent + // with how costing is handled (i.e. with $ and *). ); let message_tx = self.message_tx.clone(); @@ -97,6 +102,7 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + Some(group_id), ); let message_tx = self.message_tx.clone(); @@ -141,6 +147,9 @@ impl Optimizer { ) -> Result<(), M::MemoError> { use EngineProduct::*; + let Goal(group_id, physical_props) = self.memo.materialize_goal(goal_id).await?; + let properties = physical_properties_to_value(&physical_props); + let engine = self.init_engine(); let plan = partial_logical_to_value( &self @@ -148,11 +157,9 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + Some(group_id), ); - let Goal(_, physical_props) = self.memo.materialize_goal(goal_id).await?; - let properties = physical_properties_to_value(&physical_props); - let message_tx = self.message_tx.clone(); tokio::spawn(async move { let response = engine @@ -179,11 +186,13 @@ impl Optimizer { /// /// # Parameters /// * `expression_id`: The ID of the logical expression to continue with. + /// * `group_id`: The ID of the group to which the expression belongs. /// * `k`: The continuation function to be called with the materialized plan. /// * `job_id`: The ID of the job to be executed. pub(super) async fn execute_continue_with_logical( &self, expression_id: LogicalExpressionId, + group_id: GroupId, k: LogicalContinuation, job_id: JobId, ) -> Result<(), M::MemoError> { @@ -196,6 +205,7 @@ impl Optimizer { .materialize_logical_expr(expression_id) .await? .into(), + Some(group_id), ); let message_tx = self.message_tx.clone(); diff --git a/optd/src/optimizer/jobs/manage.rs b/optd/src/optimizer/jobs/manage.rs index 607567a3..f517c583 100644 --- a/optd/src/optimizer/jobs/manage.rs +++ b/optd/src/optimizer/jobs/manage.rs @@ -63,8 +63,8 @@ impl Optimizer { self.execute_implementation_rule(rule_name, expression_id, goal_id, job_id) .await?; } - ContinueWithLogical(expression_id, k) => { - self.execute_continue_with_logical(expression_id, k, job_id) + ContinueWithLogical(expression_id, group_id, k) => { + self.execute_continue_with_logical(expression_id, group_id, k, job_id) .await?; } } diff --git a/optd/src/optimizer/jobs/mod.rs b/optd/src/optimizer/jobs/mod.rs index 78968b3d..7774ecaf 100644 --- a/optd/src/optimizer/jobs/mod.rs +++ b/optd/src/optimizer/jobs/mod.rs @@ -54,5 +54,5 @@ pub(crate) enum JobKind { /// /// This job represents a continuation-passing-style callback for /// handling the result of a logical expression operation. - ContinueWithLogical(LogicalExpressionId, LogicalContinuation), + ContinueWithLogical(LogicalExpressionId, GroupId, LogicalContinuation), } diff --git a/optd/src/optimizer/merge/helpers.rs b/optd/src/optimizer/merge/helpers.rs index d15c6e75..18742974 100644 --- a/optd/src/optimizer/merge/helpers.rs +++ b/optd/src/optimizer/merge/helpers.rs @@ -39,8 +39,12 @@ impl Optimizer { .continuation .clone(); - let continuation_tasks = - self.create_logical_cont_tasks(&new_exprs, fork_task_id, &continuation); + let continuation_tasks = self.create_logical_cont_tasks( + &new_exprs, + group_id, + fork_task_id, + &continuation, + ); self.get_fork_logical_task_mut(fork_task_id) .unwrap() diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index 62506769..9508ca87 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -96,7 +96,7 @@ impl Optimizer { .dispatched_exprs .clone(); let continue_with_logical_in = - self.create_logical_cont_tasks(&expressions, fork_task_id, &continuation); + self.create_logical_cont_tasks(&expressions, group_id, fork_task_id, &continuation); // Create the fork task. let fork_logical_task = ForkLogicalTask { @@ -121,6 +121,7 @@ impl Optimizer { /// /// # Parameters /// * `expression_id`: The ID of the logical expression to continue with. + /// * `group_id`: The ID of the group that this expression belongs to. /// * `fork_out`: The ID of the fork task that this continue task feeds into. /// * `continuation`: The logical continuation to be used. /// @@ -129,6 +130,7 @@ impl Optimizer { pub(crate) fn launch_continue_with_logical_task( &mut self, expression_id: LogicalExpressionId, + group_id: GroupId, fork_out: TaskId, continuation: LogicalContinuation, ) -> TaskId { @@ -144,7 +146,7 @@ impl Optimizer { self.add_task(task_id, ContinueWithLogical(task)); self.schedule_job( task_id, - JobKind::ContinueWithLogical(expression_id, continuation), + JobKind::ContinueWithLogical(expression_id, group_id, continuation), ); task_id @@ -223,6 +225,7 @@ impl Optimizer { /// /// # Arguments /// * `expressions` - The logical expressions to continue with + /// * `group_id` - The group ID these expressions belong to /// * `fork_task_id` - The fork task that these continuations feed into /// * `continuation` - The continuation to apply /// @@ -231,13 +234,18 @@ impl Optimizer { pub(crate) fn create_logical_cont_tasks( &mut self, expressions: &HashSet, + group_id: GroupId, fork_task_id: TaskId, continuation: &LogicalContinuation, ) -> HashSet { let mut continuation_tasks = HashSet::new(); for &expr_id in expressions { - let task_id = - self.launch_continue_with_logical_task(expr_id, fork_task_id, continuation.clone()); + let task_id = self.launch_continue_with_logical_task( + expr_id, + group_id, + fork_task_id, + continuation.clone(), + ); continuation_tasks.insert(task_id); } From 5661f5dfafd7a6bc729474d944a1e011d8efe0a6 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 17:05:54 -0400 Subject: [PATCH 12/23] Fix bug in demo --- optd/src/demo/demo.opt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optd/src/demo/demo.opt b/optd/src/demo/demo.opt index 15e63e7e..5e41511c 100644 --- a/optd/src/demo/demo.opt +++ b/optd/src/demo/demo.opt @@ -21,7 +21,7 @@ fn input(): Logical = Div(Const(3), Const(4)) ), Mult( - Sub(Const(5), Const(6)), + Sub(Const(1), Const(2)), Div(Const(7), Const(8)) ) ) From 480ca853f4b5c95792ccd1e823323d2ed19ebf5a Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 17:39:54 -0400 Subject: [PATCH 13/23] launch initial implement tasks --- optd/src/optimizer/tasks/launch.rs | 134 +++++++++++++++++++++++------ 1 file changed, 110 insertions(+), 24 deletions(-) diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index 9508ca87..af550ac7 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -6,8 +6,8 @@ use crate::{ Optimizer, jobs::{JobKind, LogicalContinuation}, tasks::{ - ContinueWithLogicalTask, ExploreGroupTask, ForkLogicalTask, OptimizeGoalTask, - TransformExpressionTask, + ContinueWithLogicalTask, ExploreGroupTask, ForkLogicalTask, ImplementExpressionTask, + OptimizeGoalTask, TransformExpressionTask, }, }, }; @@ -188,6 +188,42 @@ impl Optimizer { task_id } + /// Method to launch a task for implementing a logical expression with a rule. + /// + /// # Parameters + /// * `expr_id`: The ID of the logical expression to implement. + /// * `rule`: The implementation rule to apply. + /// * `optimize_goal_out`: The ID of the optimization task that this implementation feeds into. + /// * `goal_id`: The ID of the goal that this implementation belongs to. + /// + /// # Returns + /// * `TaskId`: The ID of the created implement task. + pub(crate) fn launch_implement_expression_task( + &mut self, + expr_id: LogicalExpressionId, + rule: ImplementationRule, + optimize_goal_out: TaskId, + goal_id: GoalId, + ) -> TaskId { + use Task::*; + + let task_id = self.next_task_id(); + let task = ImplementExpressionTask { + rule: rule.clone(), + expression_id: expr_id, + optimize_goal_out, + fork_in: None, + }; + + self.add_task(task_id, ImplementExpression(task)); + self.schedule_job( + task_id, + JobKind::ImplementExpression(rule, expr_id, goal_id), + ); + + task_id + } + /// Creates transform tasks for a set of logical expressions. /// /// # Arguments @@ -221,6 +257,39 @@ impl Optimizer { transform_tasks } + /// Creates implement tasks for a set of logical expressions. + /// + /// # Arguments + /// * `expressions` - The logical expressions to implement + /// * `goal_id` - The goal ID to implement for + /// * `optimize_task_id` - The optimization task that these transforms feed into + /// + /// # Returns + /// * `HashSet` - The IDs of all created implement tasks + pub(crate) fn create_implement_tasks( + &mut self, + expressions: &HashSet, + goal_id: GoalId, + optimize_task_id: TaskId, + ) -> HashSet { + let implementations = self.rule_book.get_implementations().to_vec(); + let mut implement_tasks = HashSet::new(); + + for &expr_id in expressions { + for rule in &implementations { + let task_id = self.launch_implement_expression_task( + expr_id, + rule.clone(), + optimize_task_id, + goal_id, + ); + implement_tasks.insert(task_id); + } + } + + implement_tasks + } + /// Creates logical continuation tasks for a fork task and a set of expressions. /// /// # Arguments @@ -324,20 +393,23 @@ impl Optimizer { let goal_optimize_task_id = self.next_task_id(); - // TODO(Alexis): Materialize the goal and only explore the group - for now. - // This is sufficient to support logical->logical transformation. let Goal(group_id, _) = self.memo.materialize_goal(goal_id).await?; let explore_group_in = self.ensure_group_exploration_task(group_id).await?; + // Launch all implementation tasks. + let expressions = self.memo.get_all_logical_exprs(group_id).await?; + let implement_expression_in = + self.create_implement_tasks(&expressions, goal_id, goal_optimize_task_id); + let goal_optimize_task = OptimizeGoalTask { goal_id, optimize_plan_out: HashSet::new(), - optimize_goal_out: HashSet::new(), + optimize_goal_out: HashSet::new(), // filled below to avoid infinite recursion fork_costed_out: HashSet::new(), - optimize_goal_in: HashSet::new(), + optimize_goal_in: HashSet::new(), // filled below to avoid infinite recursion explore_group_in, - implement_expression_in: HashSet::new(), - cost_expression_in: HashSet::new(), + implement_expression_in, + cost_expression_in: HashSet::new(), // TODO: design proper costing }; // Add this task to the exploration task's outgoing edges. @@ -351,22 +423,36 @@ impl Optimizer { self.goal_optimization_task_index .insert(goal_repr, goal_optimize_task_id); - Ok(goal_optimize_task_id) - } + // Ensure sub-goals are getting explored too, we do this after registering + // the task to avoid infinite recursion. + let sub_goals = self + .memo + .get_all_goal_members(goal_id) + .await? + .into_iter() + .filter_map(|member_id| match member_id { + GoalMemberId::GoalId(sub_goal_id) => Some(sub_goal_id), + _ => None, + }); + + // Launch optimization tasks for each subgoal and establish links. + let mut subgoal_task_ids = HashSet::new(); + for sub_goal_id in sub_goals { + let sub_goal_task_id = self.ensure_optimize_goal_task(sub_goal_id).await?; + subgoal_task_ids.insert(sub_goal_task_id); + + // Add current task to subgoal's outgoing edges. + self.get_optimize_goal_task_mut(sub_goal_task_id) + .unwrap() + .optimize_goal_out + .insert(goal_optimize_task_id); + } - /// Ensures a cost expression task exists and and returns its id. - /// If a costing task already exists, we reuse it. - /// - /// # Parameters - /// * `expression_id`: The ID of the expression to be costed. - /// - /// # Returns - /// * `TaskId`: The ID of the task that was created or reused. - #[allow(dead_code)] - async fn ensure_cost_expression_task( - &mut self, - _expression_id: PhysicalExpressionId, - ) -> Result { - todo!("What do we decide to do with costing is an open question"); + // Update the parent task with links to subgoals. + self.get_optimize_goal_task_mut(goal_optimize_task_id) + .unwrap() + .optimize_goal_in = subgoal_task_ids; + + Ok(goal_optimize_task_id) } } From bdcfd29a1c13e4cd6817f47333b75d3a4cfb7e7c Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 17:53:15 -0400 Subject: [PATCH 14/23] ingest partial physical done --- optd/src/optimizer/handlers.rs | 12 +++++++----- optd/src/optimizer/mod.rs | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/optd/src/optimizer/handlers.rs b/optd/src/optimizer/handlers.rs index f0e3cef7..92154db5 100644 --- a/optd/src/optimizer/handlers.rs +++ b/optd/src/optimizer/handlers.rs @@ -123,17 +123,19 @@ impl Optimizer { /// # Parameters /// * `plan` - The partial physical plan to process. /// * `goal_id` - ID of the goal associated with this plan. - /// * `job_id` - ID of the job that generated this plan. /// /// # Returns /// * `Result<(), Error>` - Success or error during processing. pub(super) async fn process_new_physical_partial( &mut self, - _plan: PartialPhysicalPlan, - _goal_id: GoalId, - _job_id: JobId, + plan: PartialPhysicalPlan, + goal_id: GoalId, ) -> Result<(), M::MemoError> { - todo!() + let goal_id = self.memo.find_repr_goal_id(goal_id).await?; + let member_id = self.probe_ingest_physical_plan(&plan).await?; + // TODO: Here we would launch costing tasks based on the design. + self.memo.add_goal_member(goal_id, member_id).await?; + Ok(()) } /// This method handles group creation for expressions with derived properties diff --git a/optd/src/optimizer/mod.rs b/optd/src/optimizer/mod.rs index 3b13cd92..8587a0d6 100644 --- a/optd/src/optimizer/mod.rs +++ b/optd/src/optimizer/mod.rs @@ -242,7 +242,7 @@ impl Optimizer { self.process_new_logical_partial(plan, group_id, job_id).await?; } NewPhysicalPartial(plan, goal_id) => { - self.process_new_physical_partial(plan, goal_id, job_id).await?; + self.process_new_physical_partial(plan, goal_id).await?; } CreateGroup(expression_id, properties) => { self.process_create_group(expression_id, &properties, job_id).await?; From c22258afe1074a2d2dffe596fb1aeb3de19d8313 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 19:13:56 -0400 Subject: [PATCH 15/23] dedup implementations --- optd/src/optimizer/merge/helpers.rs | 67 +++++++++++++++++++++-------- optd/src/optimizer/merge/mod.rs | 11 ++--- optd/src/optimizer/tasks/delete.rs | 27 ++++++++++-- 3 files changed, 79 insertions(+), 26 deletions(-) diff --git a/optd/src/optimizer/merge/helpers.rs b/optd/src/optimizer/merge/helpers.rs index 18742974..4dd3a184 100644 --- a/optd/src/optimizer/merge/helpers.rs +++ b/optd/src/optimizer/merge/helpers.rs @@ -18,7 +18,7 @@ impl Optimizer { /// * `task_id` - The ID of the group exploration task to update. /// * `all_logical_exprs` - The complete set of logical expressions known for this group. /// * `principal` - Whether this task is the principal one (responsible for launching transforms). - pub(super) async fn update_group_explore( + pub(super) async fn update_tasks( &mut self, task_id: TaskId, all_logical_exprs: &HashSet, @@ -121,7 +121,7 @@ impl Optimizer { }); } - /// Deduplicate dispatched expressions for a group exploration task. + /// Deduplicate dispatched expressions for a group. /// /// - During optimization, logical expressions in a group may get merged /// in the memo structure. As a result, multiple expressions in the same @@ -130,12 +130,12 @@ impl Optimizer { /// - This function: /// - Maps each expression to its current representative. /// - Identifies and prunes redundant (duplicate) expressions. - /// - Updates or deletes related transform and fork/continuation tasks - /// accordingly. + /// - Updates or deletes related transform, implement and continuation + /// tasks accordingly. /// /// # Arguments /// * `task_id` - The ID of the group exploration task to deduplicate. - pub(super) async fn dedup_group_explore( + pub(super) async fn dedup_tasks( &mut self, task_id: TaskId, ) -> Result<(), M::MemoError> { @@ -143,6 +143,7 @@ impl Optimizer { let old_exprs = std::mem::take(&mut task.dispatched_exprs); let transform_ids: Vec<_> = task.transform_expr_in.iter().copied().collect(); let fork_ids: Vec<_> = task.fork_logical_out.iter().copied().collect(); + let optimize_ids: Vec<_> = task.optimize_goal_out.iter().copied().collect(); let expr_to_repr = self.map_to_representatives(&old_exprs).await?; @@ -159,8 +160,9 @@ impl Optimizer { (seen, dups) }; - self.process_transform_tasks(&transform_ids, &expr_to_repr, &to_delete); - self.process_fork_tasks(&fork_ids, &expr_to_repr, &to_delete); + self.dedup_transform_tasks(&transform_ids, &expr_to_repr, &to_delete); + self.dedup_continue_tasks(&fork_ids, &expr_to_repr, &to_delete); + self.dedup_implement_tasks(&optimize_ids, &expr_to_repr, &to_delete); let task = self.get_explore_group_task_mut(task_id).unwrap(); task.dispatched_exprs = unique_reprs; @@ -182,22 +184,22 @@ impl Optimizer { } /// Update or delete transform tasks based on deduplicated expressions. - fn process_transform_tasks( + fn dedup_transform_tasks( &mut self, - tasks: &[TaskId], + transform_ids: &[TaskId], expr_to_repr: &HashMap, to_delete: &[LogicalExpressionId], ) { - for &task_id in tasks { + for &transform_id in transform_ids { let expr_id = self - .get_transform_expression_task(task_id) + .get_transform_expression_task(transform_id) .unwrap() .expression_id; if to_delete.contains(&expr_id) { - self.delete_task(task_id); + self.delete_task(transform_id); } else { - self.get_transform_expression_task_mut(task_id) + self.get_transform_expression_task_mut(transform_id) .unwrap() .expression_id = expr_to_repr[&expr_id]; } @@ -205,15 +207,15 @@ impl Optimizer { } /// Update or delete continuation tasks spawned by fork tasks. - fn process_fork_tasks( + fn dedup_continue_tasks( &mut self, - tasks: &[TaskId], + fork_ids: &[TaskId], expr_to_repr: &HashMap, to_delete: &[LogicalExpressionId], ) { - for &task_id in tasks { + for &fork_id in fork_ids { let continue_ids = self - .get_fork_logical_task(task_id) + .get_fork_logical_task(fork_id) .unwrap() .continue_with_logical_in .clone(); // Clone to avoid mutable borrow conflict. @@ -234,4 +236,35 @@ impl Optimizer { } } } + + /// Update or delete implement tasks based on deduplicated expressions. + fn dedup_implement_tasks( + &mut self, + optimize_goals: &[TaskId], + expr_to_repr: &HashMap, + to_delete: &[LogicalExpressionId], + ) { + for &optimize_goal in optimize_goals { + let implement_ids = self + .get_optimize_goal_task(optimize_goal) + .unwrap() + .implement_expression_in + .clone(); // Clone to avoid mutable borrow conflict. + + for implement_id in implement_ids { + let expr_id = self + .get_implement_expression_task(implement_id) + .unwrap() + .expression_id; + + if to_delete.contains(&expr_id) { + self.delete_task(implement_id); + } else { + self.get_implement_expression_task_mut(implement_id) + .unwrap() + .expression_id = expr_to_repr[&expr_id]; + } + } + } + } } diff --git a/optd/src/optimizer/merge/mod.rs b/optd/src/optimizer/merge/mod.rs index b553c6c7..d114a69e 100644 --- a/optd/src/optimizer/merge/mod.rs +++ b/optd/src/optimizer/merge/mod.rs @@ -15,7 +15,8 @@ impl Optimizer { /// related transform and continuation tasks. /// /// 2. **Updates**: Sends any new logical expressions to each task, creating appropriate - /// transform tasks (for the principal task) and continuation tasks (for all tasks). + /// transform & implement tasks (for the principal task) and continuation tasks + /// (for all tasks). /// /// 3. **Consolidates**: Merges all secondary tasks into a principal task by transferring /// their dependencies and updating references, ensuring a clean 1:1 mapping between @@ -55,12 +56,12 @@ impl Optimizer { group_explore_tasks.split_first().unwrap(); for task_id in &group_explore_tasks { - // Step 1: Start by deduplicating all transform & continue tasks given - // potentially merged logical expressions. - self.dedup_group_explore(*task_id).await?; + // Step 1: Start by deduplicating all transform, implement & continue tasks + // given potentially merged logical expressions. + self.dedup_tasks(*task_id).await?; // Step 2: Send *new* logical expressions to each task. let is_principal = task_id == principal_task_id; - self.update_group_explore(*task_id, &all_logical_exprs, is_principal) + self.update_tasks(*task_id, &all_logical_exprs, is_principal) .await?; } diff --git a/optd/src/optimizer/tasks/delete.rs b/optd/src/optimizer/tasks/delete.rs index 88b82e6e..7e9c427c 100644 --- a/optd/src/optimizer/tasks/delete.rs +++ b/optd/src/optimizer/tasks/delete.rs @@ -28,7 +28,16 @@ impl Optimizer { self.delete_task(fork_id); } } + ImplementExpression(implement_expression_task) => { + let optimize_goal_task = self + .get_optimize_goal_task_mut(implement_expression_task.optimize_goal_out) + .unwrap(); + optimize_goal_task.implement_expression_in.remove(&task_id); + if let Some(fork_id) = implement_expression_task.fork_in { + self.delete_task(fork_id); + } + } ContinueWithLogical(task) => { let fork_task = self.get_fork_logical_task_mut(task.fork_out).unwrap(); fork_task.continue_with_logical_in.remove(&task_id); @@ -37,12 +46,12 @@ impl Optimizer { self.delete_task(fork_id); } } - ForkLogical(task) => { let explore_task = self .get_explore_group_task_mut(task.explore_group_in) .unwrap(); explore_task.fork_logical_out.remove(&task_id); + // Delete explore task if it has no more purpose. if explore_task.fork_logical_out.is_empty() && explore_task.optimize_goal_out.is_empty() @@ -56,7 +65,6 @@ impl Optimizer { self.delete_task(continue_id); } } - ExploreGroup(task) => { assert!(task.fork_logical_out.is_empty()); assert!(task.optimize_goal_out.is_empty()); @@ -69,10 +77,21 @@ impl Optimizer { self.delete_task(transform_id); } } + OptimizeGoal(task) => { + assert!(task.fork_costed_out.is_empty()); + assert!(task.optimize_goal_out.is_empty()); + assert!(task.optimize_plan_out.is_empty()); - _ => { - todo!(); + self.goal_optimization_task_index + .retain(|_, &mut v| v != task_id); + + let implement_tasks: Vec<_> = + task.implement_expression_in.iter().copied().collect(); + for implement_id in implement_tasks { + self.delete_task(implement_id); + } } + OptimizePlan(_) => todo!(), } // Finally, remove the task from the task collection. From 38ec758ef4d3963c09dac28b7c0311bda4187963 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 19:27:03 -0400 Subject: [PATCH 16/23] launch sub-goals properly after ingestion --- optd/src/optimizer/handlers.rs | 31 ++++++++++++++++++++++++++++-- optd/src/optimizer/mod.rs | 2 +- optd/src/optimizer/tasks/delete.rs | 1 - optd/src/optimizer/tasks/launch.rs | 7 +++++-- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/optd/src/optimizer/handlers.rs b/optd/src/optimizer/handlers.rs index 92154db5..4d696c64 100644 --- a/optd/src/optimizer/handlers.rs +++ b/optd/src/optimizer/handlers.rs @@ -123,6 +123,7 @@ impl Optimizer { /// # Parameters /// * `plan` - The partial physical plan to process. /// * `goal_id` - ID of the goal associated with this plan. + /// * `job_id` - ID of the job that generated this plan. /// /// # Returns /// * `Result<(), Error>` - Success or error during processing. @@ -130,11 +131,37 @@ impl Optimizer { &mut self, plan: PartialPhysicalPlan, goal_id: GoalId, + job_id: JobId, ) -> Result<(), M::MemoError> { + use GoalMemberId::*; + let goal_id = self.memo.find_repr_goal_id(goal_id).await?; + let related_task_id = self.get_related_task_id(job_id); + let member_id = self.probe_ingest_physical_plan(&plan).await?; - // TODO: Here we would launch costing tasks based on the design. - self.memo.add_goal_member(goal_id, member_id).await?; + let added = self.memo.add_goal_member(goal_id, member_id).await?; + + match member_id { + PhysicalExpressionId(_) => { + // TODO: Here we would launch costing tasks based on the design. + } + GoalId(goal_id) => { + if added { + // Optimize the new sub-goal and add to task graph. + let sub_optimize_task_id = self.ensure_optimize_goal_task(goal_id).await?; + + self.get_optimize_goal_task_mut(sub_optimize_task_id) + .unwrap() + .optimize_goal_out + .insert(related_task_id); + self.get_optimize_goal_task_mut(related_task_id) + .unwrap() + .optimize_goal_in + .insert(sub_optimize_task_id); + } + } + } + Ok(()) } diff --git a/optd/src/optimizer/mod.rs b/optd/src/optimizer/mod.rs index 8587a0d6..3b13cd92 100644 --- a/optd/src/optimizer/mod.rs +++ b/optd/src/optimizer/mod.rs @@ -242,7 +242,7 @@ impl Optimizer { self.process_new_logical_partial(plan, group_id, job_id).await?; } NewPhysicalPartial(plan, goal_id) => { - self.process_new_physical_partial(plan, goal_id).await?; + self.process_new_physical_partial(plan, goal_id, job_id).await?; } CreateGroup(expression_id, properties) => { self.process_create_group(expression_id, &properties, job_id).await?; diff --git a/optd/src/optimizer/tasks/delete.rs b/optd/src/optimizer/tasks/delete.rs index 7e9c427c..4bb93485 100644 --- a/optd/src/optimizer/tasks/delete.rs +++ b/optd/src/optimizer/tasks/delete.rs @@ -78,7 +78,6 @@ impl Optimizer { } } OptimizeGoal(task) => { - assert!(task.fork_costed_out.is_empty()); assert!(task.optimize_goal_out.is_empty()); assert!(task.optimize_plan_out.is_empty()); diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index af550ac7..d07528a9 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -334,7 +334,7 @@ impl Optimizer { /// /// # Returns /// * `TaskId`: The ID of the task that was created or reused. - async fn ensure_group_exploration_task( + pub(crate) async fn ensure_group_exploration_task( &mut self, group_id: GroupId, ) -> Result { @@ -380,7 +380,10 @@ impl Optimizer { /// /// # Returns /// * `TaskId`: The ID of the task that was created or reused. - async fn ensure_optimize_goal_task(&mut self, goal_id: GoalId) -> Result { + pub(crate) async fn ensure_optimize_goal_task( + &mut self, + goal_id: GoalId, + ) -> Result { use Task::*; // Find the representative goal for the given goal ID. From 9676a6f8bf5e8a558bdfe1e3925482b1400ed394 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 19:28:27 -0400 Subject: [PATCH 17/23] add async recursion macro --- optd/src/optimizer/tasks/launch.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index d07528a9..47c0c9c8 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -11,6 +11,7 @@ use crate::{ }, }, }; +use async_recursion::async_recursion; use hashbrown::HashSet; use tokio::sync::mpsc::Sender; @@ -380,6 +381,7 @@ impl Optimizer { /// /// # Returns /// * `TaskId`: The ID of the task that was created or reused. + #[async_recursion] pub(crate) async fn ensure_optimize_goal_task( &mut self, goal_id: GoalId, From 01e0c88e19aebe16ab38078a8c46bec50c0eadd6 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 20:06:48 -0400 Subject: [PATCH 18/23] update physical impl --- optd/src/optimizer/merge/helpers.rs | 95 ++++++++++++++++++----------- optd/src/optimizer/merge/mod.rs | 3 +- 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/optd/src/optimizer/merge/helpers.rs b/optd/src/optimizer/merge/helpers.rs index 4dd3a184..b5c7276f 100644 --- a/optd/src/optimizer/merge/helpers.rs +++ b/optd/src/optimizer/merge/helpers.rs @@ -12,7 +12,13 @@ impl Optimizer { /// 1. Computes newly discovered expressions by subtracting existing ones. /// 2. For each fork task in the group, launches continuation tasks for each new expression. /// 3. If the task is the principal, launches transform tasks for each new expression and rule. - /// 4. Updates the task's dispatched expressions with the full input set. + /// 4. For all related optimize goal tasks, launch implement tasks for each new expression. + /// *NOTE*: This happens before merging goals. While this might be slightly inefficient, as + /// we might implement twice for the same "soon-to-be" merged goals. However it keeps the code + /// cleaner and easier to understand. Also, the performance impact is negligible, as once we + /// merge the goals, we effectively delete the implement tasks and its associated jobs will + /// *not* be launched. + /// 5. Updates the task's dispatched expressions with the full input set. /// /// # Arguments /// * `task_id` - The ID of the group exploration task to update. @@ -26,45 +32,67 @@ impl Optimizer { ) -> Result<(), M::MemoError> { let new_exprs = self.compute_new_expressions(task_id, all_logical_exprs); - if !new_exprs.is_empty() { - let (group_id, fork_tasks) = { - let task = self.get_explore_group_task(task_id).unwrap(); - (task.group_id, task.fork_logical_out.clone()) - }; + if new_exprs.is_empty() { + return Ok(()); + } - for &fork_task_id in &fork_tasks { - let continuation = self - .get_fork_logical_task(fork_task_id) - .unwrap() - .continuation - .clone(); + let (group_id, fork_tasks, optimize_goal_tasks) = { + let task = self.get_explore_group_task(task_id).unwrap(); + ( + task.group_id, + task.fork_logical_out.clone(), + task.optimize_goal_out.clone(), + ) + }; - let continuation_tasks = self.create_logical_cont_tasks( - &new_exprs, - group_id, - fork_task_id, - &continuation, - ); + // For each fork task, create continuation tasks for each new expression. + fork_tasks.iter().for_each(|&fork_task_id| { + let continuation = self + .get_fork_logical_task(fork_task_id) + .unwrap() + .continuation + .clone(); - self.get_fork_logical_task_mut(fork_task_id) - .unwrap() - .continue_with_logical_in - .extend(continuation_tasks); - } + let continuation_tasks = + self.create_logical_cont_tasks(&new_exprs, group_id, fork_task_id, &continuation); - if principal { - let transform_tasks = self.create_transform_tasks(&new_exprs, group_id, task_id); - self.get_explore_group_task_mut(task_id) - .unwrap() - .transform_expr_in - .extend(transform_tasks); - } + self.get_fork_logical_task_mut(fork_task_id) + .unwrap() + .continue_with_logical_in + .extend(continuation_tasks); + }); + + // For each optimize goal task, create implement tasks for each new expression. + optimize_goal_tasks.iter().for_each(|&optimize_goal_id| { + let goal_id = self + .get_optimize_goal_task(optimize_goal_id) + .unwrap() + .goal_id; + + let implement_tasks = + self.create_implement_tasks(&new_exprs, goal_id, optimize_goal_id); + + self.get_optimize_goal_task_mut(optimize_goal_id) + .unwrap() + .implement_expression_in + .extend(implement_tasks); + }); + + // For the principal task, create transform tasks for each new expression. + // We could always do it, but this is a straightforward optimization. + if principal { + let transform_tasks = self.create_transform_tasks(&new_exprs, group_id, task_id); self.get_explore_group_task_mut(task_id) .unwrap() - .dispatched_exprs = all_logical_exprs.clone(); + .transform_expr_in + .extend(transform_tasks); } + self.get_explore_group_task_mut(task_id) + .unwrap() + .dispatched_exprs = all_logical_exprs.clone(); + Ok(()) } @@ -135,10 +163,7 @@ impl Optimizer { /// /// # Arguments /// * `task_id` - The ID of the group exploration task to deduplicate. - pub(super) async fn dedup_tasks( - &mut self, - task_id: TaskId, - ) -> Result<(), M::MemoError> { + pub(super) async fn dedup_tasks(&mut self, task_id: TaskId) -> Result<(), M::MemoError> { let task = self.get_explore_group_task_mut(task_id).unwrap(); let old_exprs = std::mem::take(&mut task.dispatched_exprs); let transform_ids: Vec<_> = task.transform_expr_in.iter().copied().collect(); diff --git a/optd/src/optimizer/merge/mod.rs b/optd/src/optimizer/merge/mod.rs index d114a69e..60771b71 100644 --- a/optd/src/optimizer/merge/mod.rs +++ b/optd/src/optimizer/merge/mod.rs @@ -15,8 +15,7 @@ impl Optimizer { /// related transform and continuation tasks. /// /// 2. **Updates**: Sends any new logical expressions to each task, creating appropriate - /// transform & implement tasks (for the principal task) and continuation tasks - /// (for all tasks). + /// transform (for the principal task), implement and continuation tasks (for all tasks). /// /// 3. **Consolidates**: Merges all secondary tasks into a principal task by transferring /// their dependencies and updating references, ensuring a clean 1:1 mapping between From 30a90285d5273c1650a825111ecdea9b5d26e1bf Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 21:02:03 -0400 Subject: [PATCH 19/23] finish log->phy optimizer --- optd/src/optimizer/merge/helpers.rs | 59 +++++++++++++++++++ optd/src/optimizer/merge/mod.rs | 89 ++++++++++++++++++++++++++--- optd/src/optimizer/tasks/launch.rs | 4 +- optd/src/optimizer/tasks/mod.rs | 4 -- 4 files changed, 142 insertions(+), 14 deletions(-) diff --git a/optd/src/optimizer/merge/helpers.rs b/optd/src/optimizer/merge/helpers.rs index b5c7276f..377f20bd 100644 --- a/optd/src/optimizer/merge/helpers.rs +++ b/optd/src/optimizer/merge/helpers.rs @@ -146,6 +146,65 @@ impl Optimizer { let goal_task = self.get_optimize_goal_task_mut(goal_id).unwrap(); goal_task.explore_group_in = principal_task_id; }); + + // *NOTE*: No need to consolidate transformations as all missing + // expressions have already been added during the update phase. + }); + } + + /// Consolidate a goal optimization task into a principal task. + /// + /// - Similar to the group exploration consolidation, a merge in the memo may cause + /// several goal optimization tasks to refer to the same underlying goal. This function + /// consolidates all such secondary tasks into a principal one. + /// + /// - This involves: + /// - Moving all outgoing optimize goal tasks, incoming optimize goal tasks, and + /// optimize plan tasks from secondary to principal. + /// - Updating each such task's references to point to the principal task. + /// - Deleting the secondary tasks. + /// + /// # Arguments + /// * `principal_task_id` - The ID of the task to retain as canonical. + /// * `secondary_task_ids` - All other task IDs to merge into the principal. + pub(super) async fn consolidate_goal_optimize( + &mut self, + principal_task_id: TaskId, + secondary_task_ids: &[TaskId], + ) { + secondary_task_ids.iter().for_each(|&task_id| { + let task = self.get_optimize_goal_task_mut(task_id).unwrap(); + + // Move out the task sets before deletion. + let optimize_out_tasks = std::mem::take(&mut task.optimize_goal_out); + let optimize_in_tasks = std::mem::take(&mut task.optimize_goal_in); + let optimize_plan_tasks = std::mem::take(&mut task.optimize_plan_out); + + self.delete_task(task_id); + + optimize_out_tasks.into_iter().for_each(|optimize_id| { + let optimize_goal_task = self.get_optimize_goal_task_mut(optimize_id).unwrap(); + optimize_goal_task.optimize_goal_in.remove(&task_id); + optimize_goal_task + .optimize_goal_in + .insert(principal_task_id); + }); + + optimize_in_tasks.into_iter().for_each(|optimize_id| { + let optimize_goal_task = self.get_optimize_goal_task_mut(optimize_id).unwrap(); + optimize_goal_task.optimize_goal_out.remove(&task_id); + optimize_goal_task + .optimize_goal_out + .insert(principal_task_id); + }); + + optimize_plan_tasks.into_iter().for_each(|optimize_id| { + let optimize_plan_task = self.get_optimize_plan_task_mut(optimize_id).unwrap(); + optimize_plan_task.optimize_goal_in = Some(principal_task_id); + }); + + // *NOTE*: No need to consolidate implementations as all missing + // expressions have already been added during the update phase. }); } diff --git a/optd/src/optimizer/merge/mod.rs b/optd/src/optimizer/merge/mod.rs index 60771b71..8d2bb618 100644 --- a/optd/src/optimizer/merge/mod.rs +++ b/optd/src/optimizer/merge/mod.rs @@ -1,11 +1,29 @@ use super::Optimizer; -use crate::memo::{Memo, MergeGroupProduct, MergeProducts}; +use crate::memo::{Memo, MergeGoalProduct, MergeGroupProduct, MergeProducts}; mod helpers; impl Optimizer { /// Processes merge results by updating the task graph to reflect merges in the memo. /// + /// This function handles both group merges and goal merges by delegating to specialized + /// handlers for each type of merge. + /// + /// # Parameters + /// * `result` - The merge result to handle, containing information about merged groups and goals. + /// + /// # Returns + /// * `Result<(), OptimizeError>` - Success or an error that occurred during processing. + pub(super) async fn handle_merge_result( + &mut self, + result: MergeProducts, + ) -> Result<(), M::MemoError> { + self.handle_group_merges(&result.group_merges).await?; + self.handle_goal_merges(&result.goal_merges).await + } + + /// Handles merges of logical groups by updating the task graph. + /// /// When groups are merged in the memo, multiple exploration tasks may now refer to /// the same underlying group. This method handles the task graph updates required /// by such merges. For each merged group, it: @@ -28,18 +46,18 @@ impl Optimizer { /// with it, containing all logical expressions from the original groups with no duplicates. /// /// # Parameters - /// * `result` - The merge result to handle, containing information about merged groups. + /// * `group_merges` - A slice of group merge products to handle. /// /// # Returns /// * `Result<(), OptimizeError>` - Success or an error that occurred during processing. - pub(super) async fn handle_merge_result( + async fn handle_group_merges( &mut self, - result: MergeProducts, + group_merges: &[MergeGroupProduct], ) -> Result<(), M::MemoError> { for MergeGroupProduct { new_group_id, merged_groups, - } in result.group_merges + } in group_merges { // For each merged group, get all group exploration tasks. // We don't need to check for the new group ID since it is guaranteed to be new. @@ -49,7 +67,7 @@ impl Optimizer { .collect(); if !group_explore_tasks.is_empty() { - let all_logical_exprs = self.memo.get_all_logical_exprs(new_group_id).await?; + let all_logical_exprs = self.memo.get_all_logical_exprs(*new_group_id).await?; let (principal_task_id, secondary_task_ids) = group_explore_tasks.split_first().unwrap(); @@ -70,7 +88,64 @@ impl Optimizer { // Step 4: Set the index to point to the new representative task. self.group_exploration_task_index - .insert(new_group_id, *principal_task_id); + .insert(*new_group_id, *principal_task_id); + } + } + + Ok(()) + } + + /// Handles merges of optimization goals by updating the task graph. + /// + /// When goals are merged in the memo, multiple optimization tasks may now refer to + /// the same underlying goal. This method handles the task graph updates required + /// by such merges. For each merged goal, it: + /// + /// 1. **Consolidates**: Merges all secondary tasks into a principal task by transferring + /// their dependencies (incoming and outgoing optimization tasks and plan tasks) and + /// updating references, ensuring a clean 1:1 mapping between goals and optimization tasks. + /// + /// 2. **Re-indexes**: Updates the goal optimization index to point to the principal task + /// for the new goal ID. + /// + /// After processing, each merged goal will have exactly one optimization task associated + /// with it, with all dependencies properly redirected to it. + /// + /// # Parameters + /// * `goal_merges` - A slice of goal merge products to handle. + /// + /// # Returns + /// * `Result<(), OptimizeError>` - Success or an error that occurred during processing. + async fn handle_goal_merges( + &mut self, + goal_merges: &[MergeGoalProduct], + ) -> Result<(), M::MemoError> { + for MergeGoalProduct { + new_goal_id, + merged_goals, + } in goal_merges + { + // For each merged goal, get all goal optimization tasks. + // We don't need to check for the new goal ID since it is guaranteed to be new. + let goal_optimize_tasks: Vec<_> = merged_goals + .iter() + .filter_map(|goal_id| self.goal_optimization_task_index.get(goal_id).copied()) + .collect(); + + if !goal_optimize_tasks.is_empty() { + let (principal_task_id, secondary_task_ids) = + goal_optimize_tasks.split_first().unwrap(); + + // *NOTE*: Deduplication and updates of implementation tasks have already been + // handled in the `handle_group_merges` method, so we don't need to do it here. + + // Step 1: Consolidate all dependent tasks into the new "representative" task. + self.consolidate_goal_optimize(*principal_task_id, secondary_task_ids) + .await; + + // Step 2: Set the index to point to the new representative task. + self.goal_optimization_task_index + .insert(*new_goal_id, *principal_task_id); } } diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index 47c0c9c8..73c34995 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -410,11 +410,9 @@ impl Optimizer { goal_id, optimize_plan_out: HashSet::new(), optimize_goal_out: HashSet::new(), // filled below to avoid infinite recursion - fork_costed_out: HashSet::new(), - optimize_goal_in: HashSet::new(), // filled below to avoid infinite recursion + optimize_goal_in: HashSet::new(), // filled below to avoid infinite recursion explore_group_in, implement_expression_in, - cost_expression_in: HashSet::new(), // TODO: design proper costing }; // Add this task to the exploration task's outgoing edges. diff --git a/optd/src/optimizer/tasks/mod.rs b/optd/src/optimizer/tasks/mod.rs index ffba21e7..ca0ec022 100644 --- a/optd/src/optimizer/tasks/mod.rs +++ b/optd/src/optimizer/tasks/mod.rs @@ -58,8 +58,6 @@ pub(crate) struct OptimizeGoalTask { /// `OptimizeGoalTask` parent goals that this task is simultaneously /// producing for. pub optimize_goal_out: HashSet, - /// `ForkCostedTask` subscribed to this goal. - pub fork_costed_out: HashSet, // Input tasks that feed this task. /// `OptimizeGoalTask` member (children) goals producing for this goal. @@ -69,8 +67,6 @@ pub(crate) struct OptimizeGoalTask { pub explore_group_in: TaskId, /// `ImplementExpressionTask` rules that are implementing logical expressions. pub implement_expression_in: HashSet, - /// `CostExpressionTask` costing of physical expressions. - pub cost_expression_in: HashSet, } /// Task to explore expressions in a logical group. From e3c226483ace6a42278344cf37475c1437d1f3cf Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Tue, 20 May 2025 21:07:03 -0400 Subject: [PATCH 20/23] clippy and fmt --- optd/src/dsl/analyzer/hir/mod.rs | 15 ++++++--------- optd/src/optimizer/merge/helpers.rs | 10 ++++++---- optd/src/optimizer/tasks/launch.rs | 2 +- optd/src/optimizer/tasks/mod.rs | 3 ++- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/optd/src/dsl/analyzer/hir/mod.rs b/optd/src/dsl/analyzer/hir/mod.rs index 92c1cb68..bcdf8106 100644 --- a/optd/src/dsl/analyzer/hir/mod.rs +++ b/optd/src/dsl/analyzer/hir/mod.rs @@ -73,17 +73,14 @@ impl TypedSpan { } } +/// Type aliases for user-defined functions (UDFs). +type UdfFutureOutput = Pin + Send>>; +type UdfFunction = + dyn Fn(Vec, Arc, Arc) -> UdfFutureOutput + Send + Sync; + #[derive(Clone)] pub struct Udf { - pub func: Arc< - dyn Fn( - Vec, - Arc, - Arc, - ) -> Pin + Send>> - + Send - + Sync, - >, + pub func: Arc, } impl fmt::Debug for Udf { diff --git a/optd/src/optimizer/merge/helpers.rs b/optd/src/optimizer/merge/helpers.rs index 377f20bd..68c10af5 100644 --- a/optd/src/optimizer/merge/helpers.rs +++ b/optd/src/optimizer/merge/helpers.rs @@ -13,11 +13,13 @@ impl Optimizer { /// 2. For each fork task in the group, launches continuation tasks for each new expression. /// 3. If the task is the principal, launches transform tasks for each new expression and rule. /// 4. For all related optimize goal tasks, launch implement tasks for each new expression. + /// /// *NOTE*: This happens before merging goals. While this might be slightly inefficient, as - /// we might implement twice for the same "soon-to-be" merged goals. However it keeps the code - /// cleaner and easier to understand. Also, the performance impact is negligible, as once we - /// merge the goals, we effectively delete the implement tasks and its associated jobs will - /// *not* be launched. + /// we might implement twice for the same "soon-to-be" merged goals. However it keeps + /// the code cleaner and easier to understand. Also, the performance impact is negligible, + /// as once we merge the goals, we effectively delete the implement tasks and its + /// associated jobs will *not* be launched. + /// /// 5. Updates the task's dispatched expressions with the full input set. /// /// # Arguments diff --git a/optd/src/optimizer/tasks/launch.rs b/optd/src/optimizer/tasks/launch.rs index 73c34995..fddcd019 100644 --- a/optd/src/optimizer/tasks/launch.rs +++ b/optd/src/optimizer/tasks/launch.rs @@ -210,7 +210,7 @@ impl Optimizer { let task_id = self.next_task_id(); let task = ImplementExpressionTask { - rule: rule.clone(), + _rule: rule.clone(), expression_id: expr_id, optimize_goal_out, fork_in: None, diff --git a/optd/src/optimizer/tasks/mod.rs b/optd/src/optimizer/tasks/mod.rs index ca0ec022..d6fcc498 100644 --- a/optd/src/optimizer/tasks/mod.rs +++ b/optd/src/optimizer/tasks/mod.rs @@ -112,7 +112,8 @@ pub(crate) struct TransformExpressionTask { #[derive(Clone)] pub(crate) struct ImplementExpressionTask { /// The implementation rule to apply. - pub rule: ImplementationRule, + /// NOTE: Variable not used but kept for observability. + pub _rule: ImplementationRule, /// The logical expression to implement. pub expression_id: LogicalExpressionId, From bebe1f4db78080e0081bf4613ae747996bddeac8 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 21 May 2025 09:48:01 -0400 Subject: [PATCH 21/23] Add memo dump debug util --- optd/src/demo/demo.opt | 18 +++++- optd/src/demo/mod.rs | 26 ++++++--- optd/src/memo/memory/implementation.rs | 80 +++++++++++++++++++++++++- optd/src/memo/mod.rs | 3 + optd/src/optimizer/mod.rs | 62 +++++++++++++------- 5 files changed, 156 insertions(+), 33 deletions(-) diff --git a/optd/src/demo/demo.opt b/optd/src/demo/demo.opt index 5e41511c..b44d5dc1 100644 --- a/optd/src/demo/demo.opt +++ b/optd/src/demo/demo.opt @@ -1,4 +1,3 @@ -data Physical data PhysicalProperties data Statistics // Taking folded here is not the most interesting property, @@ -12,6 +11,13 @@ data Logical = | Div(left: Logical, right: Logical) \ Const(val: I64) +data Physical = + | PhysicalAdd(left: Physical, right: Physical) + | PhysicalSub(left: Physical, right: Physical) + | PhysicalMult(left: Physical, right: Physical) + | PhysicalDiv(left: Physical, right: Physical) + \ PhysicalConst(val: I64) + // This will be the input plan that will be optimized. // Result is: ((1 - 2) * (3 / 4)) + ((5 - 6) * (7 / 8)) = 0 fn input(): Logical = @@ -71,4 +77,12 @@ fn (op: Logical*) const_fold_sub(): Logical? = match op fn (op: Logical*) const_fold_div(): Logical? = match op | Div(Const(a), Const(b)) -> if b == 0 then none else Const(a / b) - \ _ -> none \ No newline at end of file + \ _ -> none + +[implementation] +fn (op: Logical*) to_physical(props: PhysicalProperties?) = match op + | Add(left, right) -> PhysicalAdd(left.to_physical(props), right.to_physical(props)) + | Sub(left, right) -> PhysicalSub(left.to_physical(props), right.to_physical(props)) + | Mult(left, right) -> PhysicalMult(left.to_physical(props), right.to_physical(props)) + | Div(left, right) -> PhysicalDiv(left.to_physical(props), right.to_physical(props)) + \ Const(val) -> PhysicalConst(val) \ No newline at end of file diff --git a/optd/src/demo/mod.rs b/optd/src/demo/mod.rs index ce1f568f..5098085d 100644 --- a/optd/src/demo/mod.rs +++ b/optd/src/demo/mod.rs @@ -7,10 +7,13 @@ use crate::{ utils::retriever::{MockRetriever, Retriever}, }, memo::MemoryMemo, - optimizer::{OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical}, + optimizer::{ClientRequest, OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical}, }; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{sync::mpsc, time::timeout}; +use tokio::{ + sync::mpsc, + time::{sleep, timeout}, +}; pub async fn properties( args: Vec, @@ -64,15 +67,15 @@ async fn run_demo() { let optimize_channel = Optimizer::launch(memo, catalog, hir); let (tx, mut rx) = mpsc::channel(1); optimize_channel - .send(OptimizeRequest { - plan: logical_plan, - physical_tx: tx, - }) + .send(ClientRequest::Optimize(OptimizeRequest { + plan: logical_plan.clone(), + physical_tx: tx.clone(), + })) .await .unwrap(); - // Timeout after 2 seconds. - let timeout_duration = Duration::from_secs(2); + // Timeout after 5 seconds. + let timeout_duration = Duration::from_secs(5); let result = timeout(timeout_duration, async { while let Some(response) = rx.recv().await { println!("Received response: {:?}", response); @@ -84,6 +87,13 @@ async fn run_demo() { Ok(_) => println!("Finished receiving responses."), Err(_) => println!("Timed out after 5 seconds."), } + + // Dump the memo (debug utility). + optimize_channel + .send(ClientRequest::DumpMemo) + .await + .unwrap(); + sleep(Duration::from_secs(10)).await; } #[cfg(test)] diff --git a/optd/src/memo/memory/implementation.rs b/optd/src/memo/memory/implementation.rs index fc948d2b..3f79f2ac 100644 --- a/optd/src/memo/memory/implementation.rs +++ b/optd/src/memo/memory/implementation.rs @@ -8,6 +8,84 @@ use hashbrown::{HashMap, HashSet}; use std::collections::VecDeque; impl Memo for MemoryMemo { + async fn debug_dump(&self) -> Result<(), Infallible> { + println!("\n===== MEMO TABLE DUMP ====="); + println!("---- GROUPS ----"); + + // Get all group IDs and sort them for consistent output. + let mut group_ids: Vec<_> = self.group_info.keys().copied().collect(); + group_ids.sort_by_key(|id| id.0); + + for group_id in group_ids { + println!("Group {:?}:", group_id); + + // Print logical properties. + let group_info = self.group_info.get(&group_id).unwrap(); + println!(" Properties: {:?}", group_info.logical_properties); + + // Print expressions in group. + println!(" Expressions:"); + for expr_id in &group_info.expressions { + let expr = self.id_to_logical_expr.get(expr_id).unwrap(); + let repr_id = self.find_repr_logical_expr_id(*expr_id).await; + println!(" {:?} [{:?}]: {:?}", expr_id, repr_id, expr); + } + + // Print goals in this group (if any). + if !group_info.goals.is_empty() { + println!(" Goals:"); + for (props, goal_ids) in &group_info.goals { + for goal_id in goal_ids { + println!(" Goal {:?} with properties: {:?}", goal_id, props); + } + } + } + } + + println!("---- GOALS ----"); + + // Get all goal IDs and sort them. + let mut goal_ids: Vec<_> = self.goal_info.keys().copied().collect(); + goal_ids.sort_by_key(|id| id.0); + + for goal_id in goal_ids { + println!("Goal {:?}:", goal_id); + + // Print goal definition. + let goal_info = self.goal_info.get(&goal_id).unwrap(); + println!(" Definition: {:?}", goal_info.goal); + + // Print members. + println!(" Members:"); + for member in &goal_info.members { + match member { + GoalMemberId::GoalId(sub_goal_id) => { + println!(" Sub-goal: {:?}", sub_goal_id); + } + GoalMemberId::PhysicalExpressionId(phys_id) => { + let expr = self.id_to_physical_expr.get(phys_id).unwrap(); + println!(" Physical expr {:?}: {:?}", phys_id, expr); + } + } + } + } + + println!("---- PHYSICAL EXPRESSIONS ----"); + + // Get all physical expression IDs and sort them + let mut phys_expr_ids: Vec<_> = self.id_to_physical_expr.keys().copied().collect(); + phys_expr_ids.sort_by_key(|id| id.0); + + for phys_id in phys_expr_ids { + let expr = self.id_to_physical_expr.get(&phys_id).unwrap(); + println!("Physical expr {:?}: {:?}", phys_id, expr); + } + + println!("===== END MEMO TABLE DUMP ====="); + + Ok(()) + } + async fn get_logical_properties( &self, group_id: GroupId, @@ -171,8 +249,6 @@ impl Memo for MemoryMemo { // dependent on the merged goals (and the recursively merged expressions themselves). let expr_merges = self.merge_dependent_physical_exprs(&goal_merges).await?; - println!("Group_info: {:?}", self.group_info); - Ok(MergeProducts { group_merges, goal_merges, diff --git a/optd/src/memo/mod.rs b/optd/src/memo/mod.rs index d2a21397..fe7ba5ea 100644 --- a/optd/src/memo/mod.rs +++ b/optd/src/memo/mod.rs @@ -137,6 +137,9 @@ pub trait Representative: MemoBase { /// expressions, and finding representative nodes of the union-find substructures. #[trait_variant::make(Send)] pub trait Memo: Representative + Materialize + Sync + 'static { + /// Prints the contents of the memo table to the console for debugging purposes. + async fn debug_dump(&self) -> Result<(), Self::MemoError>; + /// Retrieves logical properties for a group ID. /// /// # Parameters diff --git a/optd/src/optimizer/mod.rs b/optd/src/optimizer/mod.rs index 3b13cd92..45d42ac7 100644 --- a/optd/src/optimizer/mod.rs +++ b/optd/src/optimizer/mod.rs @@ -24,7 +24,17 @@ use tasks::{Task, TaskId}; /// Default maximum number of concurrent jobs to run in the optimizer. const DEFAULT_MAX_CONCURRENT_JOBS: usize = 1000; -/// External client request to optimize a query in the optimizer. +/// External client requests that can be sent to the optimizer. +#[derive(Clone, Debug)] +pub enum ClientRequest { + /// Request to optimize a logical plan into a physical plan. + Optimize(OptimizeRequest), + + /// Request to dump the contents of the memo for debugging purposes. + DumpMemo, +} + +/// Request to optimize a query in the optimizer. /// /// Defines the public API for submitting a query and receiving execution plans. #[derive(Clone, Debug)] @@ -61,7 +71,7 @@ enum EngineProduct { /// Each message represents either a client request or the result of completed work, /// allowing the optimizer to track which tasks are progressing. enum OptimizerMessage { - /// Client request to optimize a plan. + /// Client request to the optimizer. Request(OptimizeRequest, TaskId), /// Request to retrieve the properties of a group. @@ -115,7 +125,7 @@ pub struct Optimizer { pending_messages: Vec, message_tx: Sender, message_rx: Receiver, - optimize_rx: Receiver, + client_rx: Receiver, // Task management. tasks: HashMap, @@ -144,7 +154,7 @@ impl Optimizer { catalog: Arc, message_tx: Sender, message_rx: Receiver, - optimize_rx: Receiver, + client_rx: Receiver, ) -> Self { Self { // Core components. @@ -158,7 +168,7 @@ impl Optimizer { pending_messages: Vec::new(), message_tx, message_rx, - optimize_rx, + client_rx, // Task management. tasks: HashMap::new(), @@ -179,9 +189,9 @@ impl Optimizer { } /// Launch a new optimizer and return a sender for client communication. - pub fn launch(memo: M, catalog: Arc, hir: HIR) -> Sender { + pub fn launch(memo: M, catalog: Arc, hir: HIR) -> Sender { let (message_tx, message_rx) = mpsc::channel(1); - let (optimize_tx, optimize_rx) = mpsc::channel(1); + let (client_tx, client_rx) = mpsc::channel(1); // Start the background processing loop. let optimizer = Self::new( @@ -190,7 +200,7 @@ impl Optimizer { catalog, message_tx.clone(), message_rx, - optimize_rx, + client_rx, ); tokio::spawn(async move { @@ -199,30 +209,40 @@ impl Optimizer { optimizer.run().await.expect("Optimizer failure"); }); - optimize_tx + client_tx } /// Run the optimizer's main processing loop. async fn run(mut self) -> Result<(), M::MemoError> { + use ClientRequest::*; use EngineProduct::*; use OptimizerMessage::*; loop { tokio::select! { - Some(request) = self.optimize_rx.recv() => { - let task_id = self.create_optimize_plan_task(request.plan.clone(), - request.physical_tx.clone()); + Some(client_request) = self.client_rx.recv() => { let message_tx = self.message_tx.clone(); - // Forward the optimization request to the message processing loop - // in a new coroutine to avoid a deadlock. - tokio::spawn( - async move { - message_tx.send(Request(request, task_id)) - .await - .expect("Failed to forward optimize request - channel closed"); + match client_request { + Optimize(optimize_request) => { + // Create a task for the optimization request. + let task_id = self.create_optimize_plan_task( + optimize_request.plan.clone(), + optimize_request.physical_tx.clone() + ); + + // Forward the client request to the message processing loop + // in a new coroutine to avoid a deadlock. + tokio::spawn(async move { + message_tx.send(Request(optimize_request, task_id)) + .await + .expect("Failed to forward client request - channel closed"); + }); + }, + DumpMemo => { + self.memo.debug_dump().await?; } - ); + } }, Some(message) = self.message_rx.recv() => { // Process the next message in the channel. @@ -231,7 +251,7 @@ impl Optimizer { self.process_optimize_request(plan, physical_tx, task_id).await?, Retrieve(group_id, response_tx) => { self.process_retrieve_properties(group_id, response_tx).await?; - } + }, Product(product, job_id) => { let task_id = self.get_related_task_id(job_id); From 2c9cb38275b4f536080da33a3522cfaf5cfef82d Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 21 May 2025 09:57:35 -0400 Subject: [PATCH 22/23] bring back fix --- optd/src/memo/memory/helpers.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index 63df5c91..52d9b72b 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -425,14 +425,20 @@ impl MemoryMemoHelper for MemoryMemo { ) -> Result, Infallible> { // Merge the set of expressions that reference these two groups into one // that references the new group. - let exprs1 = self + let exprs1: HashSet = self .group_referencing_exprs_index .remove(&group_id_1) - .unwrap_or_default(); + .unwrap_or_default() + .into_iter() + .map(|id| self.repr_logical_expr_id.find(&id)) + .collect(); let exprs2 = self .group_referencing_exprs_index .remove(&group_id_2) - .unwrap_or_default(); + .unwrap_or_default() + .into_iter() + .map(|id| self.repr_logical_expr_id.find(&id)) + .collect(); let new_set = exprs1.union(&exprs2).copied().collect(); // Update the index for the new group / set of logical expressions. From 30e571e94e530a826a466de53259e3ba4bf2a30a Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Thu, 22 May 2025 09:06:34 -0400 Subject: [PATCH 23/23] refactor regression fix in memo --- optd/src/memo/memory/helpers.rs | 46 +++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/optd/src/memo/memory/helpers.rs b/optd/src/memo/memory/helpers.rs index 52d9b72b..88727ac5 100644 --- a/optd/src/memo/memory/helpers.rs +++ b/optd/src/memo/memory/helpers.rs @@ -35,6 +35,32 @@ impl MemoryMemo { self.next_shared_id += 1; PhysicalExpressionId(id) } + + /// Takes the set of [`LogicalExpressionId`] that reference a group, mapped to their + /// representatives. + fn take_referencing_expr_set(&mut self, group_id: GroupId) -> HashSet { + self.group_referencing_exprs_index + .remove(&group_id) + .unwrap_or_default() + .iter() + .map(|id| self.repr_logical_expr_id.find(id)) + .collect() + } + + /// Merges the two sets of logical expressions that reference the two groups into a single set + /// of expressions under a new [`GroupId`]. + /// + /// If a group does not exist, then the set of expressions referencing it is the empty set. + fn merge_referencing_exprs(&mut self, group1: GroupId, group2: GroupId, new_group: GroupId) { + // Remove the entries for the original two groups that we want to merge. + let exprs1 = self.take_referencing_expr_set(group1); + let exprs2 = self.take_referencing_expr_set(group2); + let new_set = exprs1.union(&exprs2).copied().collect(); + + // Update the index for the new group / set of logical expressions. + self.group_referencing_exprs_index + .insert(new_group, new_set); + } } /// Helper functions for the in-memory memo table implementation. @@ -425,25 +451,7 @@ impl MemoryMemoHelper for MemoryMemo { ) -> Result, Infallible> { // Merge the set of expressions that reference these two groups into one // that references the new group. - let exprs1: HashSet = self - .group_referencing_exprs_index - .remove(&group_id_1) - .unwrap_or_default() - .into_iter() - .map(|id| self.repr_logical_expr_id.find(&id)) - .collect(); - let exprs2 = self - .group_referencing_exprs_index - .remove(&group_id_2) - .unwrap_or_default() - .into_iter() - .map(|id| self.repr_logical_expr_id.find(&id)) - .collect(); - let new_set = exprs1.union(&exprs2).copied().collect(); - - // Update the index for the new group / set of logical expressions. - self.group_referencing_exprs_index - .insert(new_group_id, new_set); + self.merge_referencing_exprs(group_id_1, group_id_2, new_group_id); // We need to clone here because we are modifying our `self` state inside the loop. // TODO: This is an inefficiency. This referencing index shouldn't be modified in the loop