Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add slow consumers #444

Merged
merged 1 commit into from
May 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,24 +518,28 @@ impl ConnectionHandler {

// if the channel for subscription was dropped, remove the
// subscription from the map and unsubscribe.
if subscription.sender.send(message).await.is_err() {
self.subscriptions.remove(&sid);
self.connection
.write_op(ClientOp::Unsubscribe { sid, max: None })
.await?;
self.connection.flush().await?;
// if channel was open and we sent the messsage, increase the
// `delivered` counter.
} else {
subscription.delivered += 1;
// if this `Subscription` has set `max` value, check if it
// was reached. If yes, remove the `Subscription` and in
// the result, `drop` the `sender` channel.
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
self.subscriptions.remove(&sid);
match subscription.sender.try_send(message) {
Ok(_) => {
subscription.delivered += 1;
// if this `Subscription` has set `max` value, check if it
// was reached. If yes, remove the `Subscription` and in
// the result, `drop` the `sender` channel.
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
self.subscriptions.remove(&sid);
}
}
}
Err(mpsc::error::TrySendError::Full(_)) => {
self.events.send(ServerEvent::SlowConsumer(sid)).await.ok();
}
Err(mpsc::error::TrySendError::Closed(_)) => {
self.subscriptions.remove(&sid);
self.connection
.write_op(ClientOp::Unsubscribe { sid, max: None })
.await?;
self.connection.flush().await?;
}
}
}
}
Expand Down Expand Up @@ -705,13 +709,15 @@ impl ConnectionHandler {
pub struct Client {
sender: mpsc::Sender<Command>,
next_subscription_id: Arc<AtomicU64>,
subscription_capacity: usize,
}

impl Client {
pub(crate) fn new(sender: mpsc::Sender<Command>) -> Client {
pub(crate) fn new(sender: mpsc::Sender<Command>, capacity: usize) -> Client {
Client {
sender,
next_subscription_id: Arc::new(AtomicU64::new(0)),
subscription_capacity: capacity,
}
}

Expand Down Expand Up @@ -850,7 +856,7 @@ impl Client {
queue_group: Option<String>,
) -> Result<Subscriber, io::Error> {
let sid = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(16);
let (sender, receiver) = mpsc::channel(self.subscription_capacity);

self.sender
.send(Command::Subscribe {
Expand Down Expand Up @@ -913,7 +919,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
// TODO make channel size configurable
let (sender, receiver) = mpsc::channel(options.sender_capacity);

let client = Client::new(sender.clone());
let client = Client::new(sender.clone(), options.subscription_capacity);
let mut connect_info = ConnectInfo {
tls_required,
// FIXME(tp): have optional name
Expand Down Expand Up @@ -998,6 +1004,12 @@ pub async fn connect_with_options<A: ToServerAddrs>(
ServerEvent::Disconnect => options.disconnect_callback.call().await,
ServerEvent::Error(error) => options.error_callback.call(error).await,
ServerEvent::LameDuckMode => options.lame_duck_callback.call().await,
ServerEvent::SlowConsumer(sid) => {
options
.error_callback
.call(ServerError::SlowConsumer(sid))
.await
}
}
}
});
Expand All @@ -1011,6 +1023,7 @@ pub(crate) enum ServerEvent {
Reconnect,
Disconnect,
LameDuckMode,
SlowConsumer(u64),
Error(ServerError),
}

Expand Down Expand Up @@ -1162,6 +1175,7 @@ impl Stream for Subscriber {
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ServerError {
AuthorizationViolation,
SlowConsumer(u64),
Other(String),
}

Expand All @@ -1178,6 +1192,7 @@ impl std::fmt::Display for ServerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
Self::SlowConsumer(sid) => write!(f, "nats: subscription {} is a slow consumer", sid),
Self::Other(error) => write!(f, "nats: {}", error),
}
}
Expand Down
19 changes: 19 additions & 0 deletions async-nats/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub struct ConnectOptions {
pub(crate) tls_client_config: Option<rustls::ClientConfig>,
pub(crate) flush_interval: Duration,
pub(crate) ping_interval: Duration,
pub(crate) subscription_capacity: usize,
pub(crate) sender_capacity: usize,
pub(crate) reconnect_callback: CallbackArg0<()>,
pub(crate) disconnect_callback: CallbackArg0<()>,
Expand Down Expand Up @@ -90,6 +91,7 @@ impl Default for ConnectOptions {
flush_interval: Duration::from_millis(100),
ping_interval: Duration::from_secs(60),
sender_capacity: 128,
subscription_capacity: 1024,
reconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
disconnect_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
lame_duck_callback: CallbackArg0::<()>(Box::new(|| Box::pin(async {}))),
Expand Down Expand Up @@ -350,6 +352,23 @@ impl ConnectOptions {
self
}

/// Sets the capacity for `Subscribers`. Exceeding it will trigger `slow consumer` error
/// callback and drop messages.
/// Defualt is set to 1024 messages buffer.
///
/// # Examples
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// async_nats::ConnectOptions::new().subscription_capacity(1024).connect("demo.nats.io").await?;
/// # Ok(())
/// # }
/// ```
pub fn subscription_capacity(mut self, capacity: usize) -> ConnectOptions {
self.subscription_capacity = capacity;
self
}

/// Registers asynchronous callback for errors that are receiver over the wire from the server.
///
/// # Examples
Expand Down
49 changes: 48 additions & 1 deletion async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// limitations under the License.

mod client {
use async_nats::ConnectOptions;
use async_nats::{ConnectOptions, ServerError};
use bytes::Bytes;
use futures::future::join_all;
use futures::stream::StreamExt;
Expand Down Expand Up @@ -406,4 +406,51 @@ mod client {
.unwrap()
.unwrap();
}

#[tokio::test]
async fn slow_consumers() {
let server = nats_server::run_basic_server();

let (tx, mut rx) = tokio::sync::mpsc::channel(128);
let client = ConnectOptions::new()
.subscription_capacity(1)
.error_callback(move |err| {
let tx = tx.clone();
async move {
if let ServerError::SlowConsumer(_) = err {
tx.send(()).await.unwrap()
}
}
})
.connect(server.client_url())
.await
.unwrap();

let _sub = client.subscribe("data".to_string()).await.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client.flush().await.unwrap();
client
.publish("data".to_string(), "data".into())
.await
.unwrap();
client.flush().await.unwrap();

tokio::time::sleep(Duration::from_secs(1)).await;

tokio::time::timeout(Duration::from_secs(5), rx.recv())
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(5), rx.recv())
.await
.unwrap()
.unwrap();
}
}