Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 103 additions & 159 deletions monarch_extension/src/mesh_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::iter::repeat_n;
use std::sync;
use std::sync::Arc;
use std::sync::atomic;
use std::sync::atomic::AtomicUsize;
Expand Down Expand Up @@ -216,9 +217,8 @@ impl _Controller {
Ok(())
}

fn drop_refs(&mut self, refs: Vec<Ref>) -> Result<(), anyhow::Error> {
self.history.delete_invocations_for_refs(refs);
Ok(())
fn drop_refs(&mut self, refs: Vec<Ref>) {
self.history.drop_refs(refs);
}

fn send<'py>(&mut self, ranks: Bound<'py, PyAny>, message: Bound<'py, PyAny>) -> PyResult<()> {
Expand Down Expand Up @@ -292,56 +292,84 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult
/// It is useful for tracking the dependencies of an operation and propagating
/// failures. In the future this will be used with more data dependency tracking
/// to support better failure handling.
// Allowing dead code until we do something smarter with defs, uses etc.

#[derive(Debug)]
enum Status {
Errored(Exception),
Complete(),

/// When incomplete this holds this list of users of this invocation,
/// so a future error can be propagated to them.,
Incomplete(HashMap<Seq, Arc<sync::Mutex<Invocation>>>),
}
#[derive(Debug)]
struct Invocation {
/// The sequence number of the invocation. This should be unique and increasing across all
/// invocations.
seq: Seq,
/// The references that this invocation defines or redefines. Effectively the
/// output of the invocation.
defs: Vec<Ref>,
/// The result of the invocation. This is set when the invocation is completed or
/// when a failure is inferred. A successful result will always supersede any failure.
result: Option<Result<Serialized, Exception>>,
/// The seqs for the invocations that depend on this invocation. Useful for propagating failures.
users: HashSet<Seq>,
status: Status,
/// Result reported to a future if this invocation was a fetch
/// Not all Invocations will be fetched so sometimes a Invocation will complete with
/// both result and error == None
result: Option<Serialized>,
}

impl Invocation {
fn new(seq: Seq, defs: Vec<Ref>) -> Self {
fn new(seq: Seq) -> Self {
Self {
seq,
defs,
status: Status::Incomplete(HashMap::new()),
result: None,
users: HashSet::new(),
}
}

fn add_user(&mut self, user: Seq) {
self.users.insert(user);
fn add_user(&mut self, user: Arc<sync::Mutex<Invocation>>) {
match &mut self.status {
Status::Complete() => {}
Status::Incomplete(users) => {
let seq = user.lock().unwrap().seq;
users.insert(seq, user);
}
Status::Errored(err) => {
user.lock().unwrap().set_exception(err.clone());
}
}
}

/// Invocation results can only go from valid to failed, or be
/// set if the invocation result is empty.
fn set_result(&mut self, result: Serialized) {
if self.result.is_none() {
self.result = Some(Ok(result));
self.result = Some(result);
}
}

fn succeed(&mut self) {
match self.status {
Status::Incomplete(_) => self.status = Status::Complete(),
_ => {}
}
}

fn set_exception(&mut self, exception: Exception) {
fn set_exception(&mut self, exception: Exception) -> Vec<Arc<sync::Mutex<Invocation>>> {
match exception {
Exception::Error(_, caused_by, error) => {
let e = Err(Exception::Error(self.seq, caused_by, error));
match self.result {
Some(Ok(_)) => {
self.result = Some(e);
Exception::Error(_, caused_by_new, error) => {
let err = Status::Errored(Exception::Error(self.seq, caused_by_new, error));
match &self.status {
Status::Errored(Exception::Error(_, caused_by_current, _))
if caused_by_new < *caused_by_current =>
{
self.status = err;
}
Status::Incomplete(users) => {
let users = users.values().cloned().collect();
self.status = err;
return users;
}
None => {
self.result = Some(e);
Status::Complete() => {
panic!("Complete invocation getting an exception set")
}
Some(Err(_)) => {}
_ => {}
}
}
Exception::Failure(_) => {
Expand All @@ -351,32 +379,20 @@ impl Invocation {
);
}
}
vec![]
}

fn exception(&self) -> Option<&Exception> {
self.result
.as_ref()
.map(Result::as_ref)
.and_then(Result::err)
}

#[allow(dead_code)]
fn value(&self) -> Option<&Serialized> {
self.result
.as_ref()
.map(Result::as_ref)
.and_then(Result::ok)
fn msg_result(&self) -> Option<Result<Serialized, Exception>> {
match &self.status {
Status::Complete() => self.result.clone().map(Ok),
Status::Errored(err) => Some(Err(err.clone())),
Status::Incomplete(_) => {
panic!("Incomplete invocation doesn't have a result yet")
}
}
}
}

#[derive(Debug, PartialEq)]
enum RefStatus {
// The invocation for this ref is still in progress.
Invoked(Seq),
// The invocation for this ref has errored.
Errored(Exception),
}

/// The history of invocations sent by the client to be executed on the workers.
/// This is used to track dependencies between invocations and to propagate exceptions.
/// It purges history for completed invocations to avoid memory bloat.
Expand All @@ -390,13 +406,11 @@ struct History {
first_incomplete_seqs: MinVector<Seq>,
/// The minimum incomplete Seq across all ranks.
min_incomplete_seq: Seq,
/// A map of seq to the invocation that it represents.
invocations: HashMap<Seq, Invocation>,
/// A map of seq to the invocation that it represents for all seq >= min_incomplete_seq
inflight_invocations: HashMap<Seq, Arc<sync::Mutex<Invocation>>>,
/// A map of reference to the seq for the invocation that defines it. This is used to
/// compute dependencies between invocations.
invocation_for_ref: HashMap<Ref, RefStatus>,
// Refs to be deleted in mark_worker_complete_and_propagate_failures
marked_for_deletion: HashSet<Ref>,
invocation_for_ref: HashMap<Ref, Arc<sync::Mutex<Invocation>>>,
// no new sequence numbers should be below this bound. use for
// sanity checking.
seq_lower_bound: Seq,
Expand Down Expand Up @@ -448,8 +462,7 @@ impl History {
first_incomplete_seqs: MinVector::new(vec![Seq::default(); world_size]),
min_incomplete_seq: Seq::default(),
invocation_for_ref: HashMap::new(),
invocations: HashMap::new(),
marked_for_deletion: HashSet::new(),
inflight_invocations: HashMap::new(),
seq_lower_bound: 0.into(),
}
}
Expand All @@ -459,25 +472,10 @@ impl History {
self.first_incomplete_seqs.vec()
}

pub fn delete_invocations_for_refs(&mut self, refs: Vec<Ref>) {
self.marked_for_deletion.extend(refs);

self.marked_for_deletion
.retain(|ref_| match self.invocation_for_ref.get(ref_) {
Some(RefStatus::Invoked(seq)) => {
if seq < &self.min_incomplete_seq {
self.invocation_for_ref.remove(ref_);
false
} else {
true
}
}
Some(RefStatus::Errored(_)) => {
self.invocation_for_ref.remove(ref_);
false
}
None => true,
});
pub fn drop_refs(&mut self, refs: Vec<Ref>) {
for r in refs {
self.invocation_for_ref.remove(&r);
}
}

/// Add an invocation to the history.
Expand All @@ -487,51 +485,29 @@ impl History {
uses: Vec<Ref>,
defs: Vec<Ref>,
) -> Vec<(Seq, Option<Result<Serialized, Exception>>)> {
let mut results = Vec::new();
assert!(
seq >= self.seq_lower_bound,
"nonmonotonic seq: {:?}; current lower bound: {:?}",
seq,
self.seq_lower_bound,
);
self.seq_lower_bound = seq;
let mut invocation = Invocation::new(seq, defs.clone());

for use_ in uses {
// The invocation for every use_ should add this seq as a user.
match self.invocation_for_ref.get(&use_) {
Some(RefStatus::Errored(exception)) => {
// We know that this invocation hasn't been completed yet, so we can
// directly call set_exception on it.
if results.is_empty() {
invocation.set_exception(exception.clone());
results.push((seq, Some(Err(exception.clone()))));
}
}
Some(RefStatus::Invoked(invoked_seq)) => {
if let Some(invocation) = self.invocations.get_mut(invoked_seq) {
invocation.add_user(seq)
}
}
None => tracing::debug!(
"ignoring dependency on potentially complete invocation for ref: {:?}",
use_
),
}
let invocation = Arc::new(sync::Mutex::new(Invocation::new(seq)));
self.inflight_invocations.insert(seq, invocation.clone());
for ref use_ in uses {
let producer = self.invocation_for_ref.get(use_).unwrap();
producer.lock().unwrap().add_user(invocation.clone());
}

for def in defs {
self.invocation_for_ref.insert(
def,
match invocation.exception() {
Some(err) => RefStatus::Errored(err.clone()),
None => RefStatus::Invoked(seq.clone()),
},
);
self.invocation_for_ref.insert(def, invocation.clone());
}
let invocation = invocation.lock().unwrap();
if matches!(invocation.status, Status::Errored(_)) {
vec![(seq, invocation.msg_result())]
} else {
vec![]
}

self.invocations.insert(seq, invocation);

results
}

/// Propagate worker error to the invocation with the given Seq. This will also propagate
Expand All @@ -542,31 +518,18 @@ impl History {
exception: Exception,
) -> Vec<(Seq, Option<Result<Serialized, Exception>>)> {
let mut results = Vec::new();
let mut queue = vec![seq];
let mut visited = HashSet::new();
let invocation = self.inflight_invocations.get(&seq).unwrap().clone();

while let Some(seq) = queue.pop() {
if !visited.insert(seq) {
continue;
}
let mut queue: Vec<Arc<sync::Mutex<Invocation>>> = vec![invocation];
let mut visited = HashSet::new();

let Some(invocation) = self.invocations.get_mut(&seq) else {
while let Some(invocation) = queue.pop() {
let mut invocation = invocation.lock().unwrap();
if !visited.insert(invocation.seq) {
continue;
};

// Overwrite the error, so we are using the last error for this invocation to send
// to the client.
for def in invocation.defs.iter() {
match self.invocation_for_ref.get(def) {
Some(RefStatus::Invoked(invoked_seq)) if *invoked_seq == seq => self
.invocation_for_ref
.insert(*def, RefStatus::Errored(exception.clone())),
_ => None,
};
}
invocation.set_exception(exception.clone());
results.push((seq, invocation.result.clone()));
queue.extend(invocation.users.iter());
queue.extend(invocation.set_exception(exception.clone()));
results.push((seq, invocation.msg_result()));
}
results
}
Expand All @@ -584,40 +547,21 @@ impl History {

let mut results: Vec<(Seq, Option<Result<Serialized, Exception>>)> = Vec::new();
for i in Seq::iter_between(prev, self.min_incomplete_seq) {
if let Some(invocation) = self.invocations.remove(&i) {
match invocation.result {
Some(Err(_)) => {
// Retain the def history because we may need it to propagate
// errors in the future. We rely here on the fact that the invocation
// above has been marked as failed by way of failure propagation.
for def in &invocation.defs {
match self.invocation_for_ref.get(def) {
Some(RefStatus::Invoked(seq)) if *seq == i => {
self.invocation_for_ref.remove(def)
}
_ => None,
};
}

// we have already reported all exceptions when they are generated.
}
e => {
results.push((i, e));
}
}
let invocation = self.inflight_invocations.remove(&i).unwrap();
let mut invocation = invocation.lock().unwrap();

if matches!(invocation.status, Status::Errored(_)) {
// we already reported output early when it errored
continue;
}
invocation.succeed();
results.push((i, invocation.msg_result()));
}
results
}

#[cfg(test)]
fn get_invocation(&self, seq: Seq) -> Option<&Invocation> {
self.invocations.get(&seq)
}

pub fn set_result(&mut self, seq: Seq, result: Serialized) {
if let Some(invocation) = self.invocations.get_mut(&seq) {
invocation.set_result(result);
}
let invocation = self.inflight_invocations.get(&seq).unwrap();
invocation.lock().unwrap().set_result(result);
}
}