From ca3bd188c8ecf627846196f45a7ca4df700cfdc3 Mon Sep 17 00:00:00 2001 From: ch1ffa Date: Wed, 24 Sep 2025 12:46:43 +0300 Subject: [PATCH 1/2] demonstrate broken memo --- .../pico/tests/basic_multi_function_chain.rs | 55 ++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/crates/pico/tests/basic_multi_function_chain.rs b/crates/pico/tests/basic_multi_function_chain.rs index a894ddddd..087bd07c6 100644 --- a/crates/pico/tests/basic_multi_function_chain.rs +++ b/crates/pico/tests/basic_multi_function_chain.rs @@ -3,11 +3,12 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, }; -use pico::{Database, SourceId, Storage}; +use pico::{Database, MemoRef, SourceId, Storage}; use pico_macros::{Db, Source, memo}; static FIRST_LETTER_COUNTER: AtomicUsize = AtomicUsize::new(0); static CAPITALIZED_LETTER_COUNTER: AtomicUsize = AtomicUsize::new(0); +static MEMO_REF_PARAM_COUNTER: AtomicUsize = AtomicUsize::new(0); static RUN_SERIALLY: LazyLock> = LazyLock::new(Mutex::default); @@ -70,6 +71,52 @@ fn multi_function_chain_with_irrelevant_change() { assert_eq!(CAPITALIZED_LETTER_COUNTER.load(Ordering::SeqCst), 1); } +#[test] +#[should_panic] +fn sequential_functions_with_memoref_param() { + let _serial_lock = RUN_SERIALLY.lock(); + FIRST_LETTER_COUNTER.store(0, Ordering::SeqCst); + MEMO_REF_PARAM_COUNTER.store(0, Ordering::SeqCst); + + let mut db = TestDatabase::default(); + + let id = db.set(Input { + key: "key", + value: "asdf".to_string(), + }); + + assert_eq!( + *capitalized_first_letter_from_memoref(&db, first_letter(&db, id)), + 'A', + ); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 1); + assert_eq!(MEMO_REF_PARAM_COUNTER.load(Ordering::SeqCst), 1); + + db.set(Input { + key: "key", + value: "bsdf".to_string(), + }); + + assert_eq!( + *capitalized_first_letter_from_memoref(&db, first_letter(&db, id)), + 'B', + ); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 2); + assert_eq!(MEMO_REF_PARAM_COUNTER.load(Ordering::SeqCst), 2); + + db.set(Input { + key: "key", + value: "balt".to_string(), + }); + + assert_eq!( + *capitalized_first_letter_from_memoref(&db, first_letter(&db, id)), + 'B', + ); + assert_eq!(FIRST_LETTER_COUNTER.load(Ordering::SeqCst), 3); + assert_eq!(MEMO_REF_PARAM_COUNTER.load(Ordering::SeqCst), 2); +} + #[derive(Debug, Clone, PartialEq, Eq, Source)] struct Input { #[key] @@ -90,3 +137,9 @@ fn capitalized_first_letter(db: &TestDatabase, input_id: SourceId) -> cha let first = first_letter(db, input_id); first.to_ascii_uppercase() } + +#[memo] +fn capitalized_first_letter_from_memoref(db: &TestDatabase, first: MemoRef) -> char { + MEMO_REF_PARAM_COUNTER.fetch_add(1, Ordering::SeqCst); + first.to_ascii_uppercase() +} From e7799e751de6bb982481475d7cad67c4fb4f46ac Mon Sep 17 00:00:00 2001 From: ch1ffa Date: Tue, 23 Sep 2025 13:17:52 +0300 Subject: [PATCH 2/2] register dependency on MemoRef::deref() --- crates/pico/src/database.rs | 40 ++++++++++++++----- crates/pico/src/dependency.rs | 6 +++ crates/pico/src/execute_memoized_function.rs | 26 +++++------- crates/pico/src/memo_ref.rs | 16 +++++--- .../pico/tests/basic_multi_function_chain.rs | 1 - 5 files changed, 57 insertions(+), 32 deletions(-) diff --git a/crates/pico/src/database.rs b/crates/pico/src/database.rs index cecb34c1d..6f04d1820 100644 --- a/crates/pico/src/database.rs +++ b/crates/pico/src/database.rs @@ -33,7 +33,12 @@ pub trait Database: DatabaseDyn + Sized { } pub trait StorageDyn { - fn get_value_as_any(&self, id: DerivedNodeId) -> Option<&dyn Any>; + fn get_derived_node_value_and_revision( + &self, + id: DerivedNodeId, + ) -> Option<(&dyn Any, DerivedNodeRevision)>; + + fn register_dependency_in_parent_memoized_fn(&self, node: NodeKind, time_updated: Epoch); } #[derive(Debug)] @@ -46,10 +51,17 @@ pub struct Storage { } impl StorageDyn for Storage { - fn get_value_as_any(&self, id: DerivedNodeId) -> Option<&dyn Any> { + fn get_derived_node_value_and_revision( + &self, + id: DerivedNodeId, + ) -> Option<(&dyn Any, DerivedNodeRevision)> { self.internal - .get_derived_node(id) - .map(|node| node.value.as_ref().as_any()) + .get_derived_node_and_revision(id) + .map(|(node, revision)| (node.value.as_ref().as_any(), revision)) + } + + fn register_dependency_in_parent_memoized_fn(&self, node: NodeKind, time_updated: Epoch) { + Storage::register_dependency_in_parent_memoized_fn(self, node, time_updated); } } @@ -201,14 +213,22 @@ impl InternalStorage { &self, derived_node_id: DerivedNodeId, ) -> Option<&DerivedNode> { - let index = self - .derived_node_id_to_revision - .get(&derived_node_id)? - .index; - Some(self.derived_nodes.get(index.idx).expect( + self.get_derived_node_and_revision(derived_node_id) + .map(|(node, _)| node) + } + + pub(crate) fn get_derived_node_and_revision( + &self, + derived_node_id: DerivedNodeId, + ) -> Option<(&DerivedNode, DerivedNodeRevision)> { + let revision = *self.derived_node_id_to_revision.get(&derived_node_id)?; + + let node = self.derived_nodes.get(revision.index.idx).expect( "indexes should always be valid. \ This is indicative of a bug in Pico.", - )) + ); + + Some((node, revision)) } pub(crate) fn node_verified_in_current_epoch(&self, derived_node_id: DerivedNodeId) -> bool { diff --git a/crates/pico/src/dependency.rs b/crates/pico/src/dependency.rs index e23940183..18ce1f85e 100644 --- a/crates/pico/src/dependency.rs +++ b/crates/pico/src/dependency.rs @@ -32,6 +32,12 @@ impl TrackedDependencies { pub fn push(&mut self, dependency: Dependency, time_updated: Epoch) { self.max_time_updated = std::cmp::max(time_updated, self.max_time_updated); + if let Some(last_dependency) = self.dependencies.last_mut() + && last_dependency.node_to == dependency.node_to + { + last_dependency.time_verified_or_updated = dependency.time_verified_or_updated; + return; + }; self.dependencies.push(dependency); } } diff --git a/crates/pico/src/execute_memoized_function.rs b/crates/pico/src/execute_memoized_function.rs index 2e36ce2d7..9c62b3a71 100644 --- a/crates/pico/src/execute_memoized_function.rs +++ b/crates/pico/src/execute_memoized_function.rs @@ -68,8 +68,10 @@ pub fn execute_memoized_function( db.get_storage().top_level_calls.push(derived_node_id); } - let (time_updated, did_recalculate) = if let Some(derived_node) = - db.get_storage().internal.get_derived_node(derived_node_id) + let (did_recalculate, time_updated) = if let Some((derived_node, revision)) = db + .get_storage() + .internal + .get_derived_node_and_revision(derived_node_id) { if db .get_storage() @@ -77,10 +79,7 @@ pub fn execute_memoized_function( .node_verified_in_current_epoch(derived_node_id) { event!(Level::TRACE, "epoch not changed"); - ( - db.get_storage().internal.current_epoch, - DidRecalculate::ReusedMemoizedValue, - ) + (DidRecalculate::ReusedMemoizedValue, revision.time_updated) } else { db.get_storage() .internal @@ -90,10 +89,7 @@ pub fn execute_memoized_function( update_derived_node(db, derived_node_id, derived_node.value.as_ref(), inner_fn) } else { event!(Level::TRACE, "dependencies up-to-date"); - ( - db.get_storage().internal.current_epoch, - DidRecalculate::ReusedMemoizedValue, - ) + (DidRecalculate::ReusedMemoizedValue, revision.time_updated) } } } else { @@ -111,7 +107,7 @@ fn create_derived_node( db: &Db, derived_node_id: DerivedNodeId, inner_fn: InnerFn, -) -> (Epoch, DidRecalculate) { +) -> (DidRecalculate, Epoch) { let (value, tracked_dependencies) = invoke_with_dependency_tracking(db, derived_node_id, inner_fn).expect( "InnerFn call cannot fail for a new derived node. This is indicative of a bug in Pico.", @@ -128,8 +124,8 @@ fn create_derived_node( index, ); ( - tracked_dependencies.max_time_updated, DidRecalculate::Recalculated, + tracked_dependencies.max_time_updated, ) } @@ -138,7 +134,7 @@ fn update_derived_node( derived_node_id: DerivedNodeId, prev_value: &dyn DynEq, inner_fn: InnerFn, -) -> (Epoch, DidRecalculate) { +) -> (DidRecalculate, Epoch) { match invoke_with_dependency_tracking(db, derived_node_id, inner_fn) { Some((value, tracked_dependencies)) => { let mut occupied = if let Entry::Occupied(occupied) = db @@ -169,9 +165,9 @@ fn update_derived_node( occupied.get_mut().index = index; - (tracked_dependencies.max_time_updated, did_recalculate) + (did_recalculate, tracked_dependencies.max_time_updated) } - None => (Epoch::new(), DidRecalculate::Error), + None => (DidRecalculate::Error, Epoch::new()), } } diff --git a/crates/pico/src/memo_ref.rs b/crates/pico/src/memo_ref.rs index 045a82736..c6cd1887a 100644 --- a/crates/pico/src/memo_ref.rs +++ b/crates/pico/src/memo_ref.rs @@ -2,7 +2,7 @@ use std::{marker::PhantomData, ops::Deref}; use intern::InternId; -use crate::{DatabaseDyn, DerivedNodeId, ParamId}; +use crate::{DatabaseDyn, DerivedNodeId, ParamId, dependency::NodeKind}; #[derive(Debug)] pub struct MemoRef { @@ -55,10 +55,14 @@ impl Deref for MemoRef { fn deref(&self) -> &T { // SAFETY: Database outlives this MemoRef let db = unsafe { &*self.db }; - db.get_storage_dyn() - .get_value_as_any(self.derived_node_id) - .unwrap() - .downcast_ref::() - .unwrap() + let storage = db.get_storage_dyn(); + let (value, revision) = storage + .get_derived_node_value_and_revision(self.derived_node_id) + .unwrap(); + storage.register_dependency_in_parent_memoized_fn( + NodeKind::Derived(self.derived_node_id), + revision.time_updated, + ); + value.downcast_ref::().unwrap() } } diff --git a/crates/pico/tests/basic_multi_function_chain.rs b/crates/pico/tests/basic_multi_function_chain.rs index 087bd07c6..ebc99c682 100644 --- a/crates/pico/tests/basic_multi_function_chain.rs +++ b/crates/pico/tests/basic_multi_function_chain.rs @@ -72,7 +72,6 @@ fn multi_function_chain_with_irrelevant_change() { } #[test] -#[should_panic] fn sequential_functions_with_memoref_param() { let _serial_lock = RUN_SERIALLY.lock(); FIRST_LETTER_COUNTER.store(0, Ordering::SeqCst);