diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index e9c5710d2..ec19b3857 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -518,24 +518,28 @@ impl ConnectionHandler { // if the channel for subscription was dropped, remove the // subscription from the map and unsubscribe. - if subscription.sender.send(message).await.is_err() { - self.subscriptions.remove(&sid); - self.connection - .write_op(ClientOp::Unsubscribe { sid, max: None }) - .await?; - self.connection.flush().await?; - // if channel was open and we sent the messsage, increase the - // `delivered` counter. - } else { - subscription.delivered += 1; - // if this `Subscription` has set `max` value, check if it - // was reached. If yes, remove the `Subscription` and in - // the result, `drop` the `sender` channel. - if let Some(max) = subscription.max { - if subscription.delivered.ge(&max) { - self.subscriptions.remove(&sid); + match subscription.sender.try_send(message) { + Ok(_) => { + subscription.delivered += 1; + // if this `Subscription` has set `max` value, check if it + // was reached. If yes, remove the `Subscription` and in + // the result, `drop` the `sender` channel. + if let Some(max) = subscription.max { + if subscription.delivered.ge(&max) { + self.subscriptions.remove(&sid); + } } } + Err(mpsc::error::TrySendError::Full(_)) => { + self.events.send(ServerEvent::SlowConsumer(sid)).await.ok(); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + self.subscriptions.remove(&sid); + self.connection + .write_op(ClientOp::Unsubscribe { sid, max: None }) + .await?; + self.connection.flush().await?; + } } } } @@ -705,13 +709,15 @@ impl ConnectionHandler { pub struct Client { sender: mpsc::Sender, next_subscription_id: Arc, + subscription_capacity: usize, } impl Client { - pub(crate) fn new(sender: mpsc::Sender) -> Client { + pub(crate) fn new(sender: mpsc::Sender, capacity: usize) -> Client { Client { sender, next_subscription_id: Arc::new(AtomicU64::new(0)), + subscription_capacity: capacity, } } @@ -850,7 +856,7 @@ impl Client { queue_group: Option, ) -> Result { let sid = self.next_subscription_id.fetch_add(1, Ordering::Relaxed); - let (sender, receiver) = mpsc::channel(16); + let (sender, receiver) = mpsc::channel(self.subscription_capacity); self.sender .send(Command::Subscribe { @@ -913,7 +919,7 @@ pub async fn connect_with_options( // TODO make channel size configurable let (sender, receiver) = mpsc::channel(options.sender_capacity); - let client = Client::new(sender.clone()); + let client = Client::new(sender.clone(), options.subscription_capacity); let mut connect_info = ConnectInfo { tls_required, // FIXME(tp): have optional name @@ -998,6 +1004,12 @@ pub async fn connect_with_options( ServerEvent::Disconnect => options.disconnect_callback.call().await, ServerEvent::Error(error) => options.error_callback.call(error).await, ServerEvent::LameDuckMode => options.lame_duck_callback.call().await, + ServerEvent::SlowConsumer(sid) => { + options + .error_callback + .call(ServerError::SlowConsumer(sid)) + .await + } } } }); @@ -1011,6 +1023,7 @@ pub(crate) enum ServerEvent { Reconnect, Disconnect, LameDuckMode, + SlowConsumer(u64), Error(ServerError), } @@ -1162,6 +1175,7 @@ impl Stream for Subscriber { #[derive(Clone, Debug, Eq, PartialEq)] pub enum ServerError { AuthorizationViolation, + SlowConsumer(u64), Other(String), } @@ -1178,6 +1192,7 @@ impl std::fmt::Display for ServerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::AuthorizationViolation => write!(f, "nats: authorization violation"), + Self::SlowConsumer(sid) => write!(f, "nats: subscription {} is a slow consumer", sid), Self::Other(error) => write!(f, "nats: {}", error), } } diff --git a/async-nats/src/options.rs b/async-nats/src/options.rs index 8d127cce9..7bba8f24d 100644 --- a/async-nats/src/options.rs +++ b/async-nats/src/options.rs @@ -46,6 +46,7 @@ pub struct ConnectOptions { pub(crate) tls_client_config: Option, pub(crate) flush_interval: Duration, pub(crate) ping_interval: Duration, + pub(crate) subscription_capacity: usize, pub(crate) sender_capacity: usize, pub(crate) reconnect_callback: CallbackArg0<()>, pub(crate) disconnect_callback: CallbackArg0<()>, @@ -90,6 +91,7 @@ impl Default for ConnectOptions { flush_interval: Duration::from_millis(100), ping_interval: Duration::from_secs(60), sender_capacity: 128, + subscription_capacity: 1024, reconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))), disconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))), lame_duck_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))), @@ -350,6 +352,23 @@ impl ConnectOptions { self } + /// Sets the capacity for `Subscribers`. Exceeding it will trigger `slow consumer` error + /// callback and drop messages. + /// Defualt is set to 1024 messages buffer. + /// + /// # Examples + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// async_nats::ConnectOptions::new().subscription_capacity(1024).connect("demo.nats.io").await?; + /// # Ok(()) + /// # } + /// ``` + pub fn subscription_capacity(mut self, capacity: usize) -> ConnectOptions { + self.subscription_capacity = capacity; + self + } + /// Registers asynchronous callback for errors that are receiver over the wire from the server. /// /// # Examples diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index 60c2ff9a9..ba454d751 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -12,7 +12,7 @@ // limitations under the License. mod client { - use async_nats::ConnectOptions; + use async_nats::{ConnectOptions, ServerError}; use bytes::Bytes; use futures::future::join_all; use futures::stream::StreamExt; @@ -406,4 +406,51 @@ mod client { .unwrap() .unwrap(); } + + #[tokio::test] + async fn slow_consumers() { + let server = nats_server::run_basic_server(); + + let (tx, mut rx) = tokio::sync::mpsc::channel(128); + let client = ConnectOptions::new() + .subscription_capacity(1) + .error_callback(move |err| { + let tx = tx.clone(); + async move { + if let ServerError::SlowConsumer(_) = err { + tx.send(()).await.unwrap() + } + } + }) + .connect(server.client_url()) + .await + .unwrap(); + + let _sub = client.subscribe("data".to_string()).await.unwrap(); + client + .publish("data".to_string(), "data".into()) + .await + .unwrap(); + client + .publish("data".to_string(), "data".into()) + .await + .unwrap(); + client.flush().await.unwrap(); + client + .publish("data".to_string(), "data".into()) + .await + .unwrap(); + client.flush().await.unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + tokio::time::timeout(Duration::from_secs(5), rx.recv()) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(5), rx.recv()) + .await + .unwrap() + .unwrap(); + } }