Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions src/commands/interactive/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,31 @@ use russh::client::Msg;
use russh::Channel;
use std::io::{self, Write};
use tokio::time::{timeout, Duration};
use zeroize::Zeroizing;

use crate::jump::{parse_jump_hosts, JumpHostChain};
use crate::node::Node;
use crate::ssh::{
known_hosts::get_check_method,
tokio_client::{AuthMethod, Client, ServerCheckMethod},
tokio_client::{AuthMethod, Client, Error as SshError, ServerCheckMethod},
};

use super::types::{InteractiveCommand, NodeSession};

impl InteractiveCommand {
/// Helper function to establish SSH connection with proper error handling and rate limiting
/// This eliminates code duplication across different connection paths and prevents brute-force attacks
///
/// If `allow_password_fallback` is true and key authentication fails, it will prompt for password
/// and retry with password authentication (matching OpenSSH behavior).
async fn establish_connection(
addr: (&str, u16),
username: &str,
auth_method: AuthMethod,
check_method: ServerCheckMethod,
host: &str,
port: u16,
allow_password_fallback: bool,
) -> Result<Client> {
const SSH_CONNECT_TIMEOUT_SECS: u64 = 30;
let connect_timeout = Duration::from_secs(SSH_CONNECT_TIMEOUT_SECS);
Expand All @@ -56,15 +61,47 @@ impl InteractiveCommand {

let result = timeout(
connect_timeout,
Client::connect(addr, username, auth_method, check_method),
Client::connect(addr, username, auth_method, check_method.clone()),
)
.await
.with_context(|| {
format!(
"Connection timeout: Failed to connect to {host}:{port} after {SSH_CONNECT_TIMEOUT_SECS} seconds"
)
})?
.with_context(|| format!("SSH connection failed to {host}:{port}"));
})?;

// Check if key authentication failed and password fallback is allowed
let result = match result {
Err(SshError::KeyAuthFailed)
if allow_password_fallback && atty::is(atty::Stream::Stdin) =>
{
tracing::debug!(
"SSH key authentication failed for {username}@{host}:{port}, attempting password fallback"
);

// Prompt for password (matching OpenSSH behavior)
let password = Self::prompt_password(username, host).await?;

// Retry with password authentication
let password_auth = AuthMethod::with_password(&password);

// Small delay before retry to prevent rapid attempts
tokio::time::sleep(Duration::from_millis(500)).await;

timeout(
connect_timeout,
Client::connect(addr, username, password_auth, check_method),
)
.await
.with_context(|| {
format!(
"Connection timeout: Failed to connect to {host}:{port} after {SSH_CONNECT_TIMEOUT_SECS} seconds"
)
})?
.with_context(|| format!("SSH connection failed to {host}:{port}"))
}
other => other.with_context(|| format!("SSH connection failed to {host}:{port}")),
};

// SECURITY: Normalize timing to prevent timing attacks
// Ensure all authentication attempts take at least 500ms to complete
Expand All @@ -79,6 +116,22 @@ impl InteractiveCommand {
result
}

/// Prompt for password with secure handling
async fn prompt_password(username: &str, host: &str) -> Result<Zeroizing<String>> {
let username = username.to_string();
let host = host.to_string();

tokio::task::spawn_blocking(move || {
let password = Zeroizing::new(
rpassword::prompt_password(format!("{username}@{host}'s password: "))
.with_context(|| "Failed to read password")?,
);
Ok(password)
})
.await
.with_context(|| "Password prompt task failed")?
}

/// Determine authentication method based on node and config (same logic as exec mode)
pub(super) async fn determine_auth_method(&self, node: &Node) -> Result<AuthMethod> {
// Use centralized authentication logic from auth module
Expand Down Expand Up @@ -164,13 +217,15 @@ impl InteractiveCommand {
tracing::debug!("No valid jump hosts found, using direct connection");

// Use the helper function to establish connection
// Enable password fallback for interactive mode (matches OpenSSH behavior)
Self::establish_connection(
addr,
&node.username,
auth_method.clone(),
check_method.clone(),
&node.host,
node.port,
!self.use_password, // Allow fallback unless explicit password mode
)
.await?
} else {
Expand Down Expand Up @@ -239,13 +294,15 @@ impl InteractiveCommand {
tracing::debug!("Using direct connection (no jump hosts)");

// Use the helper function to establish connection
// Enable password fallback for interactive mode (matches OpenSSH behavior)
Self::establish_connection(
addr,
&node.username,
auth_method,
check_method,
&node.host,
node.port,
!self.use_password, // Allow fallback unless explicit password mode
)
.await?
};
Expand Down Expand Up @@ -300,13 +357,15 @@ impl InteractiveCommand {
tracing::debug!("No valid jump hosts found, using direct connection for PTY");

// Use the helper function to establish connection
// Enable password fallback for interactive mode (matches OpenSSH behavior)
Self::establish_connection(
addr,
&node.username,
auth_method.clone(),
check_method.clone(),
&node.host,
node.port,
!self.use_password, // Allow fallback unless explicit password mode
)
.await?
} else {
Expand Down Expand Up @@ -375,13 +434,15 @@ impl InteractiveCommand {
tracing::debug!("Using direct connection for PTY (no jump hosts)");

// Use the helper function to establish connection
// Enable password fallback for interactive mode (matches OpenSSH behavior)
Self::establish_connection(
addr,
&node.username,
auth_method,
check_method,
&node.host,
node.port,
!self.use_password, // Allow fallback unless explicit password mode
)
.await?
};
Expand Down
11 changes: 8 additions & 3 deletions src/ssh/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,20 @@ impl AuthContext {
}
}

// Priority 3: Key file authentication
// Priority 3: Key file authentication (explicit -i flag)
if let Some(ref key_path) = self.key_path {
return self.key_file_auth(key_path).await;
}

// Priority 4: SSH agent auto-detection (if use_agent is true)
// Priority 4: SSH agent auto-detection (like OpenSSH behavior)
// OpenSSH tries SSH agent first when available, as it can try all registered keys
#[cfg(not(target_os = "windows"))]
if self.use_agent {
if !self.use_agent {
// Auto-detect SSH agent even without --use-agent flag
if let Some(auth) = self.agent_auth()? {
tracing::debug!(
"Using SSH agent (auto-detected) - agent will try all registered keys"
);
return Ok(auth);
}
}
Expand Down
21 changes: 15 additions & 6 deletions src/utils/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,24 @@
use tracing_subscriber::EnvFilter;

pub fn init_logging(verbosity: u8) {
let filter = match verbosity {
0 => EnvFilter::new("bssh=warn"),
1 => EnvFilter::new("bssh=info"),
2 => EnvFilter::new("bssh=debug"),
_ => EnvFilter::new("bssh=trace"),
// Priority: RUST_LOG environment variable > verbosity flag
let filter = if std::env::var("RUST_LOG").is_ok() {
// Use RUST_LOG if set (allows debugging russh and other dependencies)
EnvFilter::from_default_env()
} else {
// Fall back to verbosity-based filter
match verbosity {
0 => EnvFilter::new("bssh=warn"),
1 => EnvFilter::new("bssh=info"),
// -vv: Include russh debug logs for SSH troubleshooting
2 => EnvFilter::new("bssh=debug,russh=debug"),
// -vvv: Full trace including all dependencies
_ => EnvFilter::new("bssh=trace,russh=trace,russh_sftp=debug"),
}
};

tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.with_target(true) // Show module targets for better debugging
.init();
}