Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions execution_graph/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,143 @@ mod tests {
assert_eq!(g.node_run_count(n), Some(2));
}

#[test]
fn host_write_invalidates_prior_readers_of_same_key() {
#[derive(Clone)]
struct KvHost {
kv: Rc<RefCell<BTreeMap<u64, i64>>>,
get_sig: SigHash,
set_sig: SigHash,
}

impl Host for KvHost {
fn call(
&mut self,
symbol: &str,
sig_hash: SigHash,
args: &[ValueRef<'_>],
rets: &mut [Value],
mut ctx: HostContext<'_, '_>,
) -> Result<u64, HostError> {
match symbol {
"kv.get" => {
if sig_hash != self.get_sig {
return Err(HostError::SignatureMismatch);
}
let [ValueRef::U64(key)] = args else {
return Err(HostError::Failed);
};
ctx.record_read(ResourceKeyRef::HostState {
op: self.get_sig,
key: *key,
});
let v = *self.kv.borrow().get(key).unwrap_or(&0);
rets[0] = Value::I64(v);
Ok(0)
}
"kv.set" => {
if sig_hash != self.set_sig {
return Err(HostError::SignatureMismatch);
}
let [ValueRef::U64(key), ValueRef::I64(value)] = args else {
return Err(HostError::Failed);
};
self.kv.borrow_mut().insert(*key, *value);
// Use the reader's key namespace so this write invalidates prior reads.
ctx.record_write(ResourceKeyRef::HostState {
op: self.get_sig,
key: *key,
});
rets[0] = Value::Unit;
Ok(0)
}
_ => Err(HostError::UnknownSymbol),
}
}
}

let get_sig = HostSig {
args: vec![ValueType::U64],
rets: vec![ValueType::I64],
};
let set_sig = HostSig {
args: vec![ValueType::U64, ValueType::I64],
rets: vec![ValueType::Unit],
};
let get_hash = sig_hash(&get_sig);
let set_hash = sig_hash(&set_sig);

let mut get_builder = ProgramBuilder::new();
let get_host = get_builder.host_sig_for("kv.get", get_sig);
let mut get_asm = Asm::new();
get_asm.const_u64(1, 1);
get_asm.host_call(0, get_host, 0, &[1], &[2]);
get_asm.ret(0, &[2]);
let get_entry = get_builder
.push_function_checked(
get_asm,
FunctionSig {
arg_types: vec![],
ret_types: vec![ValueType::I64],
},
)
.unwrap();
get_builder
.set_function_output_name(get_entry, 0, "value")
.unwrap();
let get_prog = Arc::new(get_builder.build_verified().unwrap());

let mut set_builder = ProgramBuilder::new();
let set_host = set_builder.host_sig_for("kv.set", set_sig);
let mut set_asm = Asm::new();
set_asm.const_u64(1, 1);
set_asm.const_i64(2, 8);
set_asm.host_call(0, set_host, 0, &[1, 2], &[3]);
set_asm.ret(0, &[3]);
let set_entry = set_builder
.push_function_checked(
set_asm,
FunctionSig {
arg_types: vec![],
ret_types: vec![ValueType::Unit],
},
)
.unwrap();
set_builder
.set_function_output_name(set_entry, 0, "done")
.unwrap();
let set_prog = Arc::new(set_builder.build_verified().unwrap());

let kv = Rc::new(RefCell::new(BTreeMap::new()));
kv.borrow_mut().insert(1, 7);
let host = KvHost {
kv,
get_sig: get_hash,
set_sig: set_hash,
};

let mut g = ExecutionGraph::new(host, Limits::default());
let reader = g.add_node(get_prog, get_entry, vec![]);

g.run_all().unwrap();
assert_eq!(
g.node_outputs(reader).unwrap().get("value"),
Some(&Value::I64(7))
);
assert_eq!(g.node_run_count(reader), Some(1));

let writer = g.add_node(set_prog, set_entry, vec![]);
g.run_node(writer).unwrap();
assert_eq!(g.node_run_count(reader), Some(1));

g.run_all().unwrap();
assert_eq!(
g.node_outputs(reader).unwrap().get("value"),
Some(&Value::I64(8))
);
assert_eq!(g.node_run_count(reader), Some(2));
}

#[test]
fn host_read_order_changes_do_not_change_last_read_ids() {
#[derive(Clone)]
Expand Down
52 changes: 44 additions & 8 deletions execution_graph/src/tape_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ impl AccessSink for CollectingAccessSink<'_> {

fn write(&mut self, key: ResourceKeyRef<'_>) {
self.counter.set(self.counter.get().saturating_add(1));
let key = match key {
ResourceKeyRef::Input(name) => ResourceKey::input(name),
ResourceKeyRef::HostState { op, key } => {
ResourceKey::host_state(HostOpId::new(op.0), key)
}
ResourceKeyRef::OpaqueHost { op } => ResourceKey::opaque_host(HostOpId::new(op.0)),
};
let key = mark_tape_key_dirty(
self.dirty,
self.input_ids,
self.host_state_ids,
self.opaque_host_ids,
key,
);
self.log.push(Access::Write(key));
}
}
Expand Down Expand Up @@ -159,6 +159,35 @@ pub(crate) fn intern_opaque_host_key_id(
id
}

#[inline]
fn mark_tape_key_dirty(
dirty: &mut DirtyEngine,
input_ids: &mut BTreeMap<Box<str>, DirtyKey>,
host_state_ids: &mut HashMap<(HostOpId, u64), DirtyKey>,
opaque_host_ids: &mut HashMap<HostOpId, DirtyKey>,
key: ResourceKeyRef<'_>,
) -> ResourceKey {
match key {
ResourceKeyRef::Input(name) => {
let id = intern_input_key_id(dirty, input_ids, name);
dirty.mark_dirty(id);
ResourceKey::input(name)
}
ResourceKeyRef::HostState { op, key } => {
let op = HostOpId::new(op.0);
let id = intern_host_state_key_id(dirty, host_state_ids, op, key);
dirty.mark_dirty(id);
ResourceKey::host_state(op, key)
}
ResourceKeyRef::OpaqueHost { op } => {
let op = HostOpId::new(op.0);
let id = intern_opaque_host_key_id(dirty, opaque_host_ids, op);
dirty.mark_dirty(id);
ResourceKey::opaque_host(op)
}
}
}

/// Fast-path access sink used when per-node access log collection is disabled.
///
/// It emits dependency read IDs directly into `read_ids`, avoiding intermediate `AccessLog`
Expand Down Expand Up @@ -217,9 +246,16 @@ impl AccessSink for DepsOnlyAccessSink<'_> {
}

