Skip to content

Commit

Permalink
Add rtt to Client
Browse files Browse the repository at this point in the history
  • Loading branch information
n1ghtmare committed Jul 24, 2023
1 parent 625d1da commit c0dfa53
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .config/nats.dic
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,12 @@ RequestErrorKind
rustls
Acker
EndpointSchema
<<<<<<< HEAD
auth
filter_subject
filter_subjects
rollup
IoT
=======
RttError
>>>>>>> 85121a7 (Add `rtt` to `Client`)
40 changes: 40 additions & 0 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,35 @@ impl Client {
Ok(())
}

/// Calculates the round trip time between this client and the server,
/// if the server is currently connected.
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let rtt = client.rtt().await?;
/// println!("server rtt: {:?}", rtt);
/// # Ok(())
/// # }
/// ```
pub async fn rtt(&self) -> Result<Duration, RttError> {
let (tx, rx) = tokio::sync::oneshot::channel();

self.sender.send(Command::Rtt { result: tx }).await?;

let rtt = rx
.await
// first handle rx error
.map_err(|err| RttError(Box::new(err)))?
// second handle the actual rtt error
.map_err(|err| RttError(Box::new(err)))?;

Ok(rtt)
}

/// Returns the current state of the connection.
///
/// # Examples
Expand Down Expand Up @@ -684,3 +713,14 @@ impl From<SubscribeError> for RequestError {
RequestError::with_source(RequestErrorKind::Other, e)
}
}

/// Error returned when doing a round-trip time measurement fails.
#[derive(Debug, Error)]
#[error("failed to measure round-trip time: {0}")]
pub struct RttError(#[source] Box<dyn std::error::Error + Send + Sync>);

impl From<tokio::sync::mpsc::error::SendError<Command>> for RttError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
RttError(Box::new(err))
}
}
65 changes: 65 additions & 0 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ use thiserror::Error;
use futures::future::FutureExt;
use futures::select;
use futures::stream::Stream;
use std::time::Instant;
use tracing::{debug, error};

use core::fmt;
Expand Down Expand Up @@ -280,6 +281,9 @@ pub(crate) enum Command {
result: oneshot::Sender<Result<(), io::Error>>,
},
TryFlush,
Rtt {
result: oneshot::Sender<Result<Duration, io::Error>>,
},
}

/// `ClientOp` represents all actions of `Client`.
Expand Down Expand Up @@ -323,6 +327,9 @@ pub(crate) struct ConnectionHandler {
info_sender: tokio::sync::watch::Sender<ServerInfo>,
ping_interval: Interval,
flush_interval: Interval,
last_ping_time: Option<Instant>,
last_pong_time: Option<Instant>,
rtt_senders: Vec<oneshot::Sender<Result<Duration, io::Error>>>,
}

impl ConnectionHandler {
Expand All @@ -347,6 +354,9 @@ impl ConnectionHandler {
info_sender,
ping_interval,
flush_interval,
last_ping_time: None,
last_pong_time: None,
rtt_senders: Vec::new(),
}
}

Expand Down Expand Up @@ -425,6 +435,22 @@ impl ConnectionHandler {
}
ServerOp::Pong => {
debug!("received PONG");
if self.pending_pings == 1 {
self.last_pong_time = Some(Instant::now());

while let Some(sender) = self.rtt_senders.pop() {
if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time)
{
let rtt = pong.duration_since(ping);
sender.send(Ok(rtt)).map_err(|_| {
io::Error::new(
io::ErrorKind::Other,
"one shot failed to be received",
)
})?;
}
}
}
self.pending_pings = self.pending_pings.saturating_sub(1);
}
ServerOp::Error(error) => {
Expand Down Expand Up @@ -538,6 +564,14 @@ impl ConnectionHandler {
}
}
}
Command::Rtt { result } => {
self.rtt_senders.push(result);

if self.pending_pings == 0 {
// do a ping and expect a pong - will calculate rtt when handling the pong
self.handle_ping().await?;
}
}
Command::Flush { result } => {
if let Err(_err) = self.handle_flush().await {
if let Err(err) = self.handle_disconnect().await {
Expand Down Expand Up @@ -612,8 +646,39 @@ impl ConnectionHandler {
Ok(())
}

async fn handle_ping(&mut self) -> Result<(), io::Error> {
debug!(
"PING command. Pending pings {}, max pings {}",
self.pending_pings, MAX_PENDING_PINGS
);
self.pending_pings += 1;
self.ping_interval.reset();

if self.pending_pings > MAX_PENDING_PINGS {
debug!(
"pending pings {}, max pings {}. disconnecting",
self.pending_pings, MAX_PENDING_PINGS
);
self.handle_disconnect().await?;
}

if self.pending_pings == 1 {
// start the clock for calculating round trip time
self.last_ping_time = Some(Instant::now());
}

if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
self.handle_disconnect().await?;
}

self.handle_flush().await?;
Ok(())
}

async fn handle_disconnect(&mut self) -> io::Result<()> {
self.pending_pings = 0;
self.last_ping_time = None;
self.last_pong_time = None;
self.connector.events_tx.try_send(Event::Disconnected).ok();
self.connector.state_tx.send(State::Disconnected).ok();
self.handle_reconnect().await?;
Expand Down
11 changes: 11 additions & 0 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,4 +867,15 @@ mod client {
.await
.unwrap();
}

#[tokio::test]
async fn rtt() {
let server = nats_server::run_basic_server();
let client = async_nats::connect(server.client_url()).await.unwrap();

let rtt = client.rtt().await.unwrap();

println!("rtt: {:?}", rtt);
assert!(rtt.as_nanos() > 0);
}
}

0 comments on commit c0dfa53

Please sign in to comment.