Skip to content

Commit

Permalink
Add slow consumers
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarema committed May 16, 2022
1 parent 1bfc9ed commit bb0ac6d
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 20 deletions.
94 changes: 75 additions & 19 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ pub enum Command {
},
TryFlush,
Connect(ConnectInfo),
GetDroppedMessages {
response: tokio::sync::oneshot::Sender<u64>,
sid: u64,
},
}

/// `ClientOp` represents all actions of `Client`.
Expand Down Expand Up @@ -418,6 +422,7 @@ struct Subscription {
sender: mpsc::Sender<Message>,
queue_group: Option<String>,
delivered: u64,
dropped: u64,
max: Option<u64>,
}

Expand Down Expand Up @@ -518,24 +523,31 @@ impl ConnectionHandler {

// if the channel for subscription was dropped, remove the
// subscription from the map and unsubscribe.
if subscription.sender.send(message).await.is_err() {
self.subscriptions.remove(&sid);
self.connection
.write_op(ClientOp::Unsubscribe { sid, max: None })
.await?;
self.connection.flush().await?;
// if channel was open and we sent the messsage, increase the
// `delivered` counter.
} else {
subscription.delivered += 1;
// if this `Subscription` has set `max` value, check if it
// was reached. If yes, remove the `Subscription` and in
// the result, `drop` the `sender` channel.
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
self.subscriptions.remove(&sid);
match subscription.sender.try_send(message) {
Ok(_) => {
subscription.delivered += 1;
// if this `Subscription` has set `max` value, check if it
// was reached. If yes, remove the `Subscription` and in
// the result, `drop` the `sender` channel.
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
self.subscriptions.remove(&sid);
}
}
}
Err(err) => match err {
mpsc::error::TrySendError::Full(_) => {
subscription.dropped += 1;
self.events.send(ServerEvent::SlowConsumer(sid)).await.ok();
}
mpsc::error::TrySendError::Closed(_) => {
self.subscriptions.remove(&sid);
self.connection
.write_op(ClientOp::Unsubscribe { sid, max: None })
.await?;
self.connection.flush().await?;
}
},
}
}
}
Expand Down Expand Up @@ -619,6 +631,7 @@ impl ConnectionHandler {
let subscription = Subscription {
sender,
delivered: 0,
dropped: 0,
max: None,
subject: subject.to_owned(),
queue_group: queue_group.to_owned(),
Expand Down Expand Up @@ -667,6 +680,11 @@ impl ConnectionHandler {
self.handle_disconnect().await?;
}
}
Command::GetDroppedMessages { response, sid } => {
response
.send(self.subscriptions.get(&sid).unwrap().dropped)
.unwrap();
}
}