#[inline]
fn write(&mut self, _key: ResourceKeyRef<'_>) {
fn write(&mut self, key: ResourceKeyRef<'_>) {
// Strict-deps mode requires host scopes to emit at least one access event.
self.counter.set(self.counter.get().saturating_add(1));
let _ = mark_tape_key_dirty(
self.dirty,
self.input_ids,
self.host_state_ids,
self.opaque_host_ids,
key,
);
}
}

Expand Down
13 changes: 8 additions & 5 deletions execution_tape/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ use crate::value::Value;
/// let [ValueRef::U64(key), ValueRef::I64(value)] = args else {
/// return Err(HostError::Failed);
/// };
/// // Use the same `(op, key)` namespace as `kv.get` so this write invalidates
/// // prior reads of the key.
/// ctx.record_write(ResourceKeyRef::HostState {
/// op: sig_hash,
/// op: self.get_sig,
/// key: *key,
/// });
/// self.kv.insert(*key, *value);
Expand Down Expand Up @@ -364,10 +366,11 @@ impl<'vm, 'access> HostContext<'vm, 'access> {
/// The string is an embedder-chosen stable name.
///
/// - [`ResourceKeyRef::HostState`] is the main “precise” form for host-managed state.
/// It is explicitly namespaced by the host operation’s [`SigHash`], so different host ops can
/// reuse the same numeric `key` without colliding. The `key: u64` should identify *which*
/// piece of state was consulted/mutated for that operation (often a stable hash of a structured
/// key, or an intern id managed by the embedder).
/// It is explicitly namespaced by a stable [`SigHash`] chosen by the host, so unrelated state
/// domains can reuse the same numeric `key` without colliding. The `key: u64` should identify
/// *which* piece of state was consulted/mutated for that namespace (often a stable hash of a
/// structured key, or an intern id managed by the embedder). Writes that should invalidate
/// previous reads must use the same `(op, key)` pair those reads recorded.
///
/// - [`ResourceKeyRef::OpaqueHost`] is a conservative escape hatch for operations that depend on
/// (or mutate) host state but cannot (or choose not to) produce a more precise key.
Expand Down
Loading