Skip to content

Commit

Permalink
feat(server): add Server::tcp_keepalive_interval and `Server::tcp_k…
Browse files Browse the repository at this point in the history
…eepalive_retries` (#2991)

If the platform supports setting the options, otherwise it's a noop.
  • Loading branch information
hansonchar authored Sep 23, 2022
1 parent 0ff6213 commit 287d712
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 18 deletions.
23 changes: 18 additions & 5 deletions src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::error::Error as StdError;
use std::fmt;
#[cfg(feature = "tcp")]
use std::net::{SocketAddr, TcpListener as StdTcpListener};
#[cfg(any(feature = "tcp", feature = "http1"))]

#[cfg(feature = "tcp")]
use std::time::Duration;

use pin_project_lite::pin_project;

use tokio::io::{AsyncRead, AsyncWrite};
use tracing::trace;

Expand Down Expand Up @@ -559,16 +561,27 @@ impl<I, E> Builder<I, E> {
doc(cfg(all(feature = "tcp", any(feature = "http1", feature = "http2"))))
)]
impl<E> Builder<AddrIncoming, E> {
/// Set whether TCP keepalive messages are enabled on accepted connections.
/// Set the duration to remain idle before sending TCP keepalive probes.
///
/// If `None` is specified, keepalive is disabled, otherwise the duration
/// specified will be the time to remain idle before sending TCP keepalive
/// probes.
/// If `None` is specified, keepalive is disabled.
pub fn tcp_keepalive(mut self, keepalive: Option<Duration>) -> Self {
self.incoming.set_keepalive(keepalive);
self
}

/// Set the duration between two successive TCP keepalive retransmissions,
/// if acknowledgement to the previous keepalive transmission is not received.
pub fn tcp_keepalive_interval(mut self, interval: Option<Duration>) -> Self {
self.incoming.set_keepalive_interval(interval);
self
}

/// Set the number of retransmissions to be carried out before declaring that remote end is not available.
pub fn tcp_keepalive_retries(mut self, retries: Option<u32>) -> Self {
self.incoming.set_keepalive_retries(retries);
self
}

/// Set the value of `TCP_NODELAY` option for accepted connections.
pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
self.incoming.set_nodelay(enabled);
Expand Down
135 changes: 122 additions & 13 deletions src/server/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt;
use std::io;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::time::Duration;
use socket2::TcpKeepalive;

use tokio::net::TcpListener;
use tokio::time::Sleep;
Expand All @@ -13,13 +14,65 @@ use crate::common::{task, Future, Pin, Poll};
pub use self::addr_stream::AddrStream;
use super::accept::Accept;

#[derive(Default, Debug, Clone, Copy)]
struct TcpKeepaliveConfig {
time: Option<Duration>,
interval: Option<Duration>,
retries: Option<u32>,
}

impl TcpKeepaliveConfig {
/// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
fn into_socket2(self) -> Option<TcpKeepalive> {
let mut dirty = false;
let mut ka = TcpKeepalive::new();
if let Some(time) = self.time {
ka = ka.with_time(time);
dirty = true
}
if let Some(interval) = self.interval {
ka = Self::ka_with_interval(ka, interval, &mut dirty)
};
if let Some(retries) = self.retries {
ka = Self::ka_with_retries(ka, retries, &mut dirty)
};
if dirty {
Some(ka)
} else {
None
}
}

#[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
*dirty = true;
ka.with_interval(interval)
}

#[cfg(any(target_os = "openbsd", target_os = "redox", target_os = "solaris"))]
fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
ka // no-op as keepalive interval is not supported on this platform
}

#[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris", target_os = "windows")))]
fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
*dirty = true;
ka.with_retries(retries)
}

#[cfg(any(target_os = "openbsd", target_os = "redox", target_os = "solaris", target_os = "windows"))]
fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
ka // no-op as keepalive retries is not supported on this platform
}
}

