Skip to content

Commit

Permalink
Add slow consumers
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarema committed May 21, 2022
1 parent e72221b commit ad0728f
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 20 deletions.
53 changes: 34 additions & 19 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,24 +518,28 @@ impl ConnectionHandler {

// if the channel for subscription was dropped, remove the
// subscription from the map and unsubscribe.
if subscription.sender.send(message).await.is_err() {
self.subscriptions.remove(&sid);
self.connection
.write_op(ClientOp::Unsubscribe { sid, max: None })
.await?;
self.connection.flush().await?;
// if channel was open and we sent the messsage, increase the
// `delivered` counter.
} else {
subscription.delivered += 1;
// if this `Subscription` has set `max` value, check if it
// was reached. If yes, remove the `Subscription` and in
// the result, `drop` the `sender` channel.
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
self.subscriptions.remove(&sid);
match subscription.sender.try_send(message) {
Ok(_) => {
subscription.delivered += 1;
// if this `Subscription` has set `max` value, check if it
// was reached. If yes, remove the `Subscription` and in
// the result, `drop` the `sender` channel.
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
self.subscriptions.remove(&sid);
}
}
}
Err(mpsc::error::TrySendError::Full(_)) => {
self.events.send(ServerEvent::SlowConsumer(sid)).await.ok();
}
Err(mpsc::error::TrySendError::Closed(_)) => {
self.subscriptions.remove(&sid);
self.connection
.write_op(ClientOp::Unsubscribe { sid, max: None })
.await?;
self.connection.flush().await?;
}
}
}
}
Expand Down Expand Up @@ -705,13 +709,15 @@ impl ConnectionHandler {
pub struct Client {
sender: mpsc::Sender<Command>,
next_subscription_id: Arc<AtomicU64>,
subscription_capacity: usize,
}

impl Client {
pub(crate) fn new(sender: mpsc::Sender<Command>) -> Client {
pub(crate) fn new(sender: mpsc::Sender<Command>, capacity: usize) -> Client {
Client {
sender,
next_subscription_id: Arc::new(AtomicU64::new(0)),
subscription_capacity: capacity,
}
}

Expand Down Expand Up @@ -850,7 +856,7 @@ impl Client {
queue_group: Option<String>,
) -> Result<Subscriber, io::Error> {
let sid = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(16);
let (sender, receiver) = mpsc::channel(self.subscription_capacity);

self.sender
.send(Command::Subscribe {
Expand Down Expand Up @@ -913,7 +919,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
// TODO make channel size configurable
let (sender, receiver) = mpsc::channel(options.sender_capacity);

let client = Client::new(sender.clone());
let client = Client::new(sender.clone(), options.subscription_capacity);
let mut connect_info = ConnectInfo {
tls_required,
// FIXME(tp): have optional name
Expand Down Expand Up @@ -998,6 +1004,12 @@ pub async fn connect_with_options<A: ToServerAddrs>(
ServerEvent::Disconnect => options.disconnect_callback.call().await,
ServerEvent::Error(error) => options.error_callback.call(error).await,
ServerEvent::LameDuckMode => options.lame_duck_callback.call().await,
ServerEvent::SlowConsumer(sid) => {
options
.error_callback
.call(ServerError::SlowConsumer(sid))
.await
}
}
}
});
Expand All @@ -1011,6 +1023,7 @@ pub(crate) enum ServerEvent {
Reconnect,
Disconnect,
LameDuckMode,
SlowConsumer(u64),
Error(ServerError),
}

Expand Down Expand Up @@ -1162,6 +1175,7 @@ impl Stream for Subscriber {
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ServerError {
AuthorizationViolation,
SlowConsumer(u64),
Other(String),
}

Expand All @@ -1178,6 +1192,7 @@ impl std::fmt::Display for ServerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
Self::SlowConsumer(sid) => write!(f, "nats: subscription {} is a slow consumer", sid),
Self::Other(error) => write!(f, "nats: {}", error),
}
}
Expand Down
19 changes: 19 additions & 0 deletions async-nats/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub struct ConnectOptions {
pub(crate) tls_client_config: Option<rustls::ClientConfig>,
pub(crate) flush_interval: Duration,
pub(crate) ping_interval: Duration,
pub(crate) subscription_capacity: usize,
pub(crate) sender_capacity: usize,
pub(crate) reconnect_callback: CallbackArg0<()>,
pub(crate) disconnect_callback: CallbackArg0<()>,
Expand Down Expand Up @@ -90,6 +91,7 @@ impl Default for ConnectOptions {
flush_interval: Duration::from_millis(100),
ping_interval: Duration::from_secs(60),
sender_capacity: 128,
subscription_capacity: 1024,
reconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
disconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
lame_duck_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
Expand Down Expand Up @@ -350,6 +352,23 @@ impl ConnectOptions {
self
}

/// Sets the capacity for `Subscribers`. Exceeding it will trigger `slow consumer` error
/// callback and drop messages.
/// Defualt is set to 1024 messages buffer.
///
/// # Examples
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// async_nats::ConnectOptions::new().subscription_capacity(1024).connect("demo.nats.io").await?;
/// # Ok(())
/// # }
/// ```
pub fn subscription_capacity(mut self, capacity: usize) -> ConnectOptions {
self.subscription_capacity = capacity;
self
}

/// Registers asynchronous callback for errors that are receiver over the wire from the server.
///
/// # Examples
Expand Down
49 changes: 48 additions & 1 deletion async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// limitations under the License.

mod client {
use async_nats::ConnectOptions;
use async_nats::{ConnectOptions, ServerError};
use bytes::Bytes;
use futures::future::join_all;
use futures::stream::StreamExt;
Expand Down Expand Up @@ -406,4 +406,51 @@ mod client {
.unwrap()
.unwrap();
}

#[tokio::test]
async fn slow_consumers() {
let server = nats_server::run_basic_server();

let (tx, mut rx) = tokio::sync::mpsc::channel(128);
let client = ConnectOptions::new()
.subscription_capacity(1)
.error_callback(move |err| {
let tx = tx.clone();
async move {
if let ServerError::SlowConsumer(_) = err {
tx.send(()).await.unwrap()
}
}
})
.connect(server.client_url())
.await
.unwrap();

let _sub = client.subscribe("data".to_string()).await.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client.flush().await.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client.flush().await.unwrap();

tokio::time::sleep(Duration::from_secs(1)).await;

tokio::time::timeout(Duration::from_secs(5), rx.recv())
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(5), rx.recv())
.await
.unwrap()
.unwrap();
}
}

0 comments on commit ad0728f

Please sign in to comment.