diff --git a/Cargo.lock b/Cargo.lock index 4ce0dcb4..9ba06f56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -443,6 +443,7 @@ dependencies = [ "shell-words", "signal-hook 0.4.3", "smallvec", + "socket2", "ssh-key", "tempfile", "terminal_size", diff --git a/Cargo.toml b/Cargo.toml index 450b2fbe..954dd790 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ lru = "0.16.2" uuid = { version = "1.23.0", features = ["v4"] } fastrand = "2.3.0" tokio-util = "0.7.17" +socket2 = "0.6" shell-words = "1.1.1" libc = "0.2" ipnetwork = "0.21" diff --git a/src/executor/connection_manager.rs b/src/executor/connection_manager.rs index d2e21524..9eb02f15 100644 --- a/src/executor/connection_manager.rs +++ b/src/executor/connection_manager.rs @@ -42,9 +42,8 @@ pub(crate) struct ExecutionConfig<'a> { pub sudo_password: Option>, pub ssh_config: Option<&'a SshConfig>, /// SSH connection configuration (keepalive settings). - /// Note: This field is currently passed through the executor for future use. - /// Keepalive is applied at the Client::connect_with_ssh_config level. - #[allow(dead_code)] + /// Threaded through to `Client::connect_with_ssh_config` so user-configured + /// `server_alive_interval` / `server_alive_count_max` apply to exec mode. pub ssh_connection_config: Option<&'a SshConnectionConfig>, } @@ -82,6 +81,7 @@ pub(crate) async fn execute_on_node_with_jump_hosts( timeout_seconds: config.timeout, connect_timeout_seconds: config.connect_timeout, jump_hosts_spec: effective_jump_hosts, + ssh_connection_config: config.ssh_connection_config, }; // If sudo password is provided, use streaming execution to handle prompts diff --git a/src/executor/parallel.rs b/src/executor/parallel.rs index 216ed331..07eea415 100644 --- a/src/executor/parallel.rs +++ b/src/executor/parallel.rs @@ -1132,6 +1132,7 @@ impl ParallelExecutor { let jump_hosts = self.jump_hosts.clone(); let sudo_password = self.sudo_password.clone(); let semaphore = Arc::clone(&semaphore); + let ssh_connection_config = self.ssh_connection_config.clone(); let handle = tokio::spawn(async move { // Use defer pattern to ensure cleanup even on panic @@ -1177,6 +1178,7 @@ impl ParallelExecutor { timeout_seconds: timeout, connect_timeout_seconds: connect_timeout, jump_hosts_spec: jump_hosts.as_deref(), + ssh_connection_config: Some(&ssh_connection_config), }; // Execute with or without sudo password support diff --git a/src/ssh/client/command.rs b/src/ssh/client/command.rs index c874e54e..8ab10d87 100644 --- a/src/ssh/client/command.rs +++ b/src/ssh/client/command.rs @@ -62,6 +62,7 @@ impl SshClient { timeout_seconds, connect_timeout_seconds: None, // Use default jump_hosts_spec: None, // No jump hosts + ssh_connection_config: None, }; self.connect_and_execute_with_jump_hosts(command, &config) @@ -101,6 +102,7 @@ impl SshClient { config.use_agent, config.use_password, config.connect_timeout_seconds, + config.ssh_connection_config, ) .await?; @@ -211,6 +213,7 @@ impl SshClient { config.use_agent, config.use_password, config.connect_timeout_seconds, + config.ssh_connection_config, ) .await?; @@ -322,6 +325,7 @@ impl SshClient { config.use_agent, config.use_password, config.connect_timeout_seconds, + config.ssh_connection_config, ) .await?; diff --git a/src/ssh/client/config.rs b/src/ssh/client/config.rs index a134082b..1bd7f45c 100644 --- a/src/ssh/client/config.rs +++ b/src/ssh/client/config.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::ssh::known_hosts::StrictHostKeyChecking; +use crate::ssh::tokio_client::SshConnectionConfig; use std::path::Path; /// Configuration for SSH connection and command execution @@ -27,4 +28,6 @@ pub struct ConnectionConfig<'a> { pub timeout_seconds: Option, pub connect_timeout_seconds: Option, pub jump_hosts_spec: Option<&'a str>, + /// SSH keepalive / inactivity settings. `None` falls back to defaults. + pub ssh_connection_config: Option<&'a SshConnectionConfig>, } diff --git a/src/ssh/client/connection.rs b/src/ssh/client/connection.rs index 55327894..c5ae8ace 100644 --- a/src/ssh/client/connection.rs +++ b/src/ssh/client/connection.rs @@ -15,7 +15,7 @@ use super::core::SshClient; use crate::jump::{parse_jump_hosts, JumpHostChain}; use crate::ssh::known_hosts::StrictHostKeyChecking; -use crate::ssh::tokio_client::{AuthMethod, Client}; +use crate::ssh::tokio_client::{AuthMethod, Client, SshConnectionConfig}; use anyhow::{Context, Result}; use std::path::Path; use std::time::Duration; @@ -65,6 +65,7 @@ impl SshClient { auth_method: &AuthMethod, strict_mode: StrictHostKeyChecking, connect_timeout_seconds: Option, + ssh_connection_config: Option<&SshConnectionConfig>, ) -> Result { // SECURITY: Add rate limiting before connection attempts const RATE_LIMIT_DELAY: Duration = Duration::from_millis(100); @@ -79,9 +80,24 @@ impl SshClient { let connect_timeout = Duration::from_secs(connect_timeout_seconds.unwrap_or(SSH_CONNECT_TIMEOUT_SECS)); + let default_conn_cfg; + let conn_cfg = match ssh_connection_config { + Some(c) => c, + None => { + default_conn_cfg = SshConnectionConfig::default(); + &default_conn_cfg + } + }; + let result = match tokio::time::timeout( connect_timeout, - Client::connect(addr, &self.username, auth_method.clone(), check_method), + Client::connect_with_ssh_config( + addr, + &self.username, + auth_method.clone(), + check_method, + conn_cfg, + ), ) .await { @@ -148,13 +164,17 @@ impl SshClient { use_agent: bool, use_password: bool, connect_timeout_seconds: Option, + ssh_connection_config: Option<&SshConnectionConfig>, ) -> Result { // Create jump host chain with user-specified or default connect timeout let connect_timeout = Duration::from_secs(connect_timeout_seconds.unwrap_or(SSH_CONNECT_TIMEOUT_SECS)); - let chain = JumpHostChain::new(jump_hosts.to_vec()) + let mut chain = JumpHostChain::new(jump_hosts.to_vec()) .with_connect_timeout(connect_timeout) .with_command_timeout(Duration::from_secs(300)); + if let Some(cfg) = ssh_connection_config { + chain = chain.with_ssh_connection_config(cfg.clone()); + } // Connect through the chain let connection = chain @@ -195,6 +215,7 @@ impl SshClient { use_agent: bool, use_password: bool, connect_timeout_seconds: Option, + ssh_connection_config: Option<&SshConnectionConfig>, ) -> Result { if let Some(jump_spec) = jump_hosts_spec { // Parse jump hosts @@ -204,8 +225,13 @@ impl SshClient { if jump_hosts.is_empty() { tracing::debug!("No valid jump hosts found, using direct connection"); - self.connect_direct(auth_method, strict_mode, connect_timeout_seconds) - .await + self.connect_direct( + auth_method, + strict_mode, + connect_timeout_seconds, + ssh_connection_config, + ) + .await } else { tracing::info!( "Connecting to {}:{} via {} jump host(s): {}", @@ -227,14 +253,20 @@ impl SshClient { use_agent, use_password, connect_timeout_seconds, + ssh_connection_config, ) .await } } else { // Direct connection tracing::debug!("Using direct connection (no jump hosts)"); - self.connect_direct(auth_method, strict_mode, connect_timeout_seconds) - .await + self.connect_direct( + auth_method, + strict_mode, + connect_timeout_seconds, + ssh_connection_config, + ) + .await } } } diff --git a/src/ssh/client/file_transfer.rs b/src/ssh/client/file_transfer.rs index 2fe35891..900b31fe 100644 --- a/src/ssh/client/file_transfer.rs +++ b/src/ssh/client/file_transfer.rs @@ -700,6 +700,7 @@ impl SshClient { use_agent, use_password, connect_timeout_seconds, + None, ) .await } diff --git a/src/ssh/tokio_client/connection.rs b/src/ssh/tokio_client/connection.rs index b475770b..48121e46 100644 --- a/src/ssh/tokio_client/connection.rs +++ b/src/ssh/tokio_client/connection.rs @@ -100,13 +100,53 @@ impl SshConnectionConfig { } /// Convert this configuration to a russh client Config. + /// + /// When keepalive is enabled, `inactivity_timeout` is set to `None` so the + /// keepalive mechanism is the sole dead-peer detector. russh's default + /// `inactivity_timeout` is 10 minutes and would otherwise tear down an + /// otherwise-healthy idle session at that mark regardless of keepalive + /// liveness. When keepalive is disabled, we preserve a generous + /// inactivity timeout so truly dead sockets are still reaped. pub fn to_russh_config(&self) -> Config { + let inactivity_timeout = if self.keepalive_interval.is_some() { + None + } else { + Some(Duration::from_secs(3600)) + }; Config { keepalive_interval: self.keepalive_interval.map(Duration::from_secs), keepalive_max: self.keepalive_max, + inactivity_timeout, ..Default::default() } } + + /// Derive a TCP-level keepalive configuration from this SSH keepalive + /// configuration. Returns `None` if SSH keepalive is disabled. + /// + /// TCP keepalive is a belt-and-suspenders mechanism: it lets the kernel + /// detect a broken TCP path even when no application data is flowing and + /// even if SSH-level keepalive replies are dropped by a middlebox. + pub fn to_tcp_keepalive(&self) -> Option { + let interval = self.keepalive_interval?; + // Start probing after `interval` seconds of idleness, probe every + // half-interval, up to keepalive_max retries. + let probe_interval = (interval / 2).max(1); + let ka = socket2::TcpKeepalive::new() + .with_time(Duration::from_secs(interval)) + .with_interval(Duration::from_secs(probe_interval)); + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "freebsd", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "ios", + ))] + let ka = ka.with_retries(self.keepalive_max.max(1) as u32); + Some(ka) + } } use super::ToSocketAddrsWithHostname; @@ -213,7 +253,16 @@ impl Client { ssh_config: &SshConnectionConfig, ) -> Result { let config = ssh_config.to_russh_config(); - Self::connect_with_config(addr, username, auth, server_check, config).await + let tcp_keepalive = ssh_config.to_tcp_keepalive(); + Self::connect_with_config_inner( + addr, + username, + auth, + server_check, + config, + tcp_keepalive.as_ref(), + ) + .await } /// Same as `connect`, but with the option to specify a non default @@ -227,6 +276,17 @@ impl Client { auth: AuthMethod, server_check: ServerCheckMethod, config: Config, + ) -> Result { + Self::connect_with_config_inner(addr, username, auth, server_check, config, None).await + } + + async fn connect_with_config_inner( + addr: impl ToSocketAddrsWithHostname, + username: &str, + auth: AuthMethod, + server_check: ServerCheckMethod, + config: Config, + tcp_keepalive: Option<&socket2::TcpKeepalive>, ) -> Result { let config = Arc::new(config); @@ -234,7 +294,10 @@ impl Client { let socket_addrs = addr .to_socket_addrs() .map_err(super::Error::AddressInvalid)?; - let mut connect_res = Err(super::Error::AddressInvalid(io::Error::new( + let mut connect_res: Result< + (SocketAddr, russh::client::Handle), + super::Error, + > = Err(super::Error::AddressInvalid(io::Error::new( io::ErrorKind::InvalidInput, "could not resolve to any addresses", ))); @@ -244,7 +307,27 @@ impl Client { host: socket_addr, server_check: server_check.clone(), }; - match russh::client::connect(config.clone(), socket_addr, handler).await { + + let stream = match tokio::net::TcpStream::connect(socket_addr).await { + Ok(s) => s, + Err(e) => { + connect_res = Err(super::Error::IoError(e)); + continue; + } + }; + + if let Some(ka) = tcp_keepalive { + let sock_ref = socket2::SockRef::from(&stream); + if let Err(e) = sock_ref.set_tcp_keepalive(ka) { + tracing::debug!( + "Failed to set TCP keepalive on socket to {}: {}", + socket_addr, + e + ); + } + } + + match russh::client::connect_stream(config.clone(), stream, handler).await { Ok(h) => { connect_res = Ok((socket_addr, h)); break;