Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions codex-rs/rmcp-client/src/rmcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;

#[path = "streamable_http_retry.rs"]
mod streamable_http_retry;

use self::streamable_http_retry::HandshakeError;

enum PendingTransport {
InProcess {
transport: tokio::io::DuplexStream,
Expand Down Expand Up @@ -396,9 +401,13 @@ impl RmcpClient {
}
};

let (service, oauth_persistor) =
Self::connect_pending_transport(pending_transport, client_service.clone(), timeout)
.await?;
let (service, oauth_persistor) = self
.connect_pending_transport_with_initialize_retries(
pending_transport,
client_service.clone(),
timeout,
)
.await?;

let initialize_result_rmcp = service
.peer()
Expand Down Expand Up @@ -849,10 +858,10 @@ impl RmcpClient {
Some(duration) => time::timeout(duration, transport)
.await
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
.map_err(|source| HandshakeError { source })?,
None => transport
.await
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
.map_err(|source| HandshakeError { source })?,
};

Ok((Arc::new(service), oauth_persistor))
Expand Down
169 changes: 169 additions & 0 deletions codex-rs/rmcp-client/src/streamable_http_retry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;

use anyhow::Result;
use anyhow::anyhow;
use codex_exec_server::ExecServerError;
use rmcp::service::RoleClient;
use rmcp::service::RunningService;
use rmcp::transport::streamable_http_client::StreamableHttpError;
use tokio::time;
use tracing::warn;

use crate::elicitation_client_service::ElicitationClientService;
use crate::http_client_adapter::StreamableHttpClientAdapterError;
use crate::oauth::OAuthPersistor;

use super::PendingTransport;
use super::RmcpClient;

const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603;
const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000];

impl RmcpClient {
pub(super) async fn connect_pending_transport_with_initialize_retries(
&self,
initial_transport: PendingTransport,
client_service: ElicitationClientService,
timeout: Option<Duration>,
) -> Result<(
Arc<RunningService<RoleClient, ElicitationClientService>>,
Option<OAuthPersistor>,
)> {
let should_retry = match &initial_transport {
PendingTransport::InProcess { .. } | PendingTransport::Stdio { .. } => false,
PendingTransport::StreamableHttp { .. }
| PendingTransport::StreamableHttpWithOAuth { .. } => true,
};
let retry_deadline = timeout.map(|duration| Instant::now() + duration);
let mut pending_transport = Some(initial_transport);

for (attempt, retry_delay_ms) in STREAMABLE_HTTP_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
let attempt_timeout =
retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
if let Some(remaining) = attempt_timeout
&& remaining.is_zero()
{
let duration = timeout.unwrap_or(remaining);
return Err(anyhow!(
"timed out handshaking with MCP server after {duration:?}"
));
}

let transport = match pending_transport.take() {
Some(transport) => transport,
None => Self::create_pending_transport(&self.transport_recipe).await?,
};

match Self::connect_pending_transport(
transport,
client_service.clone(),
attempt_timeout,
)
.await
{
Ok(result) => return Ok(result),
Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => {
let Some(retry_delay_ms) = retry_delay_ms else {
return Err(error);
};
let delay = Duration::from_millis(retry_delay_ms);
warn!(
attempt = attempt + 1,
max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1,
delay_ms = delay.as_millis(),
error = %error,
"streamable HTTP MCP initialize failed with a retryable error; retrying"
);
if !sleep_with_retry_deadline(delay, retry_deadline).await {
let duration = timeout.unwrap_or(delay);
return Err(anyhow!(
"timed out handshaking with MCP server after {duration:?}"
));
}
}
Err(error) => return Err(error),
}
}

unreachable!("initialize retry loop should return on success or final error")
}

fn is_retryable_initialize_error(error: &anyhow::Error) -> bool {
error.chain().any(|source| {
source
.downcast_ref::<HandshakeError>()
.is_some_and(|error| Self::is_retryable_client_initialize_error(&error.source))
|| source
.downcast_ref::<rmcp::service::ClientInitializeError>()
.is_some_and(Self::is_retryable_client_initialize_error)
})
}

fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool {
match error {
rmcp::service::ClientInitializeError::TransportError { error, context }
if matches!(
context.as_ref(),
"send initialize request" | "send initialized notification"
) =>
{
error
.error
.downcast_ref::<StreamableHttpError<StreamableHttpClientAdapterError>>()
.is_some_and(Self::is_retryable_streamable_http_error)
}
_ => false,
}
}

fn is_retryable_streamable_http_error(
error: &StreamableHttpError<StreamableHttpClientAdapterError>,
) -> bool {
match error {
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::HttpRequest(_),
)) => true,
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::Server { code, message },
)) => {
*code == JSON_RPC_INTERNAL_ERROR_CODE && message.starts_with("http/request failed:")
}
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::Protocol(message),
)) => message.starts_with("http response stream `") && message.contains("` failed:"),
_ => false,
}
}
}

async fn sleep_with_retry_deadline(delay: Duration, deadline: Option<Instant>) -> bool {
if let Some(deadline) = deadline {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return false;
}
time::timeout(remaining, time::sleep(delay)).await.is_ok()
} else {
time::sleep(delay).await;
true
}
}

