diff --git a/mithril-relay/src/lib.rs b/mithril-relay/src/lib.rs index f08b4b89565..9e08eb40d4b 100644 --- a/mithril-relay/src/lib.rs +++ b/mithril-relay/src/lib.rs @@ -5,6 +5,7 @@ mod commands; /// Peer to peer module pub mod p2p; mod relay; +mod repeater; pub use commands::Args; pub use commands::RelayCommands; diff --git a/mithril-relay/src/repeater.rs b/mithril-relay/src/repeater.rs new file mode 100644 index 00000000000..6dd1efe35e4 --- /dev/null +++ b/mithril-relay/src/repeater.rs @@ -0,0 +1,110 @@ +use anyhow::anyhow; +use mithril_common::StdResult; +use slog_scope::debug; +use std::{fmt::Debug, sync::Arc, time::Duration}; +use tokio::sync::{mpsc::UnboundedSender, Mutex}; + +/// A message repeater will send a message to a channel at a given delay +pub struct MessageRepeater { + message: Arc>>, + tx_message: UnboundedSender, + delay: Duration, +} + +impl MessageRepeater { + /// Factory for MessageRepeater + pub fn new(tx_message: UnboundedSender, delay: Duration) -> Self { + Self { + message: Arc::new(Mutex::new(None)), + tx_message, + delay, + } + } + + /// Set the message to repeat + pub async fn set_message(&self, message: M) { + debug!("MessageRepeater: set message"; "message" => format!("{:#?}", message)); + *self.message.lock().await = Some(message); + } + + /// Start repeating the message if any + pub async fn repeat_message(&self) -> StdResult<()> { + tokio::time::sleep(self.delay).await; + match self.message.lock().await.as_ref() { + Some(message) => { + debug!("MessageRepeater: repeat message"; "message" => format!("{:#?}", message)); + self.tx_message + .send(message.clone()) + .map_err(|e| anyhow!(e))? + } + None => { + debug!("MessageRepeater: no message to repeat"); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use tokio::{sync::mpsc, time}; + + use super::*; + + #[tokio::test] + async fn should_repeat_message_when_exists() { + let (tx, mut rx) = mpsc::unbounded_channel(); + let delay = Duration::from_millis(100); + let repeater = MessageRepeater::new(tx, delay); + + let message = "Hello, world!"; + repeater.set_message(message.to_string()).await; + repeater.repeat_message().await.unwrap(); + + let received = rx.recv().await.unwrap(); + assert_eq!(message, received); + } + + #[tokio::test] + async fn should_repeat_message_when_exists_with_expected_delay() { + let (tx, _rx) = mpsc::unbounded_channel(); + let delay = Duration::from_secs(1); + let repeater = MessageRepeater::new(tx, delay); + + let message = "Hello, world!"; + repeater.set_message(message.to_string()).await; + + let result = tokio::select! { + _ = time::sleep(delay-Duration::from_millis(100)) => {Err(anyhow!("Timeout"))} + _ = repeater.repeat_message() => {Ok(())} + }; + + result.expect_err("should have timed out"); + } + + #[tokio::test] + async fn should_do_nothing_when_message_not_exists() { + let (tx, rx) = mpsc::unbounded_channel::(); + let delay = Duration::from_millis(100); + let repeater = MessageRepeater::new(tx, delay); + + repeater.repeat_message().await.unwrap(); + + assert!(rx.is_empty()); + } + + #[tokio::test] + async fn should_do_nothing_when_message_not_exists_with_expected_delay() { + let (tx, _rx) = mpsc::unbounded_channel::(); + let delay = Duration::from_secs(1); + let repeater = MessageRepeater::new(tx, delay); + + let result = tokio::select! { + _ = time::sleep(delay-Duration::from_millis(100)) => {Err(anyhow!("Timeout"))} + _ = repeater.repeat_message() => {Ok(())} + }; + + result.expect_err("should have timed out"); + } +}