diff --git a/crates/pico/src/database.rs b/crates/pico/src/database.rs index cecb34c1d..6f04d1820 100644 --- a/crates/pico/src/database.rs +++ b/crates/pico/src/database.rs @@ -33,7 +33,12 @@ pub trait Database: DatabaseDyn + Sized { } pub trait StorageDyn { - fn get_value_as_any(&self, id: DerivedNodeId) -> Option<&dyn Any>; + fn get_derived_node_value_and_revision( + &self, + id: DerivedNodeId, + ) -> Option<(&dyn Any, DerivedNodeRevision)>; + + fn register_dependency_in_parent_memoized_fn(&self, node: NodeKind, time_updated: Epoch); } #[derive(Debug)] @@ -46,10 +51,17 @@ pub struct Storage { } impl StorageDyn for Storage { - fn get_value_as_any(&self, id: DerivedNodeId) -> Option<&dyn Any> { + fn get_derived_node_value_and_revision( + &self, + id: DerivedNodeId, + ) -> Option<(&dyn Any, DerivedNodeRevision)> { self.internal - .get_derived_node(id) - .map(|node| node.value.as_ref().as_any()) + .get_derived_node_and_revision(id) + .map(|(node, revision)| (node.value.as_ref().as_any(), revision)) + } + + fn register_dependency_in_parent_memoized_fn(&self, node: NodeKind, time_updated: Epoch) { + Storage::register_dependency_in_parent_memoized_fn(self, node, time_updated); } } @@ -201,14 +213,22 @@ impl InternalStorage { &self, derived_node_id: DerivedNodeId, ) -> Option<&DerivedNode> { - let index = self - .derived_node_id_to_revision - .get(&derived_node_id)? - .index; - Some(self.derived_nodes.get(index.idx).expect( + self.get_derived_node_and_revision(derived_node_id) + .map(|(node, _)| node) + } + + pub(crate) fn get_derived_node_and_revision( + &self, + derived_node_id: DerivedNodeId, + ) -> Option<(&DerivedNode, DerivedNodeRevision)> { + let revision = *self.derived_node_id_to_revision.get(&derived_node_id)?; + + let node = self.derived_nodes.get(revision.index.idx).expect( "indexes should always be valid. \ This is indicative of a bug in Pico.", - )) + ); + + Some((node, revision)) } pub(crate) fn node_verified_in_current_epoch(&self, derived_node_id: DerivedNodeId) -> bool { diff --git a/crates/pico/src/dependency.rs b/crates/pico/src/dependency.rs index e23940183..18ce1f85e 100644 --- a/crates/pico/src/dependency.rs +++ b/crates/pico/src/dependency.rs @@ -32,6 +32,12 @@ impl TrackedDependencies { pub fn push(&mut self, dependency: Dependency, time_updated: Epoch) { self.max_time_updated = std::cmp::max(time_updated, self.max_time_updated); + if let Some(last_dependency) = self.dependencies.last_mut() + && last_dependency.node_to == dependency.node_to + { + last_dependency.time_verified_or_updated = dependency.time_verified_or_updated; + return; + }; self.dependencies.push(dependency); } } diff --git a/crates/pico/src/execute_memoized_function.rs b/crates/pico/src/execute_memoized_function.rs index 2e36ce2d7..9c62b3a71 100644 --- a/crates/pico/src/execute_memoized_function.rs +++ b/crates/pico/src/execute_memoized_function.rs @@ -68,8 +68,10 @@ pub fn execute_memoized_function( db.get_storage().top_level_calls.push(derived_node_id); } - let (time_updated, did_recalculate) = if let Some(derived_node) = - db.get_storage().internal.get_derived_node(derived_node_id) + let (did_recalculate, time_updated) = if let Some((derived_node, revision)) = db + .get_storage() + .internal + .get_derived_node_and_revision(derived_node_id) { if db .get_storage() @@ -77,10 +79,7 @@ pub fn execute_memoized_function( .node_verified_in_current_epoch(derived_node_id) { event!(Level::TRACE, "epoch not changed"); - ( - db.get_storage().internal.current_epoch, - DidRecalculate::ReusedMemoizedValue, - ) + (DidRecalculate::ReusedMemoizedValue, revision.time_updated) } else { db.get_storage() .internal @@ -90,10 +89,7 @@ pub fn execute_memoized_function( update_derived_node(db, derived_node_id, derived_node.value.as_ref(), inner_fn) } else { event!(Level::TRACE, "dependencies up-to-date"); - ( - db.get_storage().internal.current_epoch, - DidRecalculate::ReusedMemoizedValue, - ) + (DidRecalculate::ReusedMemoizedValue, revision.time_updated) } } } else { @@ -111,7 +107,7 @@ fn create_derived_node( db: &Db, derived_node_id: DerivedNodeId, inner_fn: InnerFn, -) -> (Epoch, DidRecalculate) { +) -> (DidRecalculate, Epoch) { let (value, tracked_dependencies) = invoke_with_dependency_tracking(db, derived_node_id, inner_fn).expect( "InnerFn call cannot fail for a new derived node. This is indicative of a bug in Pico.", @@ -128,8 +124,8 @@ fn create_derived_node( index, ); ( - tracked_dependencies.max_time_updated, DidRecalculate::Recalculated, + tracked_dependencies.max_time_updated, ) } @@ -138,7 +134,7 @@ fn update_derived_node( derived_node_id: DerivedNodeId, prev_value: &dyn DynEq, inner_fn: InnerFn, -) -> (Epoch, DidRecalculate) { +) -> (DidRecalculate, Epoch) { match invoke_with_dependency_tracking(db, derived_node_id, inner_fn) { Some((value, tracked_dependencies)) => { let mut occupied = if let Entry::Occupied(occupied) = db @@ -169,9 +165,9 @@ fn update_derived_node( occupied.get_mut().index = index; - (tracked_dependencies.max_time_updated, did_recalculate) + (did_recalculate, tracked_dependencies.max_time_updated) } - None => (Epoch::new(), DidRecalculate::Error), + None => (DidRecalculate::Error, Epoch::new()), } } diff --git a/crates/pico/src/memo_ref.rs b/crates/pico/src/memo_ref.rs index 045a82736..c6cd1887a 100644 --- a/crates/pico/src/memo_ref.rs +++ b/crates/pico/src/memo_ref.rs @@ -2,7 +2,7 @@ use std::{marker::PhantomData, ops::Deref}; use intern::InternId; -use crate::{DatabaseDyn, DerivedNodeId, ParamId}; +use crate::{DatabaseDyn, DerivedNodeId, ParamId, dependency::NodeKind}; #[derive(Debug)] pub struct MemoRef { @@ -55,10 +55,14 @@ impl Deref for MemoRef { fn deref(&self) -> &T { // SAFETY: Database outlives this MemoRef let db = unsafe { &*self.db }; - db.get_storage_dyn() - .get_value_as_any(self.derived_node_id) - .unwrap() - .downcast_ref::() - .unwrap() + let storage = db.get_storage_dyn(); + let (value, revision) = storage + .get_derived_node_value_and_revision(self.derived_node_id) + .unwrap(); + storage.register_dependency_in_parent_memoized_fn( + NodeKind::Derived(self.derived_node_id), + revision.time_updated, + ); + value.downcast_ref::().unwrap() } } diff --git a/crates/pico/tests/basic_multi_function_chain.rs b/crates/pico/tests/basic_multi_function_chain.rs index d2caf7d99..e79cf7e98 100644 --- a/crates/pico/tests/basic_multi_function_chain.rs +++ b/crates/pico/tests/basic_multi_function_chain.rs @@ -72,7 +72,6 @@ fn multi_function_chain_with_irrelevant_change() { } #[test] -#[should_panic] fn sequential_functions_with_memoref_param() { let _serial_lock = RUN_SERIALLY.lock(); FIRST_LETTER_COUNTER.store(0, Ordering::SeqCst);