Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ lru = "0.16.2"
uuid = { version = "1.23.0", features = ["v4"] }
fastrand = "2.3.0"
tokio-util = "0.7.17"
socket2 = "0.6"
shell-words = "1.1.1"
libc = "0.2"
ipnetwork = "0.21"
Expand Down
6 changes: 3 additions & 3 deletions src/executor/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ pub(crate) struct ExecutionConfig<'a> {
pub sudo_password: Option<Arc<SudoPassword>>,
pub ssh_config: Option<&'a SshConfig>,
/// SSH connection configuration (keepalive settings).
/// Note: This field is currently passed through the executor for future use.
/// Keepalive is applied at the Client::connect_with_ssh_config level.
#[allow(dead_code)]
/// Threaded through to `Client::connect_with_ssh_config` so user-configured
/// `server_alive_interval` / `server_alive_count_max` apply to exec mode.
pub ssh_connection_config: Option<&'a SshConnectionConfig>,
}

Expand Down Expand Up @@ -82,6 +81,7 @@ pub(crate) async fn execute_on_node_with_jump_hosts(
timeout_seconds: config.timeout,
connect_timeout_seconds: config.connect_timeout,
jump_hosts_spec: effective_jump_hosts,
ssh_connection_config: config.ssh_connection_config,
};

