diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 90b09d724c3..753b27a3f06 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -74,6 +74,11 @@ use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; +#[path = "streamable_http_retry.rs"] +mod streamable_http_retry; + +use self::streamable_http_retry::HandshakeError; + enum PendingTransport { InProcess { transport: tokio::io::DuplexStream, @@ -396,9 +401,13 @@ impl RmcpClient { } }; - let (service, oauth_persistor) = - Self::connect_pending_transport(pending_transport, client_service.clone(), timeout) - .await?; + let (service, oauth_persistor) = self + .connect_pending_transport_with_initialize_retries( + pending_transport, + client_service.clone(), + timeout, + ) + .await?; let initialize_result_rmcp = service .peer() @@ -849,10 +858,10 @@ impl RmcpClient { Some(duration) => time::timeout(duration, transport) .await .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + .map_err(|source| HandshakeError { source })?, None => transport .await - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + .map_err(|source| HandshakeError { source })?, }; Ok((Arc::new(service), oauth_persistor)) diff --git a/codex-rs/rmcp-client/src/streamable_http_retry.rs b/codex-rs/rmcp-client/src/streamable_http_retry.rs new file mode 100644 index 00000000000..ebca075cb8a --- /dev/null +++ b/codex-rs/rmcp-client/src/streamable_http_retry.rs @@ -0,0 +1,169 @@ +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use anyhow::Result; +use anyhow::anyhow; +use codex_exec_server::ExecServerError; +use rmcp::service::RoleClient; +use rmcp::service::RunningService; +use rmcp::transport::streamable_http_client::StreamableHttpError; +use tokio::time; +use tracing::warn; + +use crate::elicitation_client_service::ElicitationClientService; +use crate::http_client_adapter::StreamableHttpClientAdapterError; +use crate::oauth::OAuthPersistor; + +use super::PendingTransport; +use super::RmcpClient; + +const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; +const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; + +impl RmcpClient { + pub(super) async fn connect_pending_transport_with_initialize_retries( + &self, + initial_transport: PendingTransport, + client_service: ElicitationClientService, + timeout: Option, + ) -> Result<( + Arc>, + Option, + )> { + let should_retry = match &initial_transport { + PendingTransport::InProcess { .. } | PendingTransport::Stdio { .. } => false, + PendingTransport::StreamableHttp { .. } + | PendingTransport::StreamableHttpWithOAuth { .. } => true, + }; + let retry_deadline = timeout.map(|duration| Instant::now() + duration); + let mut pending_transport = Some(initial_transport); + + for (attempt, retry_delay_ms) in STREAMABLE_HTTP_RETRY_DELAYS_MS + .iter() + .copied() + .map(Some) + .chain(std::iter::once(None)) + .enumerate() + { + let attempt_timeout = + retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); + if let Some(remaining) = attempt_timeout + && remaining.is_zero() + { + let duration = timeout.unwrap_or(remaining); + return Err(anyhow!( + "timed out handshaking with MCP server after {duration:?}" + )); + } + + let transport = match pending_transport.take() { + Some(transport) => transport, + None => Self::create_pending_transport(&self.transport_recipe).await?, + }; + + match Self::connect_pending_transport( + transport, + client_service.clone(), + attempt_timeout, + ) + .await + { + Ok(result) => return Ok(result), + Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => { + let Some(retry_delay_ms) = retry_delay_ms else { + return Err(error); + }; + let delay = Duration::from_millis(retry_delay_ms); + warn!( + attempt = attempt + 1, + max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, + delay_ms = delay.as_millis(), + error = %error, + "streamable HTTP MCP initialize failed with a retryable error; retrying" + ); + if !sleep_with_retry_deadline(delay, retry_deadline).await { + let duration = timeout.unwrap_or(delay); + return Err(anyhow!( + "timed out handshaking with MCP server after {duration:?}" + )); + } + } + Err(error) => return Err(error), + } + } + + unreachable!("initialize retry loop should return on success or final error") + } + + fn is_retryable_initialize_error(error: &anyhow::Error) -> bool { + error.chain().any(|source| { + source + .downcast_ref::() + .is_some_and(|error| Self::is_retryable_client_initialize_error(&error.source)) + || source + .downcast_ref::() + .is_some_and(Self::is_retryable_client_initialize_error) + }) + } + + fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool { + match error { + rmcp::service::ClientInitializeError::TransportError { error, context } + if matches!( + context.as_ref(), + "send initialize request" | "send initialized notification" + ) => + { + error + .error + .downcast_ref::>() + .is_some_and(Self::is_retryable_streamable_http_error) + } + _ => false, + } + } + + fn is_retryable_streamable_http_error( + error: &StreamableHttpError, + ) -> bool { + match error { + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::HttpRequest(_), + )) => true, + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Server { code, message }, + )) => { + *code == JSON_RPC_INTERNAL_ERROR_CODE && message.starts_with("http/request failed:") + } + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Protocol(message), + )) => message.starts_with("http response stream `") && message.contains("` failed:"), + _ => false, + } + } +} + +async fn sleep_with_retry_deadline(delay: Duration, deadline: Option) -> bool { + if let Some(deadline) = deadline { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return false; + } + time::timeout(remaining, time::sleep(delay)).await.is_ok() + } else { + time::sleep(delay).await; + true + } +} + +#[derive(Debug, thiserror::Error)] +#[error("handshaking with MCP server failed: {source}")] +pub(super) struct HandshakeError { + #[source] + pub(super) source: rmcp::service::ClientInitializeError, +} + +#[cfg(test)] +#[path = "streamable_http_retry_tests.rs"] +mod tests; diff --git a/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs b/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs new file mode 100644 index 00000000000..8d2e6ecabf1 --- /dev/null +++ b/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs @@ -0,0 +1,72 @@ +use std::any::TypeId; + +use codex_exec_server::ExecServerError; +use pretty_assertions::assert_eq; +use rmcp::transport::DynamicTransportError; +use rmcp::transport::streamable_http_client::StreamableHttpError; + +use crate::http_client_adapter::StreamableHttpClientAdapterError; + +use super::*; + +#[test] +fn retryable_initialize_error_includes_initialized_notification_context() { + let contexts = [ + "send initialize request", + "send initialized notification", + "receive initialize response", + ]; + + assert_eq!( + contexts.map(|context| { + RmcpClient::is_retryable_client_initialize_error(&retryable_initialize_error(context)) + }), + [true, true, false], + ); +} + +#[test] +fn retryable_streamable_http_error_includes_remote_body_stream_failure() { + let errors = [ + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::HttpRequest("error sending request for url".to_string()), + )), + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Server { + code: JSON_RPC_INTERNAL_ERROR_CODE, + message: "http/request failed: error sending request for url".to_string(), + }, + )), + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Protocol( + "http response stream `http-1` failed: exec-server transport disconnected" + .to_string(), + ), + )), + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Protocol( + "http response stream `http-1` received seq 2, expected 1".to_string(), + ), + )), + ]; + + assert_eq!( + errors.map(|error| RmcpClient::is_retryable_streamable_http_error(&error)), + [true, true, true, false], + ); +} + +fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientInitializeError { + rmcp::service::ClientInitializeError::TransportError { + error: DynamicTransportError::from_parts( + "streamable_http", + TypeId::of::<()>(), + Box::new(StreamableHttpError::Client( + StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest( + "error sending request for url".to_string(), + )), + )), + ), + context: context.into(), + } +} diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 087d3d00df6..571d9adb8df 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,13 +1,114 @@ mod streamable_http_test_support; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use codex_exec_server::Environment; +use codex_exec_server::ExecServerError; +use codex_exec_server::HttpClient; +use codex_exec_server::HttpRequestParams; +use codex_exec_server::HttpRequestResponse; +use codex_exec_server::HttpResponseBodyStream; +use futures::FutureExt as _; +use futures::future::BoxFuture; use pretty_assertions::assert_eq; +use serde_json::Value; use streamable_http_test_support::arm_session_post_failure; use streamable_http_test_support::call_echo_tool; use streamable_http_test_support::create_client; +use streamable_http_test_support::create_client_with_http_client; use streamable_http_test_support::expected_echo_result; use streamable_http_test_support::spawn_streamable_http_server; +const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; +const SIMULATED_NO_RESPONSE_MESSAGE: &str = + "http/request failed: error sending request for url (simulated no response)"; + +#[derive(Clone)] +struct FailFirstInitializeHttpClient { + inner: Arc, + failures_remaining: Arc, + initialize_attempts: Arc, +} + +impl FailFirstInitializeHttpClient { + fn new(inner: Arc) -> Self { + Self { + inner, + failures_remaining: Arc::new(AtomicUsize::new(1)), + initialize_attempts: Arc::new(AtomicUsize::new(0)), + } + } + + fn initialize_attempts(&self) -> usize { + self.initialize_attempts.load(Ordering::SeqCst) + } +} + +impl HttpClient for FailFirstInitializeHttpClient { + fn http_request( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result> { + self.inner.http_request(params) + } + + fn http_request_stream( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { + let inner = Arc::clone(&self.inner); + let failures_remaining = Arc::clone(&self.failures_remaining); + let initialize_attempts = Arc::clone(&self.initialize_attempts); + + async move { + if is_initialize_post(¶ms) { + initialize_attempts.fetch_add(1, Ordering::SeqCst); + if failures_remaining.swap(0, Ordering::SeqCst) > 0 { + return Err(ExecServerError::Server { + code: JSON_RPC_INTERNAL_ERROR_CODE, + message: SIMULATED_NO_RESPONSE_MESSAGE.to_string(), + }); + } + } + + inner.http_request_stream(params).await + } + .boxed() + } +} + +fn is_initialize_post(params: &HttpRequestParams) -> bool { + params.method.eq_ignore_ascii_case("POST") + && params + .body + .as_ref() + .and_then(|body| serde_json::from_slice::(&body.0).ok()) + .and_then(|body| { + body.get("method") + .and_then(Value::as_str) + .map(|method| method == "initialize") + }) + .unwrap_or(false) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_initialize_retries_remote_no_response_error() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let http_client = + FailFirstInitializeHttpClient::new(Environment::default_for_tests().get_http_client()); + + let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; + let result = call_echo_tool(&client, "after-init-retry").await?; + + assert_eq!(http_client.initialize_attempts(), 2); + assert_eq!(result, expected_echo_result("after-init-retry")); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs index 822acef1a26..b3cefa4296a 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -20,6 +20,7 @@ use anyhow::Context as _; use codex_config::types::OAuthCredentialsStoreMode; use codex_exec_server::Environment; use codex_exec_server::ExecServerClient; +use codex_exec_server::HttpClient; use codex_exec_server::RemoteExecServerConnectArgs; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; @@ -74,6 +75,14 @@ pub(crate) fn expected_echo_result(message: &str) -> CallToolResult { } pub(crate) async fn create_client(base_url: &str) -> anyhow::Result { + create_client_with_http_client(base_url, Environment::default_for_tests().get_http_client()) + .await +} + +pub(crate) async fn create_client_with_http_client( + base_url: &str, + http_client: Arc, +) -> anyhow::Result { let client = RmcpClient::new_streamable_http_client( "test-streamable-http", &format!("{base_url}/mcp"), @@ -81,7 +90,7 @@ pub(crate) async fn create_client(base_url: &str) -> anyhow::Result /*http_headers*/ None, /*env_http_headers*/ None, OAuthCredentialsStoreMode::File, - Environment::default_for_tests().get_http_client(), + http_client, /*auth_provider*/ None, ) .await?;