diff --git a/bastion/src/children.rs b/bastion/src/children.rs index 6dad2248..e6201a08 100644 --- a/bastion/src/children.rs +++ b/bastion/src/children.rs @@ -1,3 +1,4 @@ +use crate::bastion::REGISTRY; use crate::broadcast::{BastionMessage, Broadcast, Sender}; use crate::context::{BastionContext, BastionId}; use futures::future::CatchUnwind; @@ -70,11 +71,15 @@ impl Children { } async fn run(mut self) -> Self { + REGISTRY.add_children(&self); + loop { match poll!(&mut self.bcast.next()) { Poll::Ready(Some(msg)) => { match msg { BastionMessage::PoisonPill | BastionMessage::Dead { .. } | BastionMessage::Faulted { .. } => { + REGISTRY.remove_children(&self); + if msg.is_faulted() { self.bcast.faulted(); } else { @@ -88,6 +93,8 @@ impl Children { } } Poll::Ready(None) => { + REGISTRY.remove_children(&self); + self.bcast.faulted(); return self; @@ -129,8 +136,12 @@ impl Child { } async fn run(mut self) { + REGISTRY.add_child(&self); + loop { if let Poll::Ready(res) = poll!(&mut self.exec) { + REGISTRY.remove_child(&self); + match res { Ok(Ok(())) => self.bcast.dead(), Ok(Err(())) | Err(_) => self.bcast.faulted(), @@ -143,11 +154,15 @@ impl Child { Poll::Ready(Some(msg)) => { match msg { BastionMessage::PoisonPill => { + REGISTRY.remove_child(&self); + self.bcast.dead(); return; } BastionMessage::Dead { .. } | BastionMessage::Faulted { .. } => { + REGISTRY.remove_child(&self); + self.bcast.faulted(); return; @@ -157,6 +172,8 @@ impl Child { } } Poll::Ready(None) => { + REGISTRY.remove_child(&self); + self.bcast.faulted(); return; diff --git a/bastion/src/registry.rs b/bastion/src/registry.rs index 718ca0f0..9a6c93a9 100644 --- a/bastion/src/registry.rs +++ b/bastion/src/registry.rs @@ -55,6 +55,24 @@ impl Registry { self.registered.insert(id, registrant); } + pub(super) fn remove_supervisor(&self, supervisor: &Supervisor) { + let id = supervisor.id(); + + self.registered.remove(id); + } + + pub(super) fn remove_children(&self, children: &Children) { + let id = children.id(); + + self.registered.remove(id); + } + + pub(super) fn remove_child(&self, child: &Child) { + let id = child.id(); + + self.registered.remove(id); + } + pub(super) fn send_supervisor(&self, id: &BastionId, msg: BastionMessage) -> Result<(), BastionMessage> { let registrant = if let Some(registrant) = self.registered.get(id) { registrant diff --git a/bastion/src/supervisor.rs b/bastion/src/supervisor.rs index 46276c8d..0bdaa961 100644 --- a/bastion/src/supervisor.rs +++ b/bastion/src/supervisor.rs @@ -1,4 +1,4 @@ -use crate::bastion::SYSTEM; +use crate::bastion::{REGISTRY, SYSTEM}; use crate::broadcast::{BastionMessage, Broadcast, Sender}; use crate::children::{Children, Closure, Message}; use crate::context::BastionId; @@ -141,11 +141,15 @@ impl Supervisor { } pub(super) async fn run(mut self) -> Self { + REGISTRY.add_supervisor(&self); + loop { match poll!(&mut self.bcast.next()) { Poll::Ready(Some(msg)) => { match msg { BastionMessage::PoisonPill => { + REGISTRY.remove_supervisor(&self); + self.bcast.dead(); return self; @@ -158,6 +162,8 @@ impl Supervisor { } BastionMessage::Faulted { id } => { if self.recover(id).await.is_err() { + REGISTRY.remove_supervisor(&self); + self.bcast.faulted(); return self; @@ -171,6 +177,8 @@ impl Supervisor { } } Poll::Ready(None) => { + REGISTRY.remove_supervisor(&self); + self.bcast.faulted(); return self;