Skip to content

Commit e7799e7

Browse files
committed
register dependency on MemoRef::deref()
1 parent ca3bd18 commit e7799e7

File tree

5 files changed

+57
-32
lines changed

5 files changed

+57
-32
lines changed

crates/pico/src/database.rs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ pub trait Database: DatabaseDyn + Sized {
3333
}
3434

3535
pub trait StorageDyn {
36-
fn get_value_as_any(&self, id: DerivedNodeId) -> Option<&dyn Any>;
36+
fn get_derived_node_value_and_revision(
37+
&self,
38+
id: DerivedNodeId,
39+
) -> Option<(&dyn Any, DerivedNodeRevision)>;
40+
41+
fn register_dependency_in_parent_memoized_fn(&self, node: NodeKind, time_updated: Epoch);
3742
}
3843

3944
#[derive(Debug)]
@@ -46,10 +51,17 @@ pub struct Storage<Db: Database> {
4651
}
4752

4853
impl<Db: Database> StorageDyn for Storage<Db> {
49-
fn get_value_as_any(&self, id: DerivedNodeId) -> Option<&dyn Any> {
54+
fn get_derived_node_value_and_revision(
55+
&self,
56+
id: DerivedNodeId,
57+
) -> Option<(&dyn Any, DerivedNodeRevision)> {
5058
self.internal
51-
.get_derived_node(id)
52-
.map(|node| node.value.as_ref().as_any())
59+
.get_derived_node_and_revision(id)
60+
.map(|(node, revision)| (node.value.as_ref().as_any(), revision))
61+
}
62+
63+
fn register_dependency_in_parent_memoized_fn(&self, node: NodeKind, time_updated: Epoch) {
64+
Storage::register_dependency_in_parent_memoized_fn(self, node, time_updated);
5365
}
5466
}
5567

@@ -201,14 +213,22 @@ impl<Db: Database> InternalStorage<Db> {
201213
&self,
202214
derived_node_id: DerivedNodeId,
203215
) -> Option<&DerivedNode<Db>> {
204-
let index = self
205-
.derived_node_id_to_revision
206-
.get(&derived_node_id)?
207-
.index;
208-
Some(self.derived_nodes.get(index.idx).expect(
216+
self.get_derived_node_and_revision(derived_node_id)
217+
.map(|(node, _)| node)
218+
}
219+
220+
pub(crate) fn get_derived_node_and_revision(
221+
&self,
222+
derived_node_id: DerivedNodeId,
223+
) -> Option<(&DerivedNode<Db>, DerivedNodeRevision)> {
224+
let revision = *self.derived_node_id_to_revision.get(&derived_node_id)?;
225+
226+
let node = self.derived_nodes.get(revision.index.idx).expect(
209227
"indexes should always be valid. \
210228
This is indicative of a bug in Pico.",
211-
))
229+
);
230+
231+
Some((node, revision))
212232
}
213233

214234
pub(crate) fn node_verified_in_current_epoch(&self, derived_node_id: DerivedNodeId) -> bool {

crates/pico/src/dependency.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ impl TrackedDependencies {
3232

3333
pub fn push(&mut self, dependency: Dependency, time_updated: Epoch) {
3434
self.max_time_updated = std::cmp::max(time_updated, self.max_time_updated);
35+
if let Some(last_dependency) = self.dependencies.last_mut()
36+
&& last_dependency.node_to == dependency.node_to
37+
{
38+
last_dependency.time_verified_or_updated = dependency.time_verified_or_updated;
39+
return;
40+
};
3541
self.dependencies.push(dependency);
3642
}
3743
}

crates/pico/src/execute_memoized_function.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,18 @@ pub fn execute_memoized_function<Db: Database>(
6868
db.get_storage().top_level_calls.push(derived_node_id);
6969
}
7070

71-
let (time_updated, did_recalculate) = if let Some(derived_node) =
72-
db.get_storage().internal.get_derived_node(derived_node_id)
71+
let (did_recalculate, time_updated) = if let Some((derived_node, revision)) = db
72+
.get_storage()
73+
.internal
74+
.get_derived_node_and_revision(derived_node_id)
7375
{
7476
if db
7577
.get_storage()
7678
.internal
7779
.node_verified_in_current_epoch(derived_node_id)
7880
{
7981
event!(Level::TRACE, "epoch not changed");
80-
(
81-
db.get_storage().internal.current_epoch,
82-
DidRecalculate::ReusedMemoizedValue,
83-
)
82+
(DidRecalculate::ReusedMemoizedValue, revision.time_updated)
8483
} else {
8584
db.get_storage()
8685
.internal
@@ -90,10 +89,7 @@ pub fn execute_memoized_function<Db: Database>(
9089
update_derived_node(db, derived_node_id, derived_node.value.as_ref(), inner_fn)
9190
} else {
9291
event!(Level::TRACE, "dependencies up-to-date");
93-
(
94-
db.get_storage().internal.current_epoch,
95-
DidRecalculate::ReusedMemoizedValue,
96-
)
92+
(DidRecalculate::ReusedMemoizedValue, revision.time_updated)
9793
}
9894
}
9995
} else {
@@ -111,7 +107,7 @@ fn create_derived_node<Db: Database>(
111107
db: &Db,
112108
derived_node_id: DerivedNodeId,
113109
inner_fn: InnerFn<Db>,
114-
) -> (Epoch, DidRecalculate) {
110+
) -> (DidRecalculate, Epoch) {
115111
let (value, tracked_dependencies) =
116112
invoke_with_dependency_tracking(db, derived_node_id, inner_fn).expect(
117113
"InnerFn call cannot fail for a new derived node. This is indicative of a bug in Pico.",
@@ -128,8 +124,8 @@ fn create_derived_node<Db: Database>(
128124
index,
129125
);
130126
(
131-
tracked_dependencies.max_time_updated,
132127
DidRecalculate::Recalculated,
128+
tracked_dependencies.max_time_updated,
133129
)
134130
}
135131

@@ -138,7 +134,7 @@ fn update_derived_node<Db: Database>(
138134
derived_node_id: DerivedNodeId,
139135
prev_value: &dyn DynEq,
140136
inner_fn: InnerFn<Db>,
141-
) -> (Epoch, DidRecalculate) {
137+
) -> (DidRecalculate, Epoch) {
142138
match invoke_with_dependency_tracking(db, derived_node_id, inner_fn) {
143139
Some((value, tracked_dependencies)) => {
144140
let mut occupied = if let Entry::Occupied(occupied) = db
@@ -169,9 +165,9 @@ fn update_derived_node<Db: Database>(
169165

170166
occupied.get_mut().index = index;
171167

172-
(tracked_dependencies.max_time_updated, did_recalculate)
168+
(did_recalculate, tracked_dependencies.max_time_updated)
173169
}
174-
None => (Epoch::new(), DidRecalculate::Error),
170+
None => (DidRecalculate::Error, Epoch::new()),
175171
}
176172
}
177173

crates/pico/src/memo_ref.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{marker::PhantomData, ops::Deref};
22

33
use intern::InternId;
44

5-
use crate::{DatabaseDyn, DerivedNodeId, ParamId};
5+
use crate::{DatabaseDyn, DerivedNodeId, ParamId, dependency::NodeKind};
66

77
#[derive(Debug)]
88
pub struct MemoRef<T> {
@@ -55,10 +55,14 @@ impl<T: 'static> Deref for MemoRef<T> {
5555
fn deref(&self) -> &T {
5656
// SAFETY: Database outlives this MemoRef
5757
let db = unsafe { &*self.db };
58-
db.get_storage_dyn()
59-
.get_value_as_any(self.derived_node_id)
60-
.unwrap()
61-
.downcast_ref::<T>()
62-
.unwrap()
58+
let storage = db.get_storage_dyn();
59+
let (value, revision) = storage
60+
.get_derived_node_value_and_revision(self.derived_node_id)
61+
.unwrap();
62+
storage.register_dependency_in_parent_memoized_fn(
63+
NodeKind::Derived(self.derived_node_id),
64+
revision.time_updated,
65+
);
66+
value.downcast_ref::<T>().unwrap()
6367
}
6468
}

crates/pico/tests/basic_multi_function_chain.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ fn multi_function_chain_with_irrelevant_change() {
7272
}
7373

7474
#[test]
75-
#[should_panic]
7675
fn sequential_functions_with_memoref_param() {
7776
let _serial_lock = RUN_SERIALLY.lock();
7877
FIRST_LETTER_COUNTER.store(0, Ordering::SeqCst);

0 commit comments

Comments
 (0)