Ok(())
Expand Down Expand Up @@ -705,13 +723,15 @@ impl ConnectionHandler {
pub struct Client {
sender: mpsc::Sender<Command>,
next_subscription_id: Arc<AtomicU64>,
subscription_capacity: usize,
}

impl Client {
pub(crate) fn new(sender: mpsc::Sender<Command>) -> Client {
pub(crate) fn new(sender: mpsc::Sender<Command>, capacity: usize) -> Client {
Client {
sender,
next_subscription_id: Arc::new(AtomicU64::new(0)),
subscription_capacity: capacity,
}
}

Expand Down Expand Up @@ -850,7 +870,7 @@ impl Client {
queue_group: Option<String>,
) -> Result<Subscriber, io::Error> {
let sid = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(16);
let (sender, receiver) = mpsc::channel(self.subscription_capacity);

self.sender
.send(Command::Subscribe {
Expand Down Expand Up @@ -913,7 +933,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
// TODO make channel size configurable
let (sender, receiver) = mpsc::channel(128);

let client = Client::new(sender.clone());
let client = Client::new(sender.clone(), options.subscription_capacity);
let mut connect_info = ConnectInfo {
tls_required,
// FIXME(tp): have optional name
Expand Down Expand Up @@ -998,6 +1018,12 @@ pub async fn connect_with_options<A: ToServerAddrs>(
ServerEvent::Disconnect => options.disconnect_callback.call().await,
ServerEvent::Error(error) => options.error_callback.call(error).await,
ServerEvent::LameDuckMode => options.lame_duck_callback.call().await,
ServerEvent::SlowConsumer(sid) => {
options
.error_callback
.call(ServerError::SlowConsumer(sid))
.await
}
}
}
});
Expand All @@ -1011,6 +1037,7 @@ pub(crate) enum ServerEvent {
Reconnect,
Disconnect,
LameDuckMode,
SlowConsumer(u64),
Error(ServerError),
}

Expand Down Expand Up @@ -1133,6 +1160,33 @@ impl Subscriber {
.map_err(|err| io::Error::new(ErrorKind::Other, err))?;
Ok(())
}

/// Returns number of dropped messages due to exceeding `Subscription` buffer size.
/// Exceeding that buffer also triggers `slow consumer` error.
///
/// # Examples
/// ```
/// # #[tokio::main]
/// # async fn dropped_messages() -> Result<(), Box<dyn std::error::Error>> {
/// let client = async_nats::connect("demo.nats.io").await?;
///
/// let mut sub = client.subscribe("test".into()).await?;
/// println!("dropped messages: {}", sub.dropped().await?);
///
/// # Ok(())
/// # }
pub async fn dropped(&mut self) -> io::Result<u64> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.sender
.send(Command::GetDroppedMessages {
response: tx,
sid: self.sid,
})
.await
.map_err(|err| io::Error::new(ErrorKind::Other, err))?;
rx.await
.map_err(|err| io::Error::new(ErrorKind::Other, err))
}
}

impl Drop for Subscriber {
Expand Down Expand Up @@ -1162,6 +1216,7 @@ impl Stream for Subscriber {
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ServerError {
AuthorizationViolation,
SlowConsumer(u64),
Other(String),
}

Expand All @@ -1178,6 +1233,7 @@ impl std::fmt::Display for ServerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
Self::SlowConsumer(sid) => write!(f, "nats: subscription {} is a slow consumer", sid),
Self::Other(error) => write!(f, "nats: {}", error),
}
}
Expand Down
19 changes: 19 additions & 0 deletions async-nats/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub struct ConnectOptions {
pub(crate) tls_client_config: Option<rustls::ClientConfig>,
pub(crate) flush_interval: Duration,
pub(crate) ping_interval: Duration,
pub(crate) subscription_capacity: usize,
pub(crate) reconnect_callback: CallbackArg0<()>,
pub(crate) disconnect_callback: CallbackArg0<()>,
pub(crate) lame_duck_callback: CallbackArg0<()>,
Expand Down Expand Up @@ -87,6 +88,7 @@ impl Default for ConnectOptions {
tls_client_config: None,
flush_interval: Duration::from_millis(100),
ping_interval: Duration::from_secs(60),
subscription_capacity: 1024,
reconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
disconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
lame_duck_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
Expand Down Expand Up @@ -347,6 +349,23 @@ impl ConnectOptions {
self
}

/// Sets the capacity for `Subscriptions`. Exceeding it will trigger `slow consumer` error
/// callback and drop messages.
/// Defualt is set to 1024 messages buffer.
///
/// # Examples
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// async_nats::ConnectOptions::new().subscription_capacity(1024).connect("demo.nats.io").await?;
/// # Ok(())
/// # }
/// ```
pub fn subscription_capacity(mut self, capacity: usize) -> ConnectOptions {
self.subscription_capacity = capacity;
self
}

/// Registers asynchronous callback for errors that are receiver over the wire from the server.
///
/// # Examples
Expand Down
42 changes: 41 additions & 1 deletion async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// limitations under the License.

mod client {
use async_nats::ConnectOptions;
use async_nats::{ConnectOptions, ServerError};
use bytes::Bytes;
use futures::future::join_all;
use futures::stream::StreamExt;
Expand Down Expand Up @@ -401,4 +401,44 @@ mod client {
.unwrap()
.unwrap();
}

#[tokio::test]
async fn slow_consumers() {
let server = nats_server::run_basic_server();

let (tx, mut rx) = tokio::sync::mpsc::channel(128);
let client = ConnectOptions::new()
.subscription_capacity(1)
.error_callback(move |err| {
let tx = tx.clone();
async move {
if let ServerError::SlowConsumer(_) = err {
tx.send(()).await.unwrap()
}
}
})
.connect(server.client_url())
.await
.unwrap();

let mut sub = client.subscribe("data".to_string()).await.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client.flush().await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
let dropped = sub.dropped().await.unwrap();
println!("dropped: {}", dropped);
assert_eq!(dropped, 1);

tokio::time::timeout(Duration::from_secs(5), rx.recv())
.await
.unwrap()
.unwrap();
}
}

0 comments on commit bb0ac6d

Please sign in to comment.