From 5469a74dad5fdfdca0f39537b6d14b1aac7d952f Mon Sep 17 00:00:00 2001 From: zdevito Date: Mon, 16 Jun 2025 09:27:07 -0700 Subject: [PATCH] [7/n] Get history ready for working with ports Differential Revision: [D76649221](https://our.internmc.facebook.com/intern/diff/D76649221/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D76649221/)! [ghstack-poisoned] --- monarch_extension/src/mesh_controller.rs | 262 +++++++++-------------- 1 file changed, 103 insertions(+), 159 deletions(-) diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 63ae5f267..7c6a844b8 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::collections::VecDeque; use std::iter::repeat_n; +use std::sync; use std::sync::Arc; use std::sync::atomic; use std::sync::atomic::AtomicUsize; @@ -216,9 +217,8 @@ impl _Controller { Ok(()) } - fn drop_refs(&mut self, refs: Vec) -> Result<(), anyhow::Error> { - self.history.delete_invocations_for_refs(refs); - Ok(()) + fn drop_refs(&mut self, refs: Vec) { + self.history.drop_refs(refs); } fn send<'py>(&mut self, ranks: Bound<'py, PyAny>, message: Bound<'py, PyAny>) -> PyResult<()> { @@ -292,56 +292,84 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult /// It is useful for tracking the dependencies of an operation and propagating /// failures. In the future this will be used with more data dependency tracking /// to support better failure handling. -// Allowing dead code until we do something smarter with defs, uses etc. + +#[derive(Debug)] +enum Status { + Errored(Exception), + Complete(), + + /// When incomplete this holds this list of users of this invocation, + /// so a future error can be propagated to them., + Incomplete(HashMap>>), +} #[derive(Debug)] struct Invocation { /// The sequence number of the invocation. This should be unique and increasing across all /// invocations. seq: Seq, - /// The references that this invocation defines or redefines. Effectively the - /// output of the invocation. - defs: Vec, - /// The result of the invocation. This is set when the invocation is completed or - /// when a failure is inferred. A successful result will always supersede any failure. - result: Option>, - /// The seqs for the invocations that depend on this invocation. Useful for propagating failures. - users: HashSet, + status: Status, + /// Result reported to a future if this invocation was a fetch + /// Not all Invocations will be fetched so sometimes a Invocation will complete with + /// both result and error == None + result: Option, } impl Invocation { - fn new(seq: Seq, defs: Vec) -> Self { + fn new(seq: Seq) -> Self { Self { seq, - defs, + status: Status::Incomplete(HashMap::new()), result: None, - users: HashSet::new(), } } - fn add_user(&mut self, user: Seq) { - self.users.insert(user); + fn add_user(&mut self, user: Arc>) { + match &mut self.status { + Status::Complete() => {} + Status::Incomplete(users) => { + let seq = user.lock().unwrap().seq; + users.insert(seq, user); + } + Status::Errored(err) => { + user.lock().unwrap().set_exception(err.clone()); + } + } } /// Invocation results can only go from valid to failed, or be /// set if the invocation result is empty. fn set_result(&mut self, result: Serialized) { if self.result.is_none() { - self.result = Some(Ok(result)); + self.result = Some(result); } } - fn set_exception(&mut self, exception: Exception) { + fn succeed(&mut self) { + match self.status { + Status::Incomplete(_) => self.status = Status::Complete(), + _ => {} + } + } + + fn set_exception(&mut self, exception: Exception) -> Vec>> { match exception { - Exception::Error(_, caused_by, error) => { - let e = Err(Exception::Error(self.seq, caused_by, error)); - match self.result { - Some(Ok(_)) => { - self.result = Some(e); + Exception::Error(_, caused_by_new, error) => { + let err = Status::Errored(Exception::Error(self.seq, caused_by_new, error)); + match &self.status { + Status::Errored(Exception::Error(_, caused_by_current, _)) + if caused_by_new < *caused_by_current => + { + self.status = err; + } + Status::Incomplete(users) => { + let users = users.values().cloned().collect(); + self.status = err; + return users; } - None => { - self.result = Some(e); + Status::Complete() => { + panic!("Complete invocation getting an exception set") } - Some(Err(_)) => {} + _ => {} } } Exception::Failure(_) => { @@ -351,32 +379,20 @@ impl Invocation { ); } } + return vec![]; } - fn exception(&self) -> Option<&Exception> { - self.result - .as_ref() - .map(Result::as_ref) - .and_then(Result::err) - } - - #[allow(dead_code)] - fn value(&self) -> Option<&Serialized> { - self.result - .as_ref() - .map(Result::as_ref) - .and_then(Result::ok) + fn msg_result(&self) -> Option> { + match &self.status { + Status::Complete() => self.result.clone().map(|x| Ok(x)), + Status::Errored(err) => Some(Err(err.clone())), + Status::Incomplete(_) => { + panic!("Incomplete invocation doesn't have a result yet") + } + } } } -#[derive(Debug, PartialEq)] -enum RefStatus { - // The invocation for this ref is still in progress. - Invoked(Seq), - // The invocation for this ref has errored. - Errored(Exception), -} - /// The history of invocations sent by the client to be executed on the workers. /// This is used to track dependencies between invocations and to propagate exceptions. /// It purges history for completed invocations to avoid memory bloat. @@ -390,13 +406,11 @@ struct History { first_incomplete_seqs: MinVector, /// The minimum incomplete Seq across all ranks. min_incomplete_seq: Seq, - /// A map of seq to the invocation that it represents. - invocations: HashMap, + /// A map of seq to the invocation that it represents for all seq >= min_incomplete_seq + inflight_invocations: HashMap>>, /// A map of reference to the seq for the invocation that defines it. This is used to /// compute dependencies between invocations. - invocation_for_ref: HashMap, - // Refs to be deleted in mark_worker_complete_and_propagate_failures - marked_for_deletion: HashSet, + invocation_for_ref: HashMap>>, // no new sequence numbers should be below this bound. use for // sanity checking. seq_lower_bound: Seq, @@ -448,8 +462,7 @@ impl History { first_incomplete_seqs: MinVector::new(vec![Seq::default(); world_size]), min_incomplete_seq: Seq::default(), invocation_for_ref: HashMap::new(), - invocations: HashMap::new(), - marked_for_deletion: HashSet::new(), + inflight_invocations: HashMap::new(), seq_lower_bound: 0.into(), } } @@ -459,25 +472,10 @@ impl History { self.first_incomplete_seqs.vec() } - pub fn delete_invocations_for_refs(&mut self, refs: Vec) { - self.marked_for_deletion.extend(refs); - - self.marked_for_deletion - .retain(|ref_| match self.invocation_for_ref.get(ref_) { - Some(RefStatus::Invoked(seq)) => { - if seq < &self.min_incomplete_seq { - self.invocation_for_ref.remove(ref_); - false - } else { - true - } - } - Some(RefStatus::Errored(_)) => { - self.invocation_for_ref.remove(ref_); - false - } - None => true, - }); + pub fn drop_refs(&mut self, refs: Vec) { + for r in refs { + self.invocation_for_ref.remove(&r); + } } /// Add an invocation to the history. @@ -487,7 +485,6 @@ impl History { uses: Vec, defs: Vec, ) -> Vec<(Seq, Option>)> { - let mut results = Vec::new(); assert!( seq >= self.seq_lower_bound, "nonmonotonic seq: {:?}; current lower bound: {:?}", @@ -495,43 +492,22 @@ impl History { self.seq_lower_bound, ); self.seq_lower_bound = seq; - let mut invocation = Invocation::new(seq, defs.clone()); - - for use_ in uses { - // The invocation for every use_ should add this seq as a user. - match self.invocation_for_ref.get(&use_) { - Some(RefStatus::Errored(exception)) => { - // We know that this invocation hasn't been completed yet, so we can - // directly call set_exception on it. - if results.is_empty() { - invocation.set_exception(exception.clone()); - results.push((seq, Some(Err(exception.clone())))); - } - } - Some(RefStatus::Invoked(invoked_seq)) => { - if let Some(invocation) = self.invocations.get_mut(invoked_seq) { - invocation.add_user(seq) - } - } - None => tracing::debug!( - "ignoring dependency on potentially complete invocation for ref: {:?}", - use_ - ), - } + let invocation = Arc::new(sync::Mutex::new(Invocation::new(seq))); + self.inflight_invocations.insert(seq, invocation.clone()); + for ref use_ in uses { + let producer = self.invocation_for_ref.get(use_).unwrap(); + producer.lock().unwrap().add_user(invocation.clone()); } + for def in defs { - self.invocation_for_ref.insert( - def, - match invocation.exception() { - Some(err) => RefStatus::Errored(err.clone()), - None => RefStatus::Invoked(seq.clone()), - }, - ); + self.invocation_for_ref.insert(def, invocation.clone()); + } + let invocation = invocation.lock().unwrap(); + if matches!(invocation.status, Status::Errored(_)) { + vec![(seq, invocation.msg_result())] + } else { + vec![] } - - self.invocations.insert(seq, invocation); - - results } /// Propagate worker error to the invocation with the given Seq. This will also propagate @@ -542,31 +518,18 @@ impl History { exception: Exception, ) -> Vec<(Seq, Option>)> { let mut results = Vec::new(); - let mut queue = vec![seq]; - let mut visited = HashSet::new(); + let invocation = self.inflight_invocations.get(&seq).unwrap().clone(); - while let Some(seq) = queue.pop() { - if !visited.insert(seq) { - continue; - } + let mut queue: Vec>> = vec![invocation]; + let mut visited = HashSet::new(); - let Some(invocation) = self.invocations.get_mut(&seq) else { + while let Some(invocation) = queue.pop() { + let mut invocation = invocation.lock().unwrap(); + if !visited.insert(invocation.seq) { continue; }; - - // Overwrite the error, so we are using the last error for this invocation to send - // to the client. - for def in invocation.defs.iter() { - match self.invocation_for_ref.get(def) { - Some(RefStatus::Invoked(invoked_seq)) if *invoked_seq == seq => self - .invocation_for_ref - .insert(*def, RefStatus::Errored(exception.clone())), - _ => None, - }; - } - invocation.set_exception(exception.clone()); - results.push((seq, invocation.result.clone())); - queue.extend(invocation.users.iter()); + queue.extend(invocation.set_exception(exception.clone())); + results.push((seq, invocation.msg_result())); } results } @@ -584,41 +547,22 @@ impl History { let mut results: Vec<(Seq, Option>)> = Vec::new(); for i in Seq::iter_between(prev, self.min_incomplete_seq) { - if let Some(invocation) = self.invocations.remove(&i) { - match invocation.result { - Some(Err(_)) => { - // Retain the def history because we may need it to propagate - // errors in the future. We rely here on the fact that the invocation - // above has been marked as failed by way of failure propagation. - for def in &invocation.defs { - match self.invocation_for_ref.get(def) { - Some(RefStatus::Invoked(seq)) if *seq == i => { - self.invocation_for_ref.remove(def) - } - _ => None, - }; - } - - // we have already reported all exceptions when they are generated. - } - e => { - results.push((i, e)); - } - } + let invocation = self.inflight_invocations.remove(&i).unwrap(); + let mut invocation = invocation.lock().unwrap(); + + if matches!(invocation.status, Status::Errored(_)) { + // we already reported output early when it errored + continue; } + invocation.succeed(); + results.push((i, invocation.msg_result())); } results } - #[cfg(test)] - fn get_invocation(&self, seq: Seq) -> Option<&Invocation> { - self.invocations.get(&seq) - } - pub fn set_result(&mut self, seq: Seq, result: Serialized) { - if let Some(invocation) = self.invocations.get_mut(&seq) { - invocation.set_result(result); - } + let invocation = self.inflight_invocations.get(&seq).unwrap(); + invocation.lock().unwrap().set_result(result); } }