diff --git a/.config/nats.dic b/.config/nats.dic index 1d5c7f61c..715023e5d 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -133,3 +133,4 @@ ConnectError DNS RequestErrorKind rustls +RttError diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 06ef82601..db5512204 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -463,6 +463,38 @@ impl Client { Ok(()) } + /// Calculates the round trip time between this client and the server, + /// if the server is currently connected. + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// let rtt = client.rtt().await?; + /// println!("server rtt: {:?}", rtt); + /// # Ok(()) + /// # } + /// ``` + pub async fn rtt(&self) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.sender + .send(Command::Rtt { result: tx }) + .await + .map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?; + + let rtt = rx + .await + // first handle rx error + .map_err(|err| RttError::with_source(RttErrorKind::PingError, err))? + // second handle the atual ping error + .map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?; + + Ok(rtt) + } + /// Returns the current state of the connection. /// /// # Examples @@ -688,3 +720,48 @@ impl From for RequestError { RequestError::with_source(RequestErrorKind::Other, e) } } + +/// Error returned when doing a round-trip time measurement fails. +/// To enumerate over the variants, call [RttError::kind]. +#[derive(Debug, Error)] +pub struct RttError { + kind: RttErrorKind, + source: Option>, +} + +impl Display for RttError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let source_info = self + .source + .as_ref() + .map(|e| e.to_string()) + .unwrap_or_else(|| "no details".into()); + match self.kind { + RttErrorKind::PingError => { + write!(f, "failed to ping server: {}", source_info) + } + RttErrorKind::Other => write!(f, "rtt failed: {}", source_info), + } + } +} + +impl RttError { + fn with_source(kind: RttErrorKind, source: E) -> RttError + where + E: Into>, + { + RttError { + kind, + source: Some(source.into()), + } + } + pub fn kind(&self) -> RttErrorKind { + self.kind + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum RttErrorKind { + PingError, + Other, +} diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 58ada931b..2080d8bf3 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -105,6 +105,7 @@ use thiserror::Error; use futures::future::FutureExt; use futures::select; use futures::stream::Stream; +use std::time::Instant; use tracing::{debug, error}; use core::fmt; @@ -261,6 +262,9 @@ pub enum Command { }, TryFlush, Connect(ConnectInfo), + Rtt { + result: oneshot::Sender>, + }, } /// `ClientOp` represents all actions of `Client`. @@ -301,10 +305,13 @@ pub(crate) struct ConnectionHandler { connector: Connector, subscriptions: HashMap, pending_pings: usize, + pending_pongs: usize, max_pings: usize, info_sender: tokio::sync::watch::Sender, ping_interval: Interval, flush_interval: Interval, + rtt_instant: Option, + rtt_sender: Option>>, } impl ConnectionHandler { @@ -326,10 +333,13 @@ impl ConnectionHandler { connector, subscriptions: HashMap::new(), pending_pings: 0, + pending_pongs: 0, max_pings: 2, info_sender, ping_interval, flush_interval, + rtt_instant: None, + rtt_sender: None, } } @@ -398,6 +408,18 @@ impl ConnectionHandler { ServerOp::Pong => { debug!("received PONG"); self.pending_pings = self.pending_pings.saturating_sub(1); + + if self.pending_pongs == 1 { + if let (Some(sender), Some(rtt)) = (self.rtt_sender.take(), self.rtt_instant) { + sender.send(Ok(rtt.elapsed())).map_err(|_| { + io::Error::new(io::ErrorKind::Other, "one shot failed to be received") + })?; + } + + // reset the pending pongs (we have at most 1 at any given moment to measure rtt) + self.pending_pongs = 0; + self.rtt_instant = None; + } } ServerOp::Error(error) => { self.connector @@ -509,26 +531,17 @@ impl ConnectionHandler { } } Command::Ping => { - debug!( - "PING command. Pending pings {}, max pings {}", - self.pending_pings, self.max_pings - ); - self.pending_pings += 1; - self.ping_interval.reset(); - - if self.pending_pings > self.max_pings { - debug!( - "pending pings {}, max pings {}. disconnecting", - self.pending_pings, self.max_pings - ); - self.handle_disconnect().await?; - } - - if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { - self.handle_disconnect().await?; + self.handle_ping().await?; + } + Command::Rtt { result } => { + self.rtt_sender = Some(result); + + if self.pending_pongs == 0 { + // start the clock for calculating round trip time + self.rtt_instant = Some(Instant::now()); + // do a ping and stop clock when handling pong + self.handle_ping().await?; } - - self.handle_flush().await?; } Command::Flush { result } => { if let Err(_err) = self.handle_flush().await { @@ -613,8 +626,37 @@ impl ConnectionHandler { Ok(()) } + async fn handle_ping(&mut self) -> Result<(), io::Error> { + debug!( + "PING command. Pending pings {}, max pings {}", + self.pending_pings, self.max_pings + ); + self.pending_pings += 1; + self.ping_interval.reset(); + + if self.pending_pongs == 0 { + self.pending_pongs = 1; + } + + if self.pending_pings > self.max_pings { + debug!( + "pending pings {}, max pings {}. disconnecting", + self.pending_pings, self.max_pings + ); + self.handle_disconnect().await?; + } + + if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { + self.handle_disconnect().await?; + } + + self.handle_flush().await?; + Ok(()) + } + async fn handle_disconnect(&mut self) -> io::Result<()> { self.pending_pings = 0; + self.pending_pongs = 0; self.connector.events_tx.try_send(Event::Disconnected).ok(); self.connector.state_tx.send(State::Disconnected).ok(); self.handle_reconnect().await?; diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index 538b78d2a..c4e244fb9 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -764,4 +764,15 @@ mod client { drop(servers.remove(0)); rx.recv().await; } + + #[tokio::test] + async fn rtt() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let rtt = client.rtt().await.unwrap(); + + println!("rtt: {:?}", rtt); + assert!(rtt.as_nanos() > 0); + } }