/// A stream of connections from binding to an address.
#[must_use = "streams do nothing unless polled"]
pub struct AddrIncoming {
addr: SocketAddr,
listener: TcpListener,
sleep_on_errors: bool,
tcp_keepalive_timeout: Option<Duration>,
tcp_keepalive_config: TcpKeepaliveConfig,
tcp_nodelay: bool,
timeout: Option<Pin<Box<Sleep>>>,
}
Expand Down Expand Up @@ -52,7 +105,7 @@ impl AddrIncoming {
listener,
addr,
sleep_on_errors: true,
tcp_keepalive_timeout: None,
tcp_keepalive_config: TcpKeepaliveConfig::default(),
tcp_nodelay: false,
timeout: None,
})
Expand All @@ -63,13 +116,24 @@ impl AddrIncoming {
self.addr
}

/// Set whether TCP keepalive messages are enabled on accepted connections.
/// Set the duration to remain idle before sending TCP keepalive probes.
///
/// If `None` is specified, keepalive is disabled, otherwise the duration
/// specified will be the time to remain idle before sending TCP keepalive
/// probes.
pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self {
self.tcp_keepalive_timeout = keepalive;
/// If `None` is specified, keepalive is disabled.
pub fn set_keepalive(&mut self, time: Option<Duration>) -> &mut Self {
self.tcp_keepalive_config.time = time;
self
}

/// Set the duration between two successive TCP keepalive retransmissions,
/// if acknowledgement to the previous keepalive transmission is not received.
pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self {
self.tcp_keepalive_config.interval = interval;
self
}

/// Set the number of retransmissions to be carried out before declaring that remote end is not available.
pub fn set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self {
self.tcp_keepalive_config.retries = retries;
self
}

Expand Down Expand Up @@ -108,10 +172,9 @@ impl AddrIncoming {
loop {
match ready!(self.listener.poll_accept(cx)) {
Ok((socket, remote_addr)) => {
if let Some(dur) = self.tcp_keepalive_timeout {
let socket = socket2::SockRef::from(&socket);
let conf = socket2::TcpKeepalive::new().with_time(dur);
if let Err(e) = socket.set_tcp_keepalive(&conf) {
if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() {
let sock_ref = socket2::SockRef::from(&socket);
if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) {
trace!("error trying to set TCP keepalive: {}", e);
}
}
Expand Down Expand Up @@ -188,7 +251,7 @@ impl fmt::Debug for AddrIncoming {
f.debug_struct("AddrIncoming")
.field("addr", &self.addr)
.field("sleep_on_errors", &self.sleep_on_errors)
.field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout)
.field("tcp_keepalive_config", &self.tcp_keepalive_config)
.field("tcp_nodelay", &self.tcp_nodelay)
.finish()
}
Expand Down Expand Up @@ -316,3 +379,49 @@ mod addr_stream {
}
}
}

#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::server::tcp::TcpKeepaliveConfig;

#[test]
fn no_tcp_keepalive_config() {
assert!(TcpKeepaliveConfig::default().into_socket2().is_none());
}

#[test]
fn tcp_keepalive_time_config() {
let mut kac = TcpKeepaliveConfig::default();
kac.time = Some(Duration::from_secs(60));
if let Some(tcp_keepalive) = kac.into_socket2() {
assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
} else {
panic!("test failed");
}
}

#[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
#[test]
fn tcp_keepalive_interval_config() {
let mut kac = TcpKeepaliveConfig::default();
kac.interval = Some(Duration::from_secs(1));
if let Some(tcp_keepalive) = kac.into_socket2() {
assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
} else {
panic!("test failed");
}
}

#[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris", target_os = "windows")))]
#[test]
fn tcp_keepalive_retries_config() {
let mut kac = TcpKeepaliveConfig::default();
kac.retries = Some(3);
if let Some(tcp_keepalive) = kac.into_socket2() {
assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
} else {
panic!("test failed");
}
}
}

0 comments on commit 287d712

Please sign in to comment.