#[derive(Debug, thiserror::Error)]
#[error("handshaking with MCP server failed: {source}")]
pub(super) struct HandshakeError {
#[source]
pub(super) source: rmcp::service::ClientInitializeError,
}

#[cfg(test)]
#[path = "streamable_http_retry_tests.rs"]
mod tests;
72 changes: 72 additions & 0 deletions codex-rs/rmcp-client/src/streamable_http_retry_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use std::any::TypeId;

use codex_exec_server::ExecServerError;
use pretty_assertions::assert_eq;
use rmcp::transport::DynamicTransportError;
use rmcp::transport::streamable_http_client::StreamableHttpError;

use crate::http_client_adapter::StreamableHttpClientAdapterError;

use super::*;

#[test]
fn retryable_initialize_error_includes_initialized_notification_context() {
let contexts = [
"send initialize request",
"send initialized notification",
"receive initialize response",
];

assert_eq!(
contexts.map(|context| {
RmcpClient::is_retryable_client_initialize_error(&retryable_initialize_error(context))
}),
[true, true, false],
);
}

#[test]
fn retryable_streamable_http_error_includes_remote_body_stream_failure() {
let errors = [
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::HttpRequest("error sending request for url".to_string()),
)),
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::Server {
code: JSON_RPC_INTERNAL_ERROR_CODE,
message: "http/request failed: error sending request for url".to_string(),
},
)),
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::Protocol(
"http response stream `http-1` failed: exec-server transport disconnected"
.to_string(),
),
)),
StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest(
ExecServerError::Protocol(
"http response stream `http-1` received seq 2, expected 1".to_string(),
),
)),
];

assert_eq!(
errors.map(|error| RmcpClient::is_retryable_streamable_http_error(&error)),
[true, true, true, false],
);
}

fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientInitializeError {
rmcp::service::ClientInitializeError::TransportError {
error: DynamicTransportError::from_parts(
"streamable_http",
TypeId::of::<()>(),
Box::new(StreamableHttpError::Client(
StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest(
"error sending request for url".to_string(),
)),
)),
),
context: context.into(),
}
}
101 changes: 101 additions & 0 deletions codex-rs/rmcp-client/tests/streamable_http_recovery.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,114 @@
mod streamable_http_test_support;

use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;

use codex_exec_server::Environment;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpRequestResponse;
use codex_exec_server::HttpResponseBodyStream;
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pretty_assertions::assert_eq;
use serde_json::Value;

use streamable_http_test_support::arm_session_post_failure;
use streamable_http_test_support::call_echo_tool;
use streamable_http_test_support::create_client;
use streamable_http_test_support::create_client_with_http_client;
use streamable_http_test_support::expected_echo_result;
use streamable_http_test_support::spawn_streamable_http_server;

const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603;
const SIMULATED_NO_RESPONSE_MESSAGE: &str =
"http/request failed: error sending request for url (simulated no response)";

#[derive(Clone)]
struct FailFirstInitializeHttpClient {
inner: Arc<dyn HttpClient>,
failures_remaining: Arc<AtomicUsize>,
initialize_attempts: Arc<AtomicUsize>,
}

impl FailFirstInitializeHttpClient {
fn new(inner: Arc<dyn HttpClient>) -> Self {
Self {
inner,
failures_remaining: Arc::new(AtomicUsize::new(1)),
initialize_attempts: Arc::new(AtomicUsize::new(0)),
}
}

fn initialize_attempts(&self) -> usize {
self.initialize_attempts.load(Ordering::SeqCst)
}
}

impl HttpClient for FailFirstInitializeHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
self.inner.http_request(params)
}

fn http_request_stream(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
let inner = Arc::clone(&self.inner);
let failures_remaining = Arc::clone(&self.failures_remaining);
let initialize_attempts = Arc::clone(&self.initialize_attempts);

async move {
if is_initialize_post(&params) {
initialize_attempts.fetch_add(1, Ordering::SeqCst);
if failures_remaining.swap(0, Ordering::SeqCst) > 0 {
return Err(ExecServerError::Server {
code: JSON_RPC_INTERNAL_ERROR_CODE,
message: SIMULATED_NO_RESPONSE_MESSAGE.to_string(),
});
}
}

inner.http_request_stream(params).await
}
.boxed()
}
}

fn is_initialize_post(params: &HttpRequestParams) -> bool {
params.method.eq_ignore_ascii_case("POST")
&& params
.body
.as_ref()
.and_then(|body| serde_json::from_slice::<Value>(&body.0).ok())
.and_then(|body| {
body.get("method")
.and_then(Value::as_str)
.map(|method| method == "initialize")
})
.unwrap_or(false)
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_initialize_retries_remote_no_response_error() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
let http_client =
FailFirstInitializeHttpClient::new(Environment::default_for_tests().get_http_client());

let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?;
let result = call_echo_tool(&client, "after-init-retry").await?;

assert_eq!(http_client.initialize_attempts(), 2);
assert_eq!(result, expected_echo_result("after-init-retry"));

Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> {
let (_server, base_url) = spawn_streamable_http_server().await?;
Expand Down
Loading
Loading