// If sudo password is provided, use streaming execution to handle prompts
Expand Down
2 changes: 2 additions & 0 deletions src/executor/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ impl ParallelExecutor {
let jump_hosts = self.jump_hosts.clone();
let sudo_password = self.sudo_password.clone();
let semaphore = Arc::clone(&semaphore);
let ssh_connection_config = self.ssh_connection_config.clone();

let handle = tokio::spawn(async move {
// Use defer pattern to ensure cleanup even on panic
Expand Down Expand Up @@ -1177,6 +1178,7 @@ impl ParallelExecutor {
timeout_seconds: timeout,
connect_timeout_seconds: connect_timeout,
jump_hosts_spec: jump_hosts.as_deref(),
ssh_connection_config: Some(&ssh_connection_config),
};

// Execute with or without sudo password support
Expand Down
4 changes: 4 additions & 0 deletions src/ssh/client/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl SshClient {
timeout_seconds,
connect_timeout_seconds: None, // Use default
jump_hosts_spec: None, // No jump hosts
ssh_connection_config: None,
};

self.connect_and_execute_with_jump_hosts(command, &config)
Expand Down Expand Up @@ -101,6 +102,7 @@ impl SshClient {
config.use_agent,
config.use_password,
config.connect_timeout_seconds,
config.ssh_connection_config,
)
.await?;

Expand Down Expand Up @@ -211,6 +213,7 @@ impl SshClient {
config.use_agent,
config.use_password,
config.connect_timeout_seconds,
config.ssh_connection_config,
)
.await?;

Expand Down Expand Up @@ -322,6 +325,7 @@ impl SshClient {
config.use_agent,
config.use_password,
config.connect_timeout_seconds,
config.ssh_connection_config,
)
.await?;

Expand Down
3 changes: 3 additions & 0 deletions src/ssh/client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use crate::ssh::known_hosts::StrictHostKeyChecking;
use crate::ssh::tokio_client::SshConnectionConfig;
use std::path::Path;

/// Configuration for SSH connection and command execution
Expand All @@ -27,4 +28,6 @@ pub struct ConnectionConfig<'a> {
pub timeout_seconds: Option<u64>,
pub connect_timeout_seconds: Option<u64>,
pub jump_hosts_spec: Option<&'a str>,
/// SSH keepalive / inactivity settings. `None` falls back to defaults.
pub ssh_connection_config: Option<&'a SshConnectionConfig>,
}
46 changes: 39 additions & 7 deletions src/ssh/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use super::core::SshClient;
use crate::jump::{parse_jump_hosts, JumpHostChain};
use crate::ssh::known_hosts::StrictHostKeyChecking;
use crate::ssh::tokio_client::{AuthMethod, Client};
use crate::ssh::tokio_client::{AuthMethod, Client, SshConnectionConfig};
use anyhow::{Context, Result};
use std::path::Path;
use std::time::Duration;
Expand Down Expand Up @@ -65,6 +65,7 @@ impl SshClient {
auth_method: &AuthMethod,
strict_mode: StrictHostKeyChecking,
connect_timeout_seconds: Option<u64>,
ssh_connection_config: Option<&SshConnectionConfig>,
) -> Result<Client> {
// SECURITY: Add rate limiting before connection attempts
const RATE_LIMIT_DELAY: Duration = Duration::from_millis(100);
Expand All @@ -79,9 +80,24 @@ impl SshClient {
let connect_timeout =
Duration::from_secs(connect_timeout_seconds.unwrap_or(SSH_CONNECT_TIMEOUT_SECS));

let default_conn_cfg;
let conn_cfg = match ssh_connection_config {
Some(c) => c,
None => {
default_conn_cfg = SshConnectionConfig::default();
&default_conn_cfg
}
};

let result = match tokio::time::timeout(
connect_timeout,
Client::connect(addr, &self.username, auth_method.clone(), check_method),
Client::connect_with_ssh_config(
addr,
&self.username,
auth_method.clone(),
check_method,
conn_cfg,
),
)
.await
{
Expand Down Expand Up @@ -148,13 +164,17 @@ impl SshClient {
use_agent: bool,
use_password: bool,
connect_timeout_seconds: Option<u64>,
ssh_connection_config: Option<&SshConnectionConfig>,
) -> Result<Client> {
// Create jump host chain with user-specified or default connect timeout
let connect_timeout =
Duration::from_secs(connect_timeout_seconds.unwrap_or(SSH_CONNECT_TIMEOUT_SECS));
let chain = JumpHostChain::new(jump_hosts.to_vec())
let mut chain = JumpHostChain::new(jump_hosts.to_vec())
.with_connect_timeout(connect_timeout)
.with_command_timeout(Duration::from_secs(300));
if let Some(cfg) = ssh_connection_config {
chain = chain.with_ssh_connection_config(cfg.clone());
}

// Connect through the chain
let connection = chain
Expand Down Expand Up @@ -195,6 +215,7 @@ impl SshClient {
use_agent: bool,
use_password: bool,
connect_timeout_seconds: Option<u64>,
ssh_connection_config: Option<&SshConnectionConfig>,
) -> Result<Client> {
if let Some(jump_spec) = jump_hosts_spec {
// Parse jump hosts
Expand All @@ -204,8 +225,13 @@ impl SshClient {

if jump_hosts.is_empty() {
tracing::debug!("No valid jump hosts found, using direct connection");
self.connect_direct(auth_method, strict_mode, connect_timeout_seconds)
.await
self.connect_direct(
auth_method,
strict_mode,
connect_timeout_seconds,
ssh_connection_config,
)
.await
} else {
tracing::info!(
"Connecting to {}:{} via {} jump host(s): {}",
Expand All @@ -227,14 +253,20 @@ impl SshClient {
use_agent,
use_password,
connect_timeout_seconds,
ssh_connection_config,
)
.await
}
} else {
// Direct connection
tracing::debug!("Using direct connection (no jump hosts)");
self.connect_direct(auth_method, strict_mode, connect_timeout_seconds)
.await
self.connect_direct(
auth_method,
strict_mode,
connect_timeout_seconds,
ssh_connection_config,
)
.await
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ssh/client/file_transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ impl SshClient {
use_agent,
use_password,
connect_timeout_seconds,
None,
)
.await
}
Expand Down
89 changes: 86 additions & 3 deletions src/ssh/tokio_client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,53 @@ impl SshConnectionConfig {
}

/// Convert this configuration to a russh client Config.
///
/// When keepalive is enabled, `inactivity_timeout` is set to `None` so the
/// keepalive mechanism is the sole dead-peer detector. russh's default
/// `inactivity_timeout` is 10 minutes and would otherwise tear down an
/// otherwise-healthy idle session at that mark regardless of keepalive
/// liveness. When keepalive is disabled, we preserve a generous
/// inactivity timeout so truly dead sockets are still reaped.
pub fn to_russh_config(&self) -> Config {
let inactivity_timeout = if self.keepalive_interval.is_some() {
None
} else {
Some(Duration::from_secs(3600))
};
Config {
keepalive_interval: self.keepalive_interval.map(Duration::from_secs),
keepalive_max: self.keepalive_max,
inactivity_timeout,
..Default::default()
}
}

/// Derive a TCP-level keepalive configuration from this SSH keepalive
/// configuration. Returns `None` if SSH keepalive is disabled.
///
/// TCP keepalive is a belt-and-suspenders mechanism: it lets the kernel
/// detect a broken TCP path even when no application data is flowing and
/// even if SSH-level keepalive replies are dropped by a middlebox.
pub fn to_tcp_keepalive(&self) -> Option<socket2::TcpKeepalive> {
let interval = self.keepalive_interval?;
// Start probing after `interval` seconds of idleness, probe every
// half-interval, up to keepalive_max retries.
let probe_interval = (interval / 2).max(1);
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(interval))
.with_interval(Duration::from_secs(probe_interval));
#[cfg(any(
target_os = "linux",
target_os = "macos",
target_os = "freebsd",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "ios",
))]
let ka = ka.with_retries(self.keepalive_max.max(1) as u32);
Some(ka)
}
}
use super::ToSocketAddrsWithHostname;

Expand Down Expand Up @@ -213,7 +253,16 @@ impl Client {
ssh_config: &SshConnectionConfig,
) -> Result<Self, super::Error> {
let config = ssh_config.to_russh_config();
Self::connect_with_config(addr, username, auth, server_check, config).await
let tcp_keepalive = ssh_config.to_tcp_keepalive();
Self::connect_with_config_inner(
addr,
username,
auth,
server_check,
config,
tcp_keepalive.as_ref(),
)
.await
}

/// Same as `connect`, but with the option to specify a non default
Expand All @@ -227,14 +276,28 @@ impl Client {
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
) -> Result<Self, super::Error> {
Self::connect_with_config_inner(addr, username, auth, server_check, config, None).await
}

async fn connect_with_config_inner(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
tcp_keepalive: Option<&socket2::TcpKeepalive>,
) -> Result<Self, super::Error> {
let config = Arc::new(config);

// Connection code inspired from std::net::TcpStream::connect and std::net::each_addr
let socket_addrs = addr
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let mut connect_res = Err(super::Error::AddressInvalid(io::Error::new(
let mut connect_res: Result<
(SocketAddr, russh::client::Handle<ClientHandler>),
super::Error,
> = Err(super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)));
Expand All @@ -244,7 +307,27 @@ impl Client {
host: socket_addr,
server_check: server_check.clone(),
};
match russh::client::connect(config.clone(), socket_addr, handler).await {

let stream = match tokio::net::TcpStream::connect(socket_addr).await {
Ok(s) => s,
Err(e) => {
connect_res = Err(super::Error::IoError(e));
continue;
}
};

if let Some(ka) = tcp_keepalive {
let sock_ref = socket2::SockRef::from(&stream);
if let Err(e) = sock_ref.set_tcp_keepalive(ka) {
tracing::debug!(
"Failed to set TCP keepalive on socket to {}: {}",
socket_addr,
e
);
}
}

match russh::client::connect_stream(config.clone(), stream, handler).await {
Ok(h) => {
connect_res = Ok((socket_addr, h));
break;
Expand Down